-
Notifications
You must be signed in to change notification settings - Fork 39
/
Copy pathhybrid_search.py
144 lines (129 loc) · 4.57 KB
/
hybrid_search.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import os
import time
import pyodbc
import logging
import json
from sentence_transformers import SentenceTransformer
from dotenv import load_dotenv
from utilities import get_mssql_connection
load_dotenv()
if __name__ == '__main__':
print('Initializing sample...')
model = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1', tokenizer_kwargs={'clean_up_tokenization_spaces': True})
print('Getting embeddings...')
sentences = [
'The dog is barking',
'The cat is purring',
'The bear is growling',
'A bear growling to a cat',
'A cat purring to a dog',
'A dog barking to a bear',
'A bear growling to a dog',
'A cat purring to a bear',
'A wolf howling to a bear',
'A bear growling to a wolf'
]
embeddings = model.encode(sentences)
conn = get_mssql_connection()
print('Cleaning up the database...')
try:
cursor = conn.cursor()
cursor.execute("DELETE FROM dbo.hybrid_search_sample;")
cursor.commit();
finally:
cursor.close()
print('Saving documents and embeddings in the database...')
try:
cursor = conn.cursor()
for id, (sentence, embedding) in enumerate(zip(sentences, embeddings)):
cursor.execute(f"""
DECLARE @id INT = ?;
DECLARE @content NVARCHAR(MAX) = ?;
DECLARE @embedding VECTOR(384) = CAST(? AS VECTOR(384));
INSERT INTO dbo.hybrid_search_sample (id, content, embedding) VALUES (@id, @content, @embedding);
""",
id,
sentence,
json.dumps(embedding.tolist())
)
cursor.commit()
finally:
cursor.close()
print('Waiting a few seconds to let fulltext index sync...')
time.sleep(3)
print('Searching for similar documents...')
print('Getting embeddings...')
query = 'a growling bear'
embedding = model.encode(query)
k = 5
print(f'Querying database for {k} similar sentenct to "{query}"...')
try:
cursor = conn.cursor()
results = cursor.execute(f"""
DECLARE @k INT = ?;
DECLARE @q NVARCHAR(4000) = ?;
DECLARE @e VECTOR(384) = CAST(? AS VECTOR(384));
WITH keyword_search AS (
SELECT TOP(@k)
id,
RANK() OVER (ORDER BY rank) AS rank,
content
FROM
(
SELECT TOP(@k)
sd.id,
ftt.[RANK] AS rank,
sd.content
FROM
dbo.hybrid_search_sample AS sd
INNER JOIN
FREETEXTTABLE(dbo.hybrid_search_sample, *, @q) AS ftt ON sd.id = ftt.[KEY]
ORDER BY
rank DESC
) AS t
ORDER BY
rank
),
semantic_search AS
(
SELECT TOP(@k)
id,
RANK() OVER (ORDER BY distance) AS rank,
content
FROM
(
SELECT TOP(@k)
id,
VECTOR_DISTANCE('cosine', embedding, @e) AS distance,
content
FROM
dbo.hybrid_search_sample
ORDER BY
distance
) AS t
ORDER BY
rank
)
SELECT TOP(@k)
COALESCE(ss.id, ks.id) AS id,
COALESCE(1.0 / (@k + ss.rank), 0.0) +
COALESCE(1.0 / (@k + ks.rank), 0.0) AS score, -- Reciprocal Rank Fusion (RRF)
COALESCE(ss.content, ks.content) AS content,
ss.rank AS semantic_rank,
ks.rank AS keyword_rank
FROM
semantic_search ss
FULL OUTER JOIN
keyword_search ks ON ss.id = ks.id
ORDER BY
score DESC
""",
k,
query,
json.dumps(embedding.tolist()),
)
for (pos, row) in enumerate(results):
print(f'[{pos}] RRF score: {row[1]:0.4} (Semantic Rank: {row[3]}, Keyword Rank: {row[4]})\tDocument: "{row[2]}", Id: {row[0]}')
finally:
cursor.close()
print("Done.")