Location via proxy:   [ UP ]  
[Report a bug]   [Manage cookies]                
Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion src/langchain_google_cloud_sql_pg/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ async def _aexecute(self, query: str, params: Optional[dict] = None) -> None:
await conn.commit()

async def _aexecute_outside_tx(self, query: str) -> None:
"""Execute a SQL query."""
"""Execute a SQL query in a new transaction."""
async with self._engine.connect() as conn:
await conn.execute(text("COMMIT"))
await conn.execute(text(query))
Expand All @@ -343,6 +343,18 @@ async def _afetch(

return result_fetch

async def _afetch_with_query_options(
self, query: str, query_options: str
) -> Sequence[RowMapping]:
"""Set temporary database flags and fetch results from a SQL query."""
async with self._engine.connect() as conn:
await conn.execute(text(query_options))
result = await conn.execute(text(query))
result_map = result.mappings()
result_fetch = result_map.fetchall()

return result_fetch

def _execute(self, query: str, params: Optional[dict] = None) -> None:
"""Execute a SQL query."""
return self._run_as_sync(self._aexecute(query, params))
Expand Down
8 changes: 5 additions & 3 deletions src/langchain_google_cloud_sql_pg/vectorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,10 +599,12 @@ async def __query_collection(
filter = f"WHERE {filter}" if filter else ""
stmt = f"SELECT *, {search_function}({self.embedding_column}, '{embedding}') as distance FROM \"{self.table_name}\" {filter} ORDER BY {self.embedding_column} {operator} '{embedding}' LIMIT {k};"
if self.index_query_options:
await self.engine._aexecute(
f"SET LOCAL {self.index_query_options.to_string()};"
query_options_stmt = f"SET LOCAL {self.index_query_options.to_string()};"
results = await self.engine._afetch_with_query_options(
stmt, query_options_stmt
)
results = await self.engine._afetch(stmt)
else:
results = await self.engine._afetch(stmt)
return results

def similarity_search(
Expand Down