Spaces:
Runtime error
Runtime error
updated search
Browse files- legal_info_search_utils/utils.py +3 -22
- semantic_search.py +40 -8
legal_info_search_utils/utils.py
CHANGED
@@ -12,6 +12,7 @@ from torch.cuda.amp import autocast
|
|
12 |
|
13 |
all_types_but_courts = ['НКРФ', 'ГКРФ', 'ТКРФ', 'Федеральный закон', 'Письмо Минфина',
|
14 |
'Письмо ФНС', 'Приказ ФНС', 'Постановление Правительства']
|
|
|
15 |
|
16 |
|
17 |
class FaissDocsDataset(Dataset):
|
@@ -172,26 +173,6 @@ def extract_text_embeddings(index_toks, val_questions, model,
|
|
172 |
return docs_embeds_faiss.numpy(), questions_embeds_faiss.numpy()
|
173 |
|
174 |
|
175 |
-
def run_semantic_search(index, model, tokenizer, filtered_db_data, all_docs_qa,
|
176 |
-
do_normalization=True, faiss_batch_size=16, topk=100):
|
177 |
-
index_keys, index_toks = db_tokenization(filtered_db_data, tokenizer)
|
178 |
-
val_questions, val_refs = qa_tokenization(all_docs_qa, tokenizer)
|
179 |
-
docs_embeds_faiss, questions_embeds_faiss = extract_text_embeddings(index_toks,
|
180 |
-
val_questions, model, do_normalization, faiss_batch_size)
|
181 |
-
index.add(docs_embeds_faiss)
|
182 |
-
|
183 |
-
pred = {}
|
184 |
-
true = {}
|
185 |
-
all_distances = []
|
186 |
-
for idx, (q_embed, refs) in enumerate(zip(questions_embeds_faiss, val_refs.values())):
|
187 |
-
distances, indices = index.search(np.expand_dims(q_embed, 0), topk)
|
188 |
-
pred[idx] = [index_keys[x] for x in indices[0]]
|
189 |
-
true[idx] = list(refs)
|
190 |
-
all_distances.append(distances)
|
191 |
-
|
192 |
-
return pred, true, all_distances
|
193 |
-
|
194 |
-
|
195 |
def filter_ref_parts(ref_dict, filter_parts):
|
196 |
filtered_dict = {}
|
197 |
for k, refs in ref_dict.items():
|
@@ -203,13 +184,13 @@ def filter_ref_parts(ref_dict, filter_parts):
|
|
203 |
|
204 |
|
205 |
def get_final_metrics(pred, true, categories, top_k_values,
|
206 |
-
metrics_func,
|
207 |
metrics = {}
|
208 |
for top_k in top_k_values:
|
209 |
ctg_metrics = {}
|
210 |
for ctg in categories:
|
211 |
ctg_pred, ctg_true = get_exact_ctg_data(pred, true, ctg)
|
212 |
-
metrics_at_k = metrics_func(ctg_pred, ctg_true, top_k,
|
213 |
for mk in metrics_at_k.keys():
|
214 |
metrics_at_k[mk] = round(metrics_at_k[mk] * 100, 6)
|
215 |
ctg_metrics[ctg] = metrics_at_k
|
|
|
12 |
|
13 |
all_types_but_courts = ['НКРФ', 'ГКРФ', 'ТКРФ', 'Федеральный закон', 'Письмо Минфина',
|
14 |
'Письмо ФНС', 'Приказ ФНС', 'Постановление Правительства']
|
15 |
+
court_text_splitter = "Весь текст судебного документа: "
|
16 |
|
17 |
|
18 |
class FaissDocsDataset(Dataset):
|
|
|
173 |
return docs_embeds_faiss.numpy(), questions_embeds_faiss.numpy()
|
174 |
|
175 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
176 |
def filter_ref_parts(ref_dict, filter_parts):
|
177 |
filtered_dict = {}
|
178 |
for k, refs in ref_dict.items():
|
|
|
184 |
|
185 |
|
186 |
def get_final_metrics(pred, true, categories, top_k_values,
|
187 |
+
metrics_func, metrics_func_params):
|
188 |
metrics = {}
|
189 |
for top_k in top_k_values:
|
190 |
ctg_metrics = {}
|
191 |
for ctg in categories:
|
192 |
ctg_pred, ctg_true = get_exact_ctg_data(pred, true, ctg)
|
193 |
+
metrics_at_k = metrics_func(ctg_pred, ctg_true, top_k, **metrics_func_params)
|
194 |
for mk in metrics_at_k.keys():
|
195 |
metrics_at_k[mk] = round(metrics_at_k[mk] * 100, 6)
|
196 |
ctg_metrics[ctg] = metrics_at_k
|
semantic_search.py
CHANGED
@@ -11,16 +11,18 @@ from legal_info_search_utils.utils import filter_db_data_types, filter_qa_data_t
|
|
11 |
from legal_info_search_utils.utils import db_tokenization, qa_tokenization
|
12 |
from legal_info_search_utils.utils import extract_text_embeddings, filter_ref_parts
|
13 |
from legal_info_search_utils.utils import print_metrics, get_final_metrics
|
|
|
14 |
from legal_info_search_utils.utils import query_tokenization, query_embed_extraction
|
15 |
from legal_info_search_utils.metrics import calculate_metrics_at_k
|
16 |
|
17 |
|
18 |
global_data_path = os.environ.get("GLOBAL_DATA_PATH", "legal_info_search_data/")
|
19 |
-
global_model_path = os.environ.get("GLOBAL_MODEL_PATH",
|
|
|
20 |
|
21 |
# размеченные консультации
|
22 |
data_path_consult = os.environ.get("DATA_PATH_CONSULT",
|
23 |
-
global_data_path + "
|
24 |
|
25 |
# id консультаций, побитые на train / valid / test
|
26 |
data_path_consult_ids = os.environ.get("DATA_PATH_CONSULT_IDS",
|
@@ -40,7 +42,8 @@ db_data_types = os.environ.get("DB_DATA_TYPES", [
|
|
40 |
'Письмо Минфина',
|
41 |
'Письмо ФНС',
|
42 |
'Приказ ФНС',
|
43 |
-
'Постановление Правительства'
|
|
|
44 |
])
|
45 |
|
46 |
device = os.environ.get("MODEL_DEVICE", 'cuda' if torch.cuda.is_available() else 'cpu')
|
@@ -120,7 +123,9 @@ class SemanticSearch:
|
|
120 |
'Приказ ФНС': {
|
121 |
'thresh': 0.806896, 'sim_factor': 0.5, 'diff_n': 0},
|
122 |
'Постановление Правительства': {
|
123 |
-
'thresh': 0.782758, 'sim_factor': 0.5, 'diff_n': 0}
|
|
|
|
|
124 |
}
|
125 |
self.ref_categories = {
|
126 |
'all': 'all',
|
@@ -162,9 +167,14 @@ class SemanticSearch:
|
|
162 |
filtered_pred = filter_ref_parts(pred, filter_parts)
|
163 |
filtered_true = filter_ref_parts(true, filter_parts)
|
164 |
|
|
|
|
|
|
|
|
|
165 |
metrics = get_final_metrics(filtered_pred, filtered_true,
|
166 |
self.ref_categories.keys(), [0],
|
167 |
-
metrics_func=calculate_metrics_at_k,
|
|
|
168 |
|
169 |
print_metrics(metrics, self.ref_categories)
|
170 |
|
@@ -176,9 +186,13 @@ class SemanticSearch:
|
|
176 |
ctg_thresh = self.optimal_params[ctg]['thresh']
|
177 |
ctg_sim_factor = self.optimal_params[ctg]['sim_factor']
|
178 |
ctg_diff_n = self.optimal_params[ctg]['diff_n']
|
179 |
-
|
180 |
-
|
181 |
-
|
|
|
|
|
|
|
|
|
182 |
|
183 |
sorted_pd = sorted(ctg_preds, key=lambda x: x[1], reverse=True)
|
184 |
sorted_preds = [x[0] for x in sorted_pd]
|
@@ -216,6 +230,23 @@ class SemanticSearch:
|
|
216 |
|
217 |
return sorted_preds, sorted_scores
|
218 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
219 |
def search(self, query, top=10):
|
220 |
query_tokens = query_tokenization(query, self.tokenizer)
|
221 |
query_embeds = query_embed_extraction(query_tokens, self.model,
|
@@ -224,5 +255,6 @@ class SemanticSearch:
|
|
224 |
pred = [self.index_keys[x] for x in indices[0]]
|
225 |
preds, scores = self.search_results_filtering(pred, distances[0])
|
226 |
docs = [self.filtered_db_data[ref] for ref in preds]
|
|
|
227 |
|
228 |
return preds[:top], docs[:top], scores[:top]
|
|
|
11 |
from legal_info_search_utils.utils import db_tokenization, qa_tokenization
|
12 |
from legal_info_search_utils.utils import extract_text_embeddings, filter_ref_parts
|
13 |
from legal_info_search_utils.utils import print_metrics, get_final_metrics
|
14 |
+
from legal_info_search_utils.utils import all_types_but_courts, court_text_splitter
|
15 |
from legal_info_search_utils.utils import query_tokenization, query_embed_extraction
|
16 |
from legal_info_search_utils.metrics import calculate_metrics_at_k
|
17 |
|
18 |
|
19 |
global_data_path = os.environ.get("GLOBAL_DATA_PATH", "legal_info_search_data/")
|
20 |
+
global_model_path = os.environ.get("GLOBAL_MODEL_PATH",
|
21 |
+
"legal_info_search_model/20240120_122822_ep6/")
|
22 |
|
23 |
# размеченные консультации
|
24 |
data_path_consult = os.environ.get("DATA_PATH_CONSULT",
|
25 |
+
global_data_path + "data_jsons_20240131.pkl")
|
26 |
|
27 |
# id консультаций, побитые на train / valid / test
|
28 |
data_path_consult_ids = os.environ.get("DATA_PATH_CONSULT_IDS",
|
|
|
42 |
'Письмо Минфина',
|
43 |
'Письмо ФНС',
|
44 |
'Приказ ФНС',
|
45 |
+
'Постановление Правительства',
|
46 |
+
'Суды'
|
47 |
])
|
48 |
|
49 |
device = os.environ.get("MODEL_DEVICE", 'cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
123 |
'Приказ ФНС': {
|
124 |
'thresh': 0.806896, 'sim_factor': 0.5, 'diff_n': 0},
|
125 |
'Постановление Правительства': {
|
126 |
+
'thresh': 0.782758, 'sim_factor': 0.5, 'diff_n': 0},
|
127 |
+
'Суды': {
|
128 |
+
'thresh': 0.846153, 'sim_factor': 0.939230,'diff_n': 0}
|
129 |
}
|
130 |
self.ref_categories = {
|
131 |
'all': 'all',
|
|
|
167 |
filtered_pred = filter_ref_parts(pred, filter_parts)
|
168 |
filtered_true = filter_ref_parts(true, filter_parts)
|
169 |
|
170 |
+
metrics_func_params = {
|
171 |
+
# 'compensate_div_0': True,
|
172 |
+
'dynamic_topk': True
|
173 |
+
}
|
174 |
metrics = get_final_metrics(filtered_pred, filtered_true,
|
175 |
self.ref_categories.keys(), [0],
|
176 |
+
metrics_func=calculate_metrics_at_k,
|
177 |
+
metrics_func_params=metrics_func_params)
|
178 |
|
179 |
print_metrics(metrics, self.ref_categories)
|
180 |
|
|
|
186 |
ctg_thresh = self.optimal_params[ctg]['thresh']
|
187 |
ctg_sim_factor = self.optimal_params[ctg]['sim_factor']
|
188 |
ctg_diff_n = self.optimal_params[ctg]['diff_n']
|
189 |
+
|
190 |
+
if ctg == 'Суды':
|
191 |
+
ctg_preds = [(ref, dist) for ref, dist in zip(pred, dists) if dist > ctg_thresh
|
192 |
+
and not any([True for type_ in all_types_but_courts if type_ in ref])]
|
193 |
+
else:
|
194 |
+
ctg_preds = [(ref, dist) for ref, dist in zip(pred, dists)
|
195 |
+
if ctg in ref and dist > ctg_thresh]
|
196 |
|
197 |
sorted_pd = sorted(ctg_preds, key=lambda x: x[1], reverse=True)
|
198 |
sorted_preds = [x[0] for x in sorted_pd]
|
|
|
230 |
|
231 |
return sorted_preds, sorted_scores
|
232 |
|
233 |
+
def court_docs_shrinking(self, preds, docs):
|
234 |
+
new_preds = []
|
235 |
+
new_docs = []
|
236 |
+
|
237 |
+
for ref_name, ref_text in zip(preds, docs):
|
238 |
+
is_court = not any([True for type_ in all_types_but_courts if type_ in ref_name])
|
239 |
+
has_splitter = court_text_splitter in ref_text
|
240 |
+
|
241 |
+
if is_court and has_splitter:
|
242 |
+
new_ref_text = ref_text.split(court_text_splitter)[0].strip()
|
243 |
+
new_preds.append(ref_name)
|
244 |
+
new_docs.append(new_ref_text)
|
245 |
+
else:
|
246 |
+
new_preds.append(ref_name)
|
247 |
+
new_docs.append(ref_text)
|
248 |
+
return new_preds, new_docs
|
249 |
+
|
250 |
def search(self, query, top=10):
|
251 |
query_tokens = query_tokenization(query, self.tokenizer)
|
252 |
query_embeds = query_embed_extraction(query_tokens, self.model,
|
|
|
255 |
pred = [self.index_keys[x] for x in indices[0]]
|
256 |
preds, scores = self.search_results_filtering(pred, distances[0])
|
257 |
docs = [self.filtered_db_data[ref] for ref in preds]
|
258 |
+
preds, docs = self.court_docs_shrinking(preds, docs)
|
259 |
|
260 |
return preds[:top], docs[:top], scores[:top]
|