muryshev commited on
Commit
32c50d0
1 Parent(s): c59d707

Search updated

Browse files
app.py CHANGED
@@ -1,5 +1,4 @@
1
  import json
2
- import os
3
  from flask import Flask, jsonify, request
4
  from semantic_search import SemanticSearch
5
 
@@ -24,4 +23,4 @@ def search_route():
24
 
25
  if __name__ == '__main__':
26
 
27
- app.run(debug=False, host='0.0.0.0', port=7868)
 
1
  import json
 
2
  from flask import Flask, jsonify, request
3
  from semantic_search import SemanticSearch
4
 
 
23
 
24
  if __name__ == '__main__':
25
 
26
+ app.run(debug=False, host='0.0.0.0')
legal_info_search_data/data_jsons_20240202.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:78cab704acf861b87eec01ba4d575e2e0110ed57ac64f814c04c3de02ef2db88
3
+ size 22359347
legal_info_search_utils/metrics.py CHANGED
@@ -1,7 +1,8 @@
1
  import numpy as np
2
 
3
 
4
- def calculate_metrics_at_k(pred, true, k, dynamic_topk=False):
 
5
  precisions_at_k = []
6
  recalls_at_k = []
7
  f1_scores_at_k = []
@@ -14,11 +15,17 @@ def calculate_metrics_at_k(pred, true, k, dynamic_topk=False):
14
  relevant_documents = set(true[query_id])
15
  true_positives = len(retrieved_documents.intersection(relevant_documents))
16
 
17
- if not len(retrieved_documents) and not len(relevant_documents):
18
  precisions_at_k.append(1)
19
  recalls_at_k.append(1)
20
  f1_scores_at_k.append(1)
21
  continue
 
 
 
 
 
 
22
 
23
  # precision
24
  precision_at_k = true_positives / k if k else 0
 
1
  import numpy as np
2
 
3
 
4
+ def calculate_metrics_at_k(pred, true, k, compensate_div_0=False, dynamic_topk=True,
5
+ skip_empty_trues=False, skip_empty_preds=False):
6
  precisions_at_k = []
7
  recalls_at_k = []
8
  f1_scores_at_k = []
 
15
  relevant_documents = set(true[query_id])
16
  true_positives = len(retrieved_documents.intersection(relevant_documents))
17
 
18
+ if compensate_div_0 and not len(retrieved_documents) and not len(relevant_documents):
19
  precisions_at_k.append(1)
20
  recalls_at_k.append(1)
21
  f1_scores_at_k.append(1)
22
  continue
23
+
24
+ if skip_empty_trues and not len(relevant_documents):
25
+ continue
26
+
27
+ if skip_empty_preds and not len(retrieved_documents):
28
+ continue
29
 
30
  # precision
31
  precision_at_k = true_positives / k if k else 0
legal_info_search_utils/utils.py CHANGED
@@ -10,8 +10,6 @@ from torch.utils.data import Dataset, DataLoader
10
  from torch.cuda.amp import autocast
11
 
12
 
13
- all_types_but_courts = ['НКРФ', 'ГКРФ', 'ТКРФ', 'Федеральный закон', 'Письмо Минфина',
14
- 'Письмо ФНС', 'Приказ ФНС', 'Постановление Правительства']
15
  court_text_splitter = "Весь текст судебного документа: "
16
 
17
 
@@ -54,12 +52,8 @@ def get_subsets_for_qa(subsets, data_ids, all_docs):
54
  def filter_db_data_types(text_parts, db_data_in):
55
  filtered_db_data = {}
56
  db_data = copy.deepcopy(db_data_in)
57
- check_if_courts = 'Суды' in text_parts
58
- for ref, text in db_data.items():
59
- check_not_other = not any([True for x in all_types_but_courts if x in ref])
60
- court_condition = check_if_courts and check_not_other
61
-
62
- if court_condition or any([True for x in text_parts if x in ref]):
63
  filtered_db_data[ref] = text
64
  return filtered_db_data
65
 
@@ -73,12 +67,8 @@ def filter_qa_data_types(text_parts, all_docs_in):
73
  continue
74
 
75
  filtered_refs = {}
76
- check_if_courts = 'Суды' in text_parts
77
  for ref, text in doc['added_refs'].items():
