Spaces:
Runtime error
Runtime error
dorogan
commited on
Commit
•
45d03f9
1
Parent(s):
5ea5121
Update: new semantic search logic was provided, requirements.txt file was fixed
Browse files- app.py +6 -1
- requirements.txt +2 -1
- semantic_search.py +89 -167
app.py
CHANGED
@@ -5,7 +5,6 @@ from semantic_search import SemanticSearch
|
|
5 |
from datetime import datetime
|
6 |
|
7 |
search = SemanticSearch()
|
8 |
-
search.test_search()
|
9 |
|
10 |
app = Flask(__name__)
|
11 |
app.config['JSON_AS_ASCII'] = False
|
@@ -20,6 +19,7 @@ if not os.path.exists(LOGS_BASE_PATH):
|
|
20 |
# Check if logs are enabled
|
21 |
ENABLE_LOGS = os.getenv("ENABLE_LOGS", "0") == "1"
|
22 |
|
|
|
23 |
def log_query_result(query, top, request_id, result):
|
24 |
if not ENABLE_LOGS:
|
25 |
return
|
@@ -38,10 +38,12 @@ def log_query_result(query, top, request_id, result):
|
|
38 |
with open(log_file_path, 'w') as log_file:
|
39 |
json.dump(log_data, log_file, indent=2)
|
40 |
|
|
|
41 |
@app.route('/health', methods=['GET'])
|
42 |
def health():
|
43 |
return jsonify({"status": "ok"})
|
44 |
|
|
|
45 |
@app.route('/search', methods=['POST'])
|
46 |
def search_route():
|
47 |
data = request.get_json()
|
@@ -56,6 +58,7 @@ def search_route():
|
|
56 |
|
57 |
return jsonify(result)
|
58 |
|
|
|
59 |
@app.route('/read_logs', methods=['GET'])
|
60 |
def read_logs():
|
61 |
logs = []
|
@@ -66,6 +69,7 @@ def read_logs():
|
|
66 |
logs.append(log_data)
|
67 |
return jsonify(logs)
|
68 |
|
|
|
69 |
@app.route('/analyze_logs', methods=['GET'])
|
70 |
def analyze_logs():
|
71 |
logs_by_query_top = {}
|
@@ -91,5 +95,6 @@ def analyze_logs():
|
|
91 |
|
92 |
return jsonify(invalid_logs)
|
93 |
|
|
|
94 |
if __name__ == '__main__':
|
95 |
app.run(debug=False, host='0.0.0.0')
|
|
|
5 |
from datetime import datetime
|
6 |
|
7 |
search = SemanticSearch()
|
|
|
8 |
|
9 |
app = Flask(__name__)
|
10 |
app.config['JSON_AS_ASCII'] = False
|
|
|
19 |
# Check if logs are enabled
|
20 |
ENABLE_LOGS = os.getenv("ENABLE_LOGS", "0") == "1"
|
21 |
|
22 |
+
|
23 |
def log_query_result(query, top, request_id, result):
|
24 |
if not ENABLE_LOGS:
|
25 |
return
|
|
|
38 |
with open(log_file_path, 'w') as log_file:
|
39 |
json.dump(log_data, log_file, indent=2)
|
40 |
|
41 |
+
|
42 |
@app.route('/health', methods=['GET'])
|
43 |
def health():
|
44 |
return jsonify({"status": "ok"})
|
45 |
|
46 |
+
|
47 |
@app.route('/search', methods=['POST'])
|
48 |
def search_route():
|
49 |
data = request.get_json()
|
|
|
58 |
|
59 |
return jsonify(result)
|
60 |
|
61 |
+
|
62 |
@app.route('/read_logs', methods=['GET'])
|
63 |
def read_logs():
|
64 |
logs = []
|
|
|
69 |
logs.append(log_data)
|
70 |
return jsonify(logs)
|
71 |
|
72 |
+
|
73 |
@app.route('/analyze_logs', methods=['GET'])
|
74 |
def analyze_logs():
|
75 |
logs_by_query_top = {}
|
|
|
95 |
|
96 |
return jsonify(invalid_logs)
|
97 |
|
98 |
+
|
99 |
if __name__ == '__main__':
|
100 |
app.run(debug=False, host='0.0.0.0')
|
requirements.txt
CHANGED
@@ -8,4 +8,5 @@ transformers==4.29.2
|
|
8 |
# sentencepiece==0.1.99
|
9 |
# six==1.16.0
|
10 |
# tokenizers==0.13.3
|
11 |
-
flask==3.0.0
|
|
|
|
8 |
# sentencepiece==0.1.99
|
9 |
# six==1.16.0
|
10 |
# tokenizers==0.13.3
|
11 |
+
flask==3.0.0
|
12 |
+
datasets
|
semantic_search.py
CHANGED
@@ -1,54 +1,58 @@
|
|
1 |
import os
|
2 |
-
import json
|
3 |
import torch
|
4 |
-
import pickle
|
5 |
import numpy as np
|
6 |
import faiss
|
7 |
-
|
8 |
from transformers import AutoTokenizer, AutoModel
|
9 |
-
from legal_info_search_utils.utils import get_subsets_for_db, get_subsets_for_qa
|
10 |
-
from legal_info_search_utils.utils import filter_db_data_types, filter_qa_data_types
|
11 |
-
from legal_info_search_utils.utils import db_tokenization, qa_tokenization
|
12 |
-
from legal_info_search_utils.utils import extract_text_embeddings, filter_ref_parts
|
13 |
-
from legal_info_search_utils.utils import print_metrics, get_final_metrics
|
14 |
-
from legal_info_search_utils.utils import court_text_splitter
|
15 |
from legal_info_search_utils.utils import query_tokenization, query_embed_extraction
|
16 |
-
|
|
|
17 |
|
18 |
|
19 |
global_data_path = os.environ.get("GLOBAL_DATA_PATH", "legal_info_search_data/")
|
20 |
-
global_model_path = os.environ.get("GLOBAL_MODEL_PATH",
|
21 |
"legal_info_search_model/20240202_204910_ep8/")
|
22 |
|
23 |
# размеченные консультации
|
24 |
-
data_path_consult = os.environ.get("DATA_PATH_CONSULT",
|
25 |
-
|
|
|
|
|
|
|
26 |
|
27 |
# id консультаций, побитые на train / valid / test
|
28 |
-
data_path_consult_ids = os.environ.get("DATA_PATH_CONSULT_IDS",
|
29 |
-
|
30 |
|
31 |
# предобработанные внутренние документы
|
32 |
-
data_path_internal_docs = os.environ.get("DATA_PATH_INTERNAL_DOCS",
|
33 |
-
|
34 |
|
35 |
# состав БД
|
36 |
# $ export DB_SUBSETS='["train", "valid", "test"]'
|
37 |
-
db_subsets = os.environ.get("DB_SUBSETS", ["train", "valid", "test"])
|
38 |
|
39 |
# Отбор типов документов. В списке указать те, которые нужно оставить в БД.
|
40 |
# $ export DB_DATA_TYPES='["НКРФ", "ГКРФ", "ТКРФ"]'
|
41 |
db_data_types = os.environ.get("DB_DATA_TYPES", [
|
42 |
-
'НКРФ',
|
43 |
-
'ГКРФ',
|
44 |
-
'ТКРФ',
|
45 |
-
'Федеральный закон',
|
46 |
'Письмо Минфина',
|
47 |
'Письмо ФНС',
|
48 |
-
'Приказ ФНС',
|
49 |
-
'Постановление Правительства',
|
50 |
'Судебный документ',
|
51 |
-
'Внутренний документ'
|
52 |
])
|
53 |
|
54 |
device = os.environ.get("MODEL_DEVICE", 'cuda' if torch.cuda.is_available() else 'cpu')
|
@@ -57,67 +61,22 @@ device = os.environ.get("MODEL_DEVICE", 'cuda' if torch.cuda.is_available() else
|
|
57 |
hf_token = os.environ.get("HF_TOKEN", "")
|
58 |
hf_model_name = os.environ.get("HF_MODEL_NAME", "")
|
59 |
|
|
|
60 |
class SemanticSearch:
|
61 |
-
def __init__(self,
|
62 |
-
|
|
|
|
|
63 |
self.device = device
|
64 |
self.do_embedding_norm = do_embedding_norm
|
65 |
-
self.faiss_batch_size = faiss_batch_size
|
66 |
self.do_normalization = do_normalization
|
67 |
self.load_model()
|
68 |
-
|
69 |
indexes = {
|
70 |
"IndexFlatL2": faiss.IndexFlatL2(self.embedding_dim),
|
71 |
"IndexFlatIP": faiss.IndexFlatIP(self.embedding_dim)
|
72 |
}
|
73 |
-
self.index = indexes[index_type]
|
74 |
-
self.load_data()
|
75 |
-
self.preproces_data()
|
76 |
-
self.test_search()
|
77 |
|
78 |
-
def load_data(self):
|
79 |
-
with open(data_path_consult, "rb") as f:
|
80 |
-
all_docs = pickle.load(f)
|
81 |
-
|
82 |
-
with open(data_path_consult_ids, "r", encoding="utf-8") as f:
|
83 |
-
data_ids = json.load(f)
|
84 |
-
|
85 |
-
with open(data_path_internal_docs, "r", encoding="utf-8") as f:
|
86 |
-
internal_docs = json.load(f)
|
87 |
-
|
88 |
-
db_data = get_subsets_for_db(db_subsets, data_ids, all_docs)
|
89 |
-
filtered_all_docs = filter_qa_data_types(db_data_types, all_docs)
|
90 |
-
|
91 |
-
self.mean_refs_count = self.get_mean_refs_counts(db_data_types, filtered_all_docs)
|
92 |
-
self.mean_refs_count['Внутренний документ'] = 3
|
93 |
-
self.filtered_db_data = filter_db_data_types(db_data_types, db_data)
|
94 |
-
self.filtered_db_data.update(internal_docs)
|
95 |
-
self.all_docs_qa = get_subsets_for_qa(["valid"], data_ids, filtered_all_docs)
|
96 |
-
|
97 |
-
def load_model(self):
|
98 |
-
if hf_token and hf_model_name:
|
99 |
-
self.tokenizer = AutoTokenizer.from_pretrained(hf_model_name, use_auth_token=True)
|
100 |
-
self.model = AutoModel.from_pretrained(hf_model_name, use_auth_token=True).to(self.device)
|
101 |
-
else:
|
102 |
-
self.tokenizer = AutoTokenizer.from_pretrained(global_model_path)
|
103 |
-
self.model = AutoModel.from_pretrained(global_model_path).to(self.device)
|
104 |
-
|
105 |
-
self.max_len = self.tokenizer.max_len_single_sentence
|
106 |
-
self.embedding_dim = self.model.config.hidden_size
|
107 |
-
|
108 |
-
def preproces_data(self):
|
109 |
-
index_keys, index_toks = db_tokenization(self.filtered_db_data, self.tokenizer)
|
110 |
-
val_questions, val_refs = qa_tokenization(self.all_docs_qa, self.tokenizer)
|
111 |
-
docs_embeds_faiss, questions_embeds_faiss = extract_text_embeddings(index_toks,
|
112 |
-
val_questions, self.model, self.do_normalization, self.faiss_batch_size)
|
113 |
-
|
114 |
-
self.index.add(docs_embeds_faiss)
|
115 |
-
self.index_keys = index_keys
|
116 |
-
self.index_toks = index_toks
|
117 |
-
self.val_questions = val_questions
|
118 |
-
self.val_refs = val_refs
|
119 |
-
self.docs_embeds_faiss = docs_embeds_faiss
|
120 |
-
self.questions_embeds_faiss = questions_embeds_faiss
|
121 |
self.optimal_params = {
|
122 |
'НКРФ': {'thresh': 0.58421, 'sim_factor': 0.89474},
|
123 |
'ГКРФ': {'thresh': 0.64737, 'sim_factor': 0.89474},
|
@@ -130,61 +89,33 @@ class SemanticSearch:
|
|
130 |
'Судебный документ': {'thresh': 0.67895, 'sim_factor': 0.89474},
|
131 |
'Внутренний документ': {'thresh': 0.55263, 'sim_factor': 0.84211}
|
132 |
}
|
133 |
-
self.
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
'Письмо ФНС': 'Письмо ФНС',
|
142 |
-
'Приказ ФНС': 'Приказ ФНС',
|
143 |
-
'Постановление Правительства': 'Пост. Прав.',
|
144 |
-
'Внутренний документ': 'Внутр. док.'
|
145 |
-
}
|
146 |
-
|
147 |
-
def test_search(self):
|
148 |
-
topk = len(self.filtered_db_data)
|
149 |
-
pred_raw = {}
|
150 |
-
true = {}
|
151 |
-
all_distances = []
|
152 |
-
for idx, (q_embed, refs) in enumerate(zip(self.questions_embeds_faiss,
|
153 |
-
self.val_refs.values())):
|
154 |
-
distances, indices = self.index.search(np.expand_dims(q_embed, 0), topk)
|
155 |
-
pred_raw[idx] = [self.index_keys[x] for x in indices[0]]
|
156 |
-
true[idx] = list(refs)
|
157 |
-
all_distances.append(distances)
|
158 |
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
|
|
|
|
163 |
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
filtered_pred = filter_ref_parts(pred, filter_parts)
|
172 |
-
filtered_true = filter_ref_parts(true, filter_parts)
|
173 |
|
174 |
-
|
175 |
-
|
176 |
-
'dynamic_topk': True,
|
177 |
-
'skip_empty_trues': False,
|
178 |
-
'skip_empty_preds': False
|
179 |
-
}
|
180 |
-
metrics = get_final_metrics(filtered_pred, filtered_true,
|
181 |
-
self.ref_categories.keys(), [0],
|
182 |
-
metrics_func=calculate_metrics_at_k,
|
183 |
-
metrics_func_params=metrics_func_params)
|
184 |
-
|
185 |
-
print_metrics(metrics, self.ref_categories)
|
186 |
|
187 |
-
|
188 |
def search_results_filtering(self, pred, dists):
|
189 |
all_ctg_preds = []
|
190 |
all_scores = []
|
@@ -192,7 +123,7 @@ class SemanticSearch:
|
|
192 |
ctg_thresh = self.optimal_params[ctg]['thresh']
|
193 |
ctg_sim_factor = self.optimal_params[ctg]['sim_factor']
|
194 |
|
195 |
-
ctg_preds = [(ref, dist) for ref, dist in zip(pred, dists)
|
196 |
if ctg in ref and dist > ctg_thresh]
|
197 |
|
198 |
sorted_pd = sorted(ctg_preds, key=lambda x: x[1], reverse=True)
|
@@ -222,48 +153,39 @@ class SemanticSearch:
|
|
222 |
|
223 |
return sorted_preds, sorted_scores
|
224 |
|
225 |
-
def
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
for tp in db_data_types:
|
246 |
-
all_tp_refs = []
|
247 |
-
for doc in data.values():
|
248 |
-
tp_refs_len = len([ref for ref in doc['added_refs'] if tp in ref])
|
249 |
-
if tp_refs_len:
|
250 |
-
all_tp_refs.append(tp_refs_len)
|
251 |
-
|
252 |
-
mean_refs_count[tp] = np.mean(all_tp_refs)
|
253 |
-
|
254 |
-
for k, v in mean_refs_count.items():
|
255 |
-
mean_refs_count[k] = int(v + 1)
|
256 |
-
|
257 |
-
return mean_refs_count
|
258 |
|
259 |
def search(self, query, top=15):
|
260 |
query_tokens = query_tokenization(query, self.tokenizer)
|
261 |
-
query_embeds = query_embed_extraction(query_tokens, self.model,
|
262 |
self.do_normalization)
|
263 |
-
distances, indices = self.
|
264 |
-
pred = [self.
|
265 |
preds, scores = self.search_results_filtering(pred, distances[0])
|
266 |
-
docs = [self.
|
267 |
-
|
268 |
-
|
|
|
|
|
|
|
|
|
269 |
return preds[:top], docs[:top], scores[:top]
|
|
|
1 |
import os
|
2 |
+
# import json
|
3 |
import torch
|
4 |
+
# import pickle
|
5 |
import numpy as np
|
6 |
import faiss
|
7 |
+
from datasets import Dataset as dataset
|
8 |
from transformers import AutoTokenizer, AutoModel
|
9 |
+
# from legal_info_search_utils.utils import get_subsets_for_db, get_subsets_for_qa
|
10 |
+
# from legal_info_search_utils.utils import filter_db_data_types, filter_qa_data_types
|
11 |
+
# from legal_info_search_utils.utils import db_tokenization, qa_tokenization
|
12 |
+
# from legal_info_search_utils.utils import extract_text_embeddings, filter_ref_parts
|
13 |
+
# from legal_info_search_utils.utils import print_metrics, get_final_metrics
|
14 |
+
# from legal_info_search_utils.utils import court_text_splitter
|
15 |
from legal_info_search_utils.utils import query_tokenization, query_embed_extraction
|
16 |
+
|
17 |
+
# from legal_info_search_utils.metrics import calculate_metrics_at_k
|
18 |
|
19 |
|
20 |
global_data_path = os.environ.get("GLOBAL_DATA_PATH", "legal_info_search_data/")
|
21 |
+
global_model_path = os.environ.get("GLOBAL_MODEL_PATH",
|
22 |
"legal_info_search_model/20240202_204910_ep8/")
|
23 |
|
24 |
# размеченные консультации
|
25 |
+
# data_path_consult = os.environ.get("DATA_PATH_CONSULT",
|
26 |
+
# global_data_path + "data_jsons_20240202.pkl")
|
27 |
+
|
28 |
+
data_path_consult = os.environ.get("DATA_PATH_CONSULT",
|
29 |
+
global_data_path + "court_dataset_chunk_200_correct_tokenizer_for_develop")
|
30 |
|
31 |
# id консультаций, побитые на train / valid / test
|
32 |
+
# data_path_consult_ids = os.environ.get("DATA_PATH_CONSULT_IDS",
|
33 |
+
# global_data_path + "data_ids.json")
|
34 |
|
35 |
# предобработанные внутренние документы
|
36 |
+
# data_path_internal_docs = os.environ.get("DATA_PATH_INTERNAL_DOCS",
|
37 |
+
# global_data_path + "internal_docs.json")
|
38 |
|
39 |
# состав БД
|
40 |
# $ export DB_SUBSETS='["train", "valid", "test"]'
|
41 |
+
# db_subsets = os.environ.get("DB_SUBSETS", ["train", "valid", "test"])
|
42 |
|
43 |
# Отбор типов документов. В списке указать те, которые нужно оставить в БД.
|
44 |
# $ export DB_DATA_TYPES='["НКРФ", "ГКРФ", "ТКРФ"]'
|
45 |
db_data_types = os.environ.get("DB_DATA_TYPES", [
|
46 |
+
# 'НКРФ',
|
47 |
+
# 'ГКРФ',
|
48 |
+
# 'ТКРФ',
|
49 |
+
# 'Федеральный закон',
|
50 |
'Письмо Минфина',
|
51 |
'Письмо ФНС',
|
52 |
+
# 'Приказ ФНС',
|
53 |
+
# 'Постановление Правительства',
|
54 |
'Судебный документ',
|
55 |
+
# 'Внутренний документ'
|
56 |
])
|
57 |
|
58 |
device = os.environ.get("MODEL_DEVICE", 'cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
61 |
hf_token = os.environ.get("HF_TOKEN", "")
|
62 |
hf_model_name = os.environ.get("HF_MODEL_NAME", "")
|
63 |
|
64 |
+
|
65 |
class SemanticSearch:
|
66 |
+
def __init__(self,
|
67 |
+
index_type="IndexFlatIP",
|
68 |
+
do_embedding_norm=True,
|
69 |
+
do_normalization=True):
|
70 |
self.device = device
|
71 |
self.do_embedding_norm = do_embedding_norm
|
|
|
72 |
self.do_normalization = do_normalization
|
73 |
self.load_model()
|
74 |
+
|
75 |
indexes = {
|
76 |
"IndexFlatL2": faiss.IndexFlatL2(self.embedding_dim),
|
77 |
"IndexFlatIP": faiss.IndexFlatIP(self.embedding_dim)
|
78 |
}
|
|
|
|
|
|
|
|
|
79 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
self.optimal_params = {
|
81 |
'НКРФ': {'thresh': 0.58421, 'sim_factor': 0.89474},
|
82 |
'ГКРФ': {'thresh': 0.64737, 'sim_factor': 0.89474},
|
|
|
89 |
'Судебный документ': {'thresh': 0.67895, 'sim_factor': 0.89474},
|
90 |
'Внутренний документ': {'thresh': 0.55263, 'sim_factor': 0.84211}
|
91 |
}
|
92 |
+
self.index_type = index_type
|
93 |
+
self.index_docs = indexes[self.index_type]
|
94 |
+
self.load_data()
|
95 |
+
self.docs_embeddings = [torch.unsqueeze(torch.Tensor(x['doc_embedding']), 0) for x in self.all_docs_info]
|
96 |
+
self.docs_embeddings = torch.cat(self.docs_embeddings, dim=0)
|
97 |
+
self.index_docs.add(self.docs_embeddings)
|
98 |
+
# self.preproces_data()
|
99 |
+
# self.test_search()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
|
101 |
+
def load_data(self):
|
102 |
+
self.all_docs_info = dataset.load_from_disk(data_path_consult).to_list()
|
103 |
+
self.docs_names = [doc['doc_name'] for doc in self.all_docs_info]
|
104 |
+
self.mean_refs_count = {'Письмо Минфина': 3,
|
105 |
+
'Письмо ФНС': 2,
|
106 |
+
'Судебный документ': 3}
|
107 |
|
108 |
+
def load_model(self):
|
109 |
+
if hf_token and hf_model_name:
|
110 |
+
self.tokenizer = AutoTokenizer.from_pretrained(hf_model_name, use_auth_token=True)
|
111 |
+
self.model = AutoModel.from_pretrained(hf_model_name, use_auth_token=True).to(self.device)
|
112 |
+
else:
|
113 |
+
self.tokenizer = AutoTokenizer.from_pretrained(global_model_path)
|
114 |
+
self.model = AutoModel.from_pretrained(global_model_path).to(self.device)
|
|
|
|
|
115 |
|
116 |
+
self.max_len = self.tokenizer.max_len_single_sentence
|
117 |
+
self.embedding_dim = self.model.config.hidden_size
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
|
|
|
119 |
def search_results_filtering(self, pred, dists):
|
120 |
all_ctg_preds = []
|
121 |
all_scores = []
|
|
|
123 |
ctg_thresh = self.optimal_params[ctg]['thresh']
|
124 |
ctg_sim_factor = self.optimal_params[ctg]['sim_factor']
|
125 |
|
126 |
+
ctg_preds = [(ref, dist) for ref, dist in zip(pred, dists)
|
127 |
if ctg in ref and dist > ctg_thresh]
|
128 |
|
129 |
sorted_pd = sorted(ctg_preds, key=lambda x: x[1], reverse=True)
|
|
|
153 |
|
154 |
return sorted_preds, sorted_scores
|
155 |
|
156 |
+
def get_most_relevant_teaser(self,
|
157 |
+
question: str = None,
|
158 |
+
doc_index: int = None):
|
159 |
+
teaser_indexes = {
|
160 |
+
"IndexFlatL2": faiss.IndexFlatL2(self.embedding_dim),
|
161 |
+
"IndexFlatIP": faiss.IndexFlatIP(self.embedding_dim)
|
162 |
+
}
|
163 |
+
teasers_index = teaser_indexes[self.index_type]
|
164 |
+
question_tokens = query_tokenization(question, self.tokenizer)
|
165 |
+
question_embedding = query_embed_extraction(question_tokens, self.model,
|
166 |
+
self.do_normalization)
|
167 |
+
# question_embedding = self.custom_embedder.create_question_embedding(question)
|
168 |
+
teasers_texts = [teaser['summary_text'] for teaser in self.all_docs_info[doc_index]['chunks_embeddings']]
|
169 |
+
teasers_embeddings = [torch.unsqueeze(torch.Tensor(teaser['embedding']), 0) for teaser in
|
170 |
+
self.all_docs_info[doc_index]['chunks_embeddings']]
|
171 |
+
teasers_embeddings = torch.cat(teasers_embeddings, 0)
|
172 |
+
teasers_index.add(teasers_embeddings)
|
173 |
+
distances, indices = teasers_index.search(question_embedding, 10)
|
174 |
+
most_relevant_teaser = teasers_texts[indices[0][0]]
|
175 |
+
return most_relevant_teaser
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
176 |
|
177 |
def search(self, query, top=15):
|
178 |
query_tokens = query_tokenization(query, self.tokenizer)
|
179 |
+
query_embeds = query_embed_extraction(query_tokens, self.model,
|
180 |
self.do_normalization)
|
181 |
+
distances, indices = self.index_docs.search(query_embeds, len(self.all_docs_info))
|
182 |
+
pred = [self.all_docs_info[x]['doc_name'] for x in indices[0]]
|
183 |
preds, scores = self.search_results_filtering(pred, distances[0])
|
184 |
+
# docs = [self.all_docs_info[x][ref] for ref in preds]
|
185 |
+
docs = []
|
186 |
+
for ref in preds:
|
187 |
+
doc_index = self.docs_names.index(ref)
|
188 |
+
most_relevant_teaser = self.get_most_relevant_teaser(question=query,
|
189 |
+
doc_index=doc_index)
|
190 |
+
docs.append(most_relevant_teaser)
|
191 |
return preds[:top], docs[:top], scores[:top]
|