diff --git a/src/langchain_google_cloud_sql_pg/engine.py b/src/langchain_google_cloud_sql_pg/engine.py index 36850317..da284af9 100644 --- a/src/langchain_google_cloud_sql_pg/engine.py +++ b/src/langchain_google_cloud_sql_pg/engine.py @@ -327,7 +327,7 @@ async def _aexecute(self, query: str, params: Optional[dict] = None) -> None: await conn.commit() async def _aexecute_outside_tx(self, query: str) -> None: - """Execute a SQL query.""" + """Execute a SQL query in a new transaction.""" async with self._engine.connect() as conn: await conn.execute(text("COMMIT")) await conn.execute(text(query)) @@ -343,6 +343,18 @@ async def _afetch( return result_fetch + async def _afetch_with_query_options( + self, query: str, query_options: str + ) -> Sequence[RowMapping]: + """Set temporary database flags and fetch results from a SQL query.""" + async with self._engine.connect() as conn: + await conn.execute(text(query_options)) + result = await conn.execute(text(query)) + result_map = result.mappings() + result_fetch = result_map.fetchall() + + return result_fetch + def _execute(self, query: str, params: Optional[dict] = None) -> None: """Execute a SQL query.""" return self._run_as_sync(self._aexecute(query, params)) diff --git a/src/langchain_google_cloud_sql_pg/vectorstore.py b/src/langchain_google_cloud_sql_pg/vectorstore.py index 1c4685e2..a4a0b53a 100644 --- a/src/langchain_google_cloud_sql_pg/vectorstore.py +++ b/src/langchain_google_cloud_sql_pg/vectorstore.py @@ -599,10 +599,12 @@ async def __query_collection( filter = f"WHERE {filter}" if filter else "" stmt = f"SELECT *, {search_function}({self.embedding_column}, '{embedding}') as distance FROM \"{self.table_name}\" {filter} ORDER BY {self.embedding_column} {operator} '{embedding}' LIMIT {k};" if self.index_query_options: - await self.engine._aexecute( - f"SET LOCAL {self.index_query_options.to_string()};" + query_options_stmt = f"SET LOCAL {self.index_query_options.to_string()};" + results = await self.engine._afetch_with_query_options( + stmt, query_options_stmt ) - results = await self.engine._afetch(stmt) + else: + results = await self.engine._afetch(stmt) return results def similarity_search(