78
- check_not_other = not any([True for x in all_types_but_courts if x in ref])
79
- court_condition = check_if_courts and check_not_other
80
-
81
- if court_condition or any([True for x in text_parts if x in ref]):
82
  filtered_refs[ref] = text
83
 
84
  filtered_all_docs[doc_key] = doc
@@ -205,16 +195,9 @@ def get_exact_ctg_data(pred_in, true_in, ctg):
205
 
206
  out_pred = {}
207
  out_true = {}
208
- check_if_courts = ctg == "Суды"
209
  for idx, (pred, true) in zip(true_in.keys(), zip(pred_in.values(), true_in.values())):
210
- if check_if_courts:
211
- ctg_refs_true = [ref for ref in true
212
- if not any([True for x in all_types_but_courts if x in ref])]
213
- ctg_refs_pred = [ref for ref in pred
214
- if not any([True for x in all_types_but_courts if x in ref])]
215
- else:
216
- ctg_refs_true = [ref for ref in true if ctg in ref]
217
- ctg_refs_pred = [ref for ref in pred if ctg in ref]
218
 
219
  out_true[idx] = ctg_refs_true
220
  out_pred[idx] = ctg_refs_pred
 
10
  from torch.cuda.amp import autocast
11
 
12
 
 
 
13
  court_text_splitter = "Весь текст судебного документа: "
14
 
15
 
 
52
  def filter_db_data_types(text_parts, db_data_in):
53
  filtered_db_data = {}
54
  db_data = copy.deepcopy(db_data_in)
55
+ for ref, text in db_data.items():
56
+ if any([True for x in text_parts if x in ref]):
 
 
 
 
57
  filtered_db_data[ref] = text
58
  return filtered_db_data
59
 
 
67
  continue
68
 
69
  filtered_refs = {}
 
70
  for ref, text in doc['added_refs'].items():
71
+ if any([True for x in text_parts if x in ref]):
 
 
 
72
  filtered_refs[ref] = text
73
 
74
  filtered_all_docs[doc_key] = doc
 
195
 
196
  out_pred = {}
197
  out_true = {}
 
198
  for idx, (pred, true) in zip(true_in.keys(), zip(pred_in.values(), true_in.values())):
199
+ ctg_refs_true = [ref for ref in true if ctg in ref]
200
+ ctg_refs_pred = [ref for ref in pred if ctg in ref]
 
 
 
 
 
 
201
 
202
  out_true[idx] = ctg_refs_true
203
  out_pred[idx] = ctg_refs_pred
semantic_search.py CHANGED
@@ -11,7 +11,7 @@ from legal_info_search_utils.utils import filter_db_data_types, filter_qa_data_t
11
  from legal_info_search_utils.utils import db_tokenization, qa_tokenization
12
  from legal_info_search_utils.utils import extract_text_embeddings, filter_ref_parts
13
  from legal_info_search_utils.utils import print_metrics, get_final_metrics
14
- from legal_info_search_utils.utils import all_types_but_courts, court_text_splitter
15
  from legal_info_search_utils.utils import query_tokenization, query_embed_extraction
16
  from legal_info_search_utils.metrics import calculate_metrics_at_k
17
 
@@ -22,7 +22,7 @@ global_model_path = os.environ.get("GLOBAL_MODEL_PATH",
22
 
23
  # размеченные консультации
24
  data_path_consult = os.environ.get("DATA_PATH_CONSULT",
25
- global_data_path + "data_jsons_20240131.pkl")
26
 
27
  # id консультаций, побитые на train / valid / test
28
  data_path_consult_ids = os.environ.get("DATA_PATH_CONSULT_IDS",
@@ -43,7 +43,8 @@ db_data_types = os.environ.get("DB_DATA_TYPES", [
43
  'Письмо ФНС',
44
  'Приказ ФНС',
45
  'Постановление Правительства',
46
- 'Суды'
 
47
  ])
48
 
49
  device = os.environ.get("MODEL_DEVICE", 'cuda' if torch.cuda.is_available() else 'cpu')
@@ -80,6 +81,7 @@ class SemanticSearch:
80
  db_data = get_subsets_for_db(db_subsets, data_ids, all_docs)
81
  filtered_all_docs = filter_qa_data_types(db_data_types, all_docs)
82
 
 
83
  self.filtered_db_data = filter_db_data_types(db_data_types, db_data)
84
  self.all_docs_qa = get_subsets_for_qa(["valid"], data_ids, filtered_all_docs)
85
 
@@ -108,24 +110,16 @@ class SemanticSearch:
108
  self.docs_embeds_faiss = docs_embeds_faiss
109
  self.questions_embeds_faiss = questions_embeds_faiss
110
  self.optimal_params = {
111
- 'НКРФ': {
112
- 'thresh': 0.613793, 'sim_factor': 0.878947, 'diff_n': 0},
113
- 'ГКРФ': {
114
- 'thresh': 0.758620, 'sim_factor': 0.878947, 'diff_n': 0},
115
- 'ТКРФ': {
116
- 'thresh': 0.734482, 'sim_factor': 0.9, 'diff_n': 0},
117
- 'Федеральный закон': {
118
- 'thresh': 0.734482, 'sim_factor': 0.5, 'diff_n': 0},
119
- 'Письмо Минфина': {
120
- 'thresh': 0.782758, 'sim_factor': 0.5, 'diff_n': 0},
121
- 'Письмо ФНС': {
122
- 'thresh': 0.879310, 'sim_factor': 0.5, 'diff_n': 0},
123
- 'Приказ ФНС': {
124
- 'thresh': 0.806896, 'sim_factor': 0.5, 'diff_n': 0},
125
- 'Постановление Правительства': {
126
- 'thresh': 0.782758, 'sim_factor': 0.5, 'diff_n': 0},
127
- 'Суды': {
128
- 'thresh': 0.846153, 'sim_factor': 0.939230,'diff_n': 0}
129
  }
130
  self.ref_categories = {
131
  'all': 'all',
@@ -133,11 +127,12 @@ class SemanticSearch:
133
  'ГКРФ': 'ГКРФ',
134
  'ТКРФ': 'ТКРФ',
135
  'Федеральный закон': 'ФЗ',
136
- 'Суды': 'Суды',
137
  'Письмо Минфина': 'Письмо МФ',
138
  'Письмо ФНС': 'Письмо ФНС',
139
  'Приказ ФНС': 'Приказ ФНС',
140
- 'Постановление Правительства': 'Пост. Прав.'
 
141
  }
142
 
143
  def test_search(self):
@@ -157,7 +152,7 @@ class SemanticSearch:
157
  fp, fs = self.search_results_filtering(p, d[0])
158
  pred[idx] = fp
159
 
160
- # раскомментировать нужное. Если всё закомментировано - метрики
161
  # посчтаются "как есть", с учетом полной иерархии
162
  filter_parts = [
163
  # "абз.",
@@ -168,8 +163,10 @@ class SemanticSearch:
168
  filtered_true = filter_ref_parts(true, filter_parts)
169
 
170
  metrics_func_params = {
171
- # 'compensate_div_0': True,
172
- 'dynamic_topk': True
 
 
173
  }
174
  metrics = get_final_metrics(filtered_pred, filtered_true,
175
  self.ref_categories.keys(), [0],
@@ -185,26 +182,15 @@ class SemanticSearch:
185
  for ctg in db_data_types:
186
  ctg_thresh = self.optimal_params[ctg]['thresh']
187
  ctg_sim_factor = self.optimal_params[ctg]['sim_factor']
188
- ctg_diff_n = self.optimal_params[ctg]['diff_n']
189
 
190
- if ctg == 'Суды':
191
- ctg_preds = [(ref, dist) for ref, dist in zip(pred, dists) if dist > ctg_thresh
192
- and not any([True for type_ in all_types_but_courts if type_ in ref])]
193
- else:
194
- ctg_preds = [(ref, dist) for ref, dist in zip(pred, dists)
195
- if ctg in ref and dist > ctg_thresh]
196
 
197
  sorted_pd = sorted(ctg_preds, key=lambda x: x[1], reverse=True)
198
  sorted_preds = [x[0] for x in sorted_pd]
199
  sorted_dists = [x[1] for x in sorted_pd]
200
 
201
  if len(sorted_dists):
202
- diffs = np.diff(sorted_dists, ctg_diff_n)
203
- if len(diffs):
204
- n_preds = np.argmax(diffs) + ctg_diff_n + 1
205
- else:
206
- n_preds = 0
207
-
208
  if len(sorted_dists) > 1:
209
  ratios = (sorted_dists[1:] / sorted_dists[0]) >= ctg_sim_factor
210
  ratios = np.array([True, *ratios])
@@ -213,15 +199,12 @@ class SemanticSearch:
213
 
214
  main_preds = np.array(sorted_preds)[np.where(ratios)].tolist()
215
  scores = np.array(sorted_dists)[np.where(ratios)].tolist()
216
- if ctg_diff_n > 0 and n_preds > 0:
217
- main_preds = main_preds[:n_preds]
218
- scores = scores[:n_preds]
219
  else:
220
  main_preds = []
221
  scores = []
222
 
223
- all_ctg_preds.extend(main_preds)
224
- all_scores.extend(scores)
225
 
226
  sorted_values = [(ref, score) for ref, score in zip(all_ctg_preds, all_scores)]
227
  sorted_values = sorted(sorted_values, key=lambda x: x[1], reverse=True)
@@ -235,7 +218,7 @@ class SemanticSearch:
235
  new_docs = []
236
 
237
  for ref_name, ref_text in zip(preds, docs):
238
- is_court = not any([True for type_ in all_types_but_courts if type_ in ref_name])
239
  has_splitter = court_text_splitter in ref_text
240
 
241
  if is_court and has_splitter:
@@ -247,7 +230,24 @@ class SemanticSearch:
247
  new_docs.append(ref_text)
248
  return new_preds, new_docs
249
 
250
- def search(self, query, top=10):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
  query_tokens = query_tokenization(query, self.tokenizer)
252
  query_embeds = query_embed_extraction(query_tokens, self.model,
253
  self.do_normalization)
 
11
  from legal_info_search_utils.utils import db_tokenization, qa_tokenization
12
  from legal_info_search_utils.utils import extract_text_embeddings, filter_ref_parts
13
  from legal_info_search_utils.utils import print_metrics, get_final_metrics
14
+ from legal_info_search_utils.utils import court_text_splitter
15
  from legal_info_search_utils.utils import query_tokenization, query_embed_extraction
16
  from legal_info_search_utils.metrics import calculate_metrics_at_k
17
 
 
22
 
23
  # размеченные консультации
24
  data_path_consult = os.environ.get("DATA_PATH_CONSULT",
25
+ global_data_path + "data_jsons_20240202.pkl")
26
 
27
  # id консультаций, побитые на train / valid / test
28
  data_path_consult_ids = os.environ.get("DATA_PATH_CONSULT_IDS",
 
43
  'Письмо ФНС',
44
  'Приказ ФНС',
45
  'Постановление Правительства',
46
+ 'Судебный документ',
47
+ 'Внутренний документ'
48
  ])
49
 
50
  device = os.environ.get("MODEL_DEVICE", 'cuda' if torch.cuda.is_available() else 'cpu')
 
81
  db_data = get_subsets_for_db(db_subsets, data_ids, all_docs)
82
  filtered_all_docs = filter_qa_data_types(db_data_types, all_docs)
83
 
84
+ self.mean_refs_count = self.get_mean_refs_counts(db_data_types, filtered_all_docs)
85
  self.filtered_db_data = filter_db_data_types(db_data_types, db_data)
86
  self.all_docs_qa = get_subsets_for_qa(["valid"], data_ids, filtered_all_docs)
87
 
 
110
  self.docs_embeds_faiss = docs_embeds_faiss
111
  self.questions_embeds_faiss = questions_embeds_faiss
112
  self.optimal_params = {
113
+ 'НКРФ': {'thresh': 0.61579, 'sim_factor': 0.84211},
114
+ 'ГКРФ': {'thresh': 0.55263, 'sim_factor': 0.0},
115
+ 'ТКРФ': {'thresh': 0.48947, 'sim_factor': 1.0},
116
+ 'Федеральный закон': {'thresh': 0.52105, 'sim_factor': 0.94737},
117
+ 'Письмо Минфина': {'thresh': 0.71053, 'sim_factor': 0.0},
118
+ 'Письмо ФНС': {'thresh': 0.61579, 'sim_factor': 0.84211},
119
+ 'Приказ ФНС': {'thresh': 0.52105, 'sim_factor': 0.94737},
120
+ 'Постановление Правительства': {'thresh': 0.45789, 'sim_factor': 0.89474},
121
+ 'Судебный документ': {'thresh': 0.80526, 'sim_factor': 0.89474},
122
+ 'Внутренний документ': {'thresh': 0.71053, 'sim_factor': 0.0}
 
 
 
 
 
 
 
 
123
  }
124
  self.ref_categories = {
125
  'all': 'all',
 
127
  'ГКРФ': 'ГКРФ',
128
  'ТКРФ': 'ТКРФ',
129
  'Федеральный закон': 'ФЗ',
130
+ 'Судебный документ': 'Суды',
131
  'Письмо Минфина': 'Письмо МФ',
132
  'Письмо ФНС': 'Письмо ФНС',
133
  'Приказ ФНС': 'Приказ ФНС',
134
+ 'Постановление Правительства': 'Пост. Прав.',
135
+ 'Внутренний документ': 'Внутр. док.'
136
  }
137
 
138
  def test_search(self):
 
152
  fp, fs = self.search_results_filtering(p, d[0])
153
  pred[idx] = fp
154
 
155
+ # раскомментировать нужное. Если в��ё закомментировано - метрики
156
  # посчтаются "как есть", с учетом полной иерархии
157
  filter_parts = [
158
  # "абз.",
 
163
  filtered_true = filter_ref_parts(true, filter_parts)
164
 
165
  metrics_func_params = {
166
+ 'compensate_div_0': True,
167
+ 'dynamic_topk': True,
168
+ 'skip_empty_trues': False,
169
+ 'skip_empty_preds': False
170
  }
171
  metrics = get_final_metrics(filtered_pred, filtered_true,
172
  self.ref_categories.keys(), [0],
 
182
  for ctg in db_data_types:
183
  ctg_thresh = self.optimal_params[ctg]['thresh']
184
  ctg_sim_factor = self.optimal_params[ctg]['sim_factor']
 
185
 
186
+ ctg_preds = [(ref, dist) for ref, dist in zip(pred, dists)
187
+ if ctg in ref and dist > ctg_thresh]
 
 
 
 
188
 
189
  sorted_pd = sorted(ctg_preds, key=lambda x: x[1], reverse=True)
190
  sorted_preds = [x[0] for x in sorted_pd]
191
  sorted_dists = [x[1] for x in sorted_pd]
192
 
193
  if len(sorted_dists):
 
 
 
 
 
 
194
  if len(sorted_dists) > 1:
195
  ratios = (sorted_dists[1:] / sorted_dists[0]) >= ctg_sim_factor
196
  ratios = np.array([True, *ratios])
 
199
 
200
  main_preds = np.array(sorted_preds)[np.where(ratios)].tolist()
201
  scores = np.array(sorted_dists)[np.where(ratios)].tolist()
 
 
 
202
  else:
203
  main_preds = []
204
  scores = []
205
 
206
+ all_ctg_preds.extend(main_preds[:self.mean_refs_count[ctg]])
207
+ all_scores.extend(scores[:self.mean_refs_count[ctg]])
208
 
209
  sorted_values = [(ref, score) for ref, score in zip(all_ctg_preds, all_scores)]
210
  sorted_values = sorted(sorted_values, key=lambda x: x[1], reverse=True)
 
218
  new_docs = []
219
 
220
  for ref_name, ref_text in zip(preds, docs):
221
+ is_court = 'Судебный документ' in ref_name
222
  has_splitter = court_text_splitter in ref_text
223
 
224
  if is_court and has_splitter:
 
230
  new_docs.append(ref_text)
231
  return new_preds, new_docs
232
 
233
+ @staticmethod
234
+ def get_mean_refs_counts(db_data_types, data):
235
+ mean_refs_count = {}
236
+ for tp in db_data_types:
237
+ all_tp_refs = []
238
+ for doc in data.values():
239
+ tp_refs_len = len([ref for ref in doc['added_refs'] if tp in ref])
240
+ if tp_refs_len:
241
+ all_tp_refs.append(tp_refs_len)
242
+
243
+ mean_refs_count[tp] = np.mean(all_tp_refs)
244
+
245
+ for k, v in mean_refs_count.items():
246
+ mean_refs_count[k] = int(v + 1)
247
+
248
+ return mean_refs_count
249
+
250
+ def search(self, query, top=15):
251
  query_tokens = query_tokenization(query, self.tokenizer)
252
  query_embeds = query_embed_extraction(query_tokens, self.model,
253
  self.do_normalization)