muryshev commited on
Commit
9b1050d
1 Parent(s): 3f1b4d8

updated search

Browse files
legal_info_search_utils/utils.py CHANGED
@@ -12,6 +12,7 @@ from torch.cuda.amp import autocast
12
 
13
  all_types_but_courts = ['НКРФ', 'ГКРФ', 'ТКРФ', 'Федеральный закон', 'Письмо Минфина',
14
  'Письмо ФНС', 'Приказ ФНС', 'Постановление Правительства']
 
15
 
16
 
17
  class FaissDocsDataset(Dataset):
@@ -172,26 +173,6 @@ def extract_text_embeddings(index_toks, val_questions, model,
172
  return docs_embeds_faiss.numpy(), questions_embeds_faiss.numpy()
173
 
174
 
175
- def run_semantic_search(index, model, tokenizer, filtered_db_data, all_docs_qa,
176
- do_normalization=True, faiss_batch_size=16, topk=100):
177
- index_keys, index_toks = db_tokenization(filtered_db_data, tokenizer)
178
- val_questions, val_refs = qa_tokenization(all_docs_qa, tokenizer)
179
- docs_embeds_faiss, questions_embeds_faiss = extract_text_embeddings(index_toks,
180
- val_questions, model, do_normalization, faiss_batch_size)
181
- index.add(docs_embeds_faiss)
182
-
183
- pred = {}
184
- true = {}
185
- all_distances = []
186
- for idx, (q_embed, refs) in enumerate(zip(questions_embeds_faiss, val_refs.values())):
187
- distances, indices = index.search(np.expand_dims(q_embed, 0), topk)
188
- pred[idx] = [index_keys[x] for x in indices[0]]
189
- true[idx] = list(refs)
190
- all_distances.append(distances)
191
-
192
- return pred, true, all_distances
193
-
194
-
195
  def filter_ref_parts(ref_dict, filter_parts):
196
  filtered_dict = {}
197
  for k, refs in ref_dict.items():
@@ -203,13 +184,13 @@ def filter_ref_parts(ref_dict, filter_parts):
203
 
204
 
205
  def get_final_metrics(pred, true, categories, top_k_values,
206
- metrics_func, dynamic_topk=False):
207
  metrics = {}
208
  for top_k in top_k_values:
209
  ctg_metrics = {}
210
  for ctg in categories:
211
  ctg_pred, ctg_true = get_exact_ctg_data(pred, true, ctg)
212
- metrics_at_k = metrics_func(ctg_pred, ctg_true, top_k, dynamic_topk)
213
  for mk in metrics_at_k.keys():
214
  metrics_at_k[mk] = round(metrics_at_k[mk] * 100, 6)
215
  ctg_metrics[ctg] = metrics_at_k
 
12
 
13
  all_types_but_courts = ['НКРФ', 'ГКРФ', 'ТКРФ', 'Федеральный закон', 'Письмо Минфина',
14
  'Письмо ФНС', 'Приказ ФНС', 'Постановление Правительства']
15
+ court_text_splitter = "Весь текст судебного документа: "
16
 
17
 
18
  class FaissDocsDataset(Dataset):
 
173
  return docs_embeds_faiss.numpy(), questions_embeds_faiss.numpy()
174
 
175
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
  def filter_ref_parts(ref_dict, filter_parts):
177
  filtered_dict = {}
178
  for k, refs in ref_dict.items():
 
184
 
185
 
186
  def get_final_metrics(pred, true, categories, top_k_values,
187
+ metrics_func, metrics_func_params):
188
  metrics = {}
189
  for top_k in top_k_values:
190
  ctg_metrics = {}
191
  for ctg in categories:
192
  ctg_pred, ctg_true = get_exact_ctg_data(pred, true, ctg)
193
+ metrics_at_k = metrics_func(ctg_pred, ctg_true, top_k, **metrics_func_params)
194
  for mk in metrics_at_k.keys():
195
  metrics_at_k[mk] = round(metrics_at_k[mk] * 100, 6)
196
  ctg_metrics[ctg] = metrics_at_k
semantic_search.py CHANGED
@@ -11,16 +11,18 @@ from legal_info_search_utils.utils import filter_db_data_types, filter_qa_data_t
11
  from legal_info_search_utils.utils import db_tokenization, qa_tokenization
12
  from legal_info_search_utils.utils import extract_text_embeddings, filter_ref_parts
13
  from legal_info_search_utils.utils import print_metrics, get_final_metrics
 
14
  from legal_info_search_utils.utils import query_tokenization, query_embed_extraction
15
  from legal_info_search_utils.metrics import calculate_metrics_at_k
16
 
17
 
18
  global_data_path = os.environ.get("GLOBAL_DATA_PATH", "legal_info_search_data/")
19
- global_model_path = os.environ.get("GLOBAL_MODEL_PATH", "e5_large_rus_finetuned_20240120_122822_ep6")
 
20
 
21
  # размеченные консультации
22
  data_path_consult = os.environ.get("DATA_PATH_CONSULT",
23
- global_data_path + "data_jsons_20240119.pkl")
24
 
25
  # id консультаций, побитые на train / valid / test
26
  data_path_consult_ids = os.environ.get("DATA_PATH_CONSULT_IDS",
@@ -40,7 +42,8 @@ db_data_types = os.environ.get("DB_DATA_TYPES", [
40
  'Письмо Минфина',
41
  'Письмо ФНС',
42
  'Приказ ФНС',
43
- 'Постановление Правительства'
 
44
  ])
45
 
46
  device = os.environ.get("MODEL_DEVICE", 'cuda' if torch.cuda.is_available() else 'cpu')
@@ -120,7 +123,9 @@ class SemanticSearch:
120
  'Приказ ФНС': {
121
  'thresh': 0.806896, 'sim_factor': 0.5, 'diff_n': 0},
122
  'Постановление Правительства': {
123
- 'thresh': 0.782758, 'sim_factor': 0.5, 'diff_n': 0}
 
 
124
  }
125
  self.ref_categories = {
126
  'all': 'all',
@@ -162,9 +167,14 @@ class SemanticSearch:
162
  filtered_pred = filter_ref_parts(pred, filter_parts)
163
  filtered_true = filter_ref_parts(true, filter_parts)
164
 
 
 
 
 
165
  metrics = get_final_metrics(filtered_pred, filtered_true,
166
  self.ref_categories.keys(), [0],
167
- metrics_func=calculate_metrics_at_k, dynamic_topk=True)
 
168
 
169
  print_metrics(metrics, self.ref_categories)
170
 
@@ -176,9 +186,13 @@ class SemanticSearch:
176
  ctg_thresh = self.optimal_params[ctg]['thresh']
177
  ctg_sim_factor = self.optimal_params[ctg]['sim_factor']
178
  ctg_diff_n = self.optimal_params[ctg]['diff_n']
179
-
180
- ctg_preds = [(ref, dist) for ref, dist in zip(pred, dists)
181
- if ctg in ref and dist > ctg_thresh]
 
 
 
 
182
 
183
  sorted_pd = sorted(ctg_preds, key=lambda x: x[1], reverse=True)
184
  sorted_preds = [x[0] for x in sorted_pd]
@@ -216,6 +230,23 @@ class SemanticSearch:
216
 
217
  return sorted_preds, sorted_scores
218
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
  def search(self, query, top=10):
220
  query_tokens = query_tokenization(query, self.tokenizer)
221
  query_embeds = query_embed_extraction(query_tokens, self.model,
@@ -224,5 +255,6 @@ class SemanticSearch:
224
  pred = [self.index_keys[x] for x in indices[0]]
225
  preds, scores = self.search_results_filtering(pred, distances[0])
226
  docs = [self.filtered_db_data[ref] for ref in preds]
 
227
 
228
  return preds[:top], docs[:top], scores[:top]
 
11
  from legal_info_search_utils.utils import db_tokenization, qa_tokenization
12
  from legal_info_search_utils.utils import extract_text_embeddings, filter_ref_parts
13
  from legal_info_search_utils.utils import print_metrics, get_final_metrics
14
+ from legal_info_search_utils.utils import all_types_but_courts, court_text_splitter
15
  from legal_info_search_utils.utils import query_tokenization, query_embed_extraction
16
  from legal_info_search_utils.metrics import calculate_metrics_at_k
17
 
18
 
19
  global_data_path = os.environ.get("GLOBAL_DATA_PATH", "legal_info_search_data/")
20
+ global_model_path = os.environ.get("GLOBAL_MODEL_PATH",
21
+ "legal_info_search_model/20240120_122822_ep6/")
22
 
23
  # размеченные консультации
24
  data_path_consult = os.environ.get("DATA_PATH_CONSULT",
25
+ global_data_path + "data_jsons_20240131.pkl")
26
 
27
  # id консультаций, побитые на train / valid / test
28
  data_path_consult_ids = os.environ.get("DATA_PATH_CONSULT_IDS",
 
42
  'Письмо Минфина',
43
  'Письмо ФНС',
44
  'Приказ ФНС',
45
+ 'Постановление Правительства',
46
+ 'Суды'
47
  ])
48
 
49
  device = os.environ.get("MODEL_DEVICE", 'cuda' if torch.cuda.is_available() else 'cpu')
 
123
  'Приказ ФНС': {
124
  'thresh': 0.806896, 'sim_factor': 0.5, 'diff_n': 0},
125
  'Постановление Правительства': {
126
+ 'thresh': 0.782758, 'sim_factor': 0.5, 'diff_n': 0},
127
+ 'Суды': {
128
+ 'thresh': 0.846153, 'sim_factor': 0.939230,'diff_n': 0}
129
  }
130
  self.ref_categories = {
131
  'all': 'all',
 
167
  filtered_pred = filter_ref_parts(pred, filter_parts)
168
  filtered_true = filter_ref_parts(true, filter_parts)
169
 
170
+ metrics_func_params = {
171
+ # 'compensate_div_0': True,
172
+ 'dynamic_topk': True
173
+ }
174
  metrics = get_final_metrics(filtered_pred, filtered_true,
175
  self.ref_categories.keys(), [0],
176
+ metrics_func=calculate_metrics_at_k,
177
+ metrics_func_params=metrics_func_params)
178
 
179
  print_metrics(metrics, self.ref_categories)
180
 
 
186
  ctg_thresh = self.optimal_params[ctg]['thresh']
187
  ctg_sim_factor = self.optimal_params[ctg]['sim_factor']
188
  ctg_diff_n = self.optimal_params[ctg]['diff_n']
189
+
190
+ if ctg == 'Суды':
191
+ ctg_preds = [(ref, dist) for ref, dist in zip(pred, dists) if dist > ctg_thresh
192
+ and not any([True for type_ in all_types_but_courts if type_ in ref])]
193
+ else:
194
+ ctg_preds = [(ref, dist) for ref, dist in zip(pred, dists)
195
+ if ctg in ref and dist > ctg_thresh]
196
 
197
  sorted_pd = sorted(ctg_preds, key=lambda x: x[1], reverse=True)
198
  sorted_preds = [x[0] for x in sorted_pd]
 
230
 
231
  return sorted_preds, sorted_scores
232
 
233
+ def court_docs_shrinking(self, preds, docs):
234
+ new_preds = []
235
+ new_docs = []
236
+
237
+ for ref_name, ref_text in zip(preds, docs):
238
+ is_court = not any([True for type_ in all_types_but_courts if type_ in ref_name])
239
+ has_splitter = court_text_splitter in ref_text
240
+
241
+ if is_court and has_splitter:
242
+ new_ref_text = ref_text.split(court_text_splitter)[0].strip()
243
+ new_preds.append(ref_name)
244
+ new_docs.append(new_ref_text)
245
+ else:
246
+ new_preds.append(ref_name)
247
+ new_docs.append(ref_text)
248
+ return new_preds, new_docs
249
+
250
  def search(self, query, top=10):
251
  query_tokens = query_tokenization(query, self.tokenizer)
252
  query_embeds = query_embed_extraction(query_tokens, self.model,
 
255
  pred = [self.index_keys[x] for x in indices[0]]
256
  preds, scores = self.search_results_filtering(pred, distances[0])
257
  docs = [self.filtered_db_data[ref] for ref in preds]
258
+ preds, docs = self.court_docs_shrinking(preds, docs)
259
 
260
  return preds[:top], docs[:top], scores[:top]