Spaces:
Runtime error
Runtime error
Search updated
Browse files- app.py +1 -2
- legal_info_search_data/data_jsons_20240202.pkl +3 -0
- legal_info_search_utils/metrics.py +9 -2
- legal_info_search_utils/utils.py +5 -22
- semantic_search.py +46 -46
app.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1 |
import json
|
2 |
-
import os
|
3 |
from flask import Flask, jsonify, request
|
4 |
from semantic_search import SemanticSearch
|
5 |
|
@@ -24,4 +23,4 @@ def search_route():
|
|
24 |
|
25 |
if __name__ == '__main__':
|
26 |
|
27 |
-
app.run(debug=False, host='0.0.0.0'
|
|
|
1 |
import json
|
|
|
2 |
from flask import Flask, jsonify, request
|
3 |
from semantic_search import SemanticSearch
|
4 |
|
|
|
23 |
|
24 |
if __name__ == '__main__':
|
25 |
|
26 |
+
app.run(debug=False, host='0.0.0.0')
|
legal_info_search_data/data_jsons_20240202.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:78cab704acf861b87eec01ba4d575e2e0110ed57ac64f814c04c3de02ef2db88
|
3 |
+
size 22359347
|
legal_info_search_utils/metrics.py
CHANGED
@@ -1,7 +1,8 @@
|
|
1 |
import numpy as np
|
2 |
|
3 |
|
4 |
-
def calculate_metrics_at_k(pred, true, k, dynamic_topk=
|
|
|
5 |
precisions_at_k = []
|
6 |
recalls_at_k = []
|
7 |
f1_scores_at_k = []
|
@@ -14,11 +15,17 @@ def calculate_metrics_at_k(pred, true, k, dynamic_topk=False):
|
|
14 |
relevant_documents = set(true[query_id])
|
15 |
true_positives = len(retrieved_documents.intersection(relevant_documents))
|
16 |
|
17 |
-
if not len(retrieved_documents) and not len(relevant_documents):
|
18 |
precisions_at_k.append(1)
|
19 |
recalls_at_k.append(1)
|
20 |
f1_scores_at_k.append(1)
|
21 |
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
# precision
|
24 |
precision_at_k = true_positives / k if k else 0
|
|
|
1 |
import numpy as np
|
2 |
|
3 |
|
4 |
+
def calculate_metrics_at_k(pred, true, k, compensate_div_0=False, dynamic_topk=True,
|
5 |
+
skip_empty_trues=False, skip_empty_preds=False):
|
6 |
precisions_at_k = []
|
7 |
recalls_at_k = []
|
8 |
f1_scores_at_k = []
|
|
|
15 |
relevant_documents = set(true[query_id])
|
16 |
true_positives = len(retrieved_documents.intersection(relevant_documents))
|
17 |
|
18 |
+
if compensate_div_0 and not len(retrieved_documents) and not len(relevant_documents):
|
19 |
precisions_at_k.append(1)
|
20 |
recalls_at_k.append(1)
|
21 |
f1_scores_at_k.append(1)
|
22 |
continue
|
23 |
+
|
24 |
+
if skip_empty_trues and not len(relevant_documents):
|
25 |
+
continue
|
26 |
+
|
27 |
+
if skip_empty_preds and not len(retrieved_documents):
|
28 |
+
continue
|
29 |
|
30 |
# precision
|
31 |
precision_at_k = true_positives / k if k else 0
|
legal_info_search_utils/utils.py
CHANGED
@@ -10,8 +10,6 @@ from torch.utils.data import Dataset, DataLoader
|
|
10 |
from torch.cuda.amp import autocast
|
11 |
|
12 |
|
13 |
-
all_types_but_courts = ['НКРФ', 'ГКРФ', 'ТКРФ', 'Федеральный закон', 'Письмо Минфина',
|
14 |
-
'Письмо ФНС', 'Приказ ФНС', 'Постановление Правительства']
|
15 |
court_text_splitter = "Весь текст судебного документа: "
|
16 |
|
17 |
|
@@ -54,12 +52,8 @@ def get_subsets_for_qa(subsets, data_ids, all_docs):
|
|
54 |
def filter_db_data_types(text_parts, db_data_in):
|
55 |
filtered_db_data = {}
|
56 |
db_data = copy.deepcopy(db_data_in)
|
57 |
-
|
58 |
-
|
59 |
-
check_not_other = not any([True for x in all_types_but_courts if x in ref])
|
60 |
-
court_condition = check_if_courts and check_not_other
|
61 |
-
|
62 |
-
if court_condition or any([True for x in text_parts if x in ref]):
|
63 |
filtered_db_data[ref] = text
|
64 |
return filtered_db_data
|
65 |
|
@@ -73,12 +67,8 @@ def filter_qa_data_types(text_parts, all_docs_in):
|
|
73 |
continue
|
74 |
|
75 |
filtered_refs = {}
|
76 |
-
check_if_courts = 'Суды' in text_parts
|
77 |
for ref, text in doc['added_refs'].items():
|
78 |
-
|
79 |
-
court_condition = check_if_courts and check_not_other
|
80 |
-
|
81 |
-
if court_condition or any([True for x in text_parts if x in ref]):
|
82 |
filtered_refs[ref] = text
|
83 |
|
84 |
filtered_all_docs[doc_key] = doc
|
@@ -205,16 +195,9 @@ def get_exact_ctg_data(pred_in, true_in, ctg):
|
|
205 |
|
206 |
out_pred = {}
|
207 |
out_true = {}
|
208 |
-
check_if_courts = ctg == "Суды"
|
209 |
for idx, (pred, true) in zip(true_in.keys(), zip(pred_in.values(), true_in.values())):
|
210 |
-
if
|
211 |
-
|
212 |
-
if not any([True for x in all_types_but_courts if x in ref])]
|
213 |
-
ctg_refs_pred = [ref for ref in pred
|
214 |
-
if not any([True for x in all_types_but_courts if x in ref])]
|
215 |
-
else:
|
216 |
-
ctg_refs_true = [ref for ref in true if ctg in ref]
|
217 |
-
ctg_refs_pred = [ref for ref in pred if ctg in ref]
|
218 |
|
219 |
out_true[idx] = ctg_refs_true
|
220 |
out_pred[idx] = ctg_refs_pred
|
|
|
10 |
from torch.cuda.amp import autocast
|
11 |
|
12 |
|
|
|
|
|
13 |
court_text_splitter = "Весь текст судебного документа: "
|
14 |
|
15 |
|
|
|
52 |
def filter_db_data_types(text_parts, db_data_in):
|
53 |
filtered_db_data = {}
|
54 |
db_data = copy.deepcopy(db_data_in)
|
55 |
+
for ref, text in db_data.items():
|
56 |
+
if any([True for x in text_parts if x in ref]):
|
|
|
|
|
|
|
|
|
57 |
filtered_db_data[ref] = text
|
58 |
return filtered_db_data
|
59 |
|
|
|
67 |
continue
|
68 |
|
69 |
filtered_refs = {}
|
|
|
70 |
for ref, text in doc['added_refs'].items():
|
71 |
+
if any([True for x in text_parts if x in ref]):
|
|
|
|
|
|
|
72 |
filtered_refs[ref] = text
|
73 |
|
74 |
filtered_all_docs[doc_key] = doc
|
|
|
195 |
|
196 |
out_pred = {}
|
197 |
out_true = {}
|
|
|
198 |
for idx, (pred, true) in zip(true_in.keys(), zip(pred_in.values(), true_in.values())):
|
199 |
+
ctg_refs_true = [ref for ref in true if ctg in ref]
|
200 |
+
ctg_refs_pred = [ref for ref in pred if ctg in ref]
|
|
|
|
|
|
|
|
|
|
|
|
|
201 |
|
202 |
out_true[idx] = ctg_refs_true
|
203 |
out_pred[idx] = ctg_refs_pred
|
semantic_search.py
CHANGED
@@ -11,7 +11,7 @@ from legal_info_search_utils.utils import filter_db_data_types, filter_qa_data_t
|
|
11 |
from legal_info_search_utils.utils import db_tokenization, qa_tokenization
|
12 |
from legal_info_search_utils.utils import extract_text_embeddings, filter_ref_parts
|
13 |
from legal_info_search_utils.utils import print_metrics, get_final_metrics
|
14 |
-
from legal_info_search_utils.utils import
|
15 |
from legal_info_search_utils.utils import query_tokenization, query_embed_extraction
|
16 |
from legal_info_search_utils.metrics import calculate_metrics_at_k
|
17 |
|
@@ -22,7 +22,7 @@ global_model_path = os.environ.get("GLOBAL_MODEL_PATH",
|
|
22 |
|
23 |
# размеченные консультации
|
24 |
data_path_consult = os.environ.get("DATA_PATH_CONSULT",
|
25 |
-
global_data_path + "
|
26 |
|
27 |
# id консультаций, побитые на train / valid / test
|
28 |
data_path_consult_ids = os.environ.get("DATA_PATH_CONSULT_IDS",
|
@@ -43,7 +43,8 @@ db_data_types = os.environ.get("DB_DATA_TYPES", [
|
|
43 |
'Письмо ФНС',
|
44 |
'Приказ ФНС',
|
45 |
'Постановление Правительства',
|
46 |
-
'
|
|
|
47 |
])
|
48 |
|
49 |
device = os.environ.get("MODEL_DEVICE", 'cuda' if torch.cuda.is_available() else 'cpu')
|
@@ -80,6 +81,7 @@ class SemanticSearch:
|
|
80 |
db_data = get_subsets_for_db(db_subsets, data_ids, all_docs)
|
81 |
filtered_all_docs = filter_qa_data_types(db_data_types, all_docs)
|
82 |
|
|
|
83 |
self.filtered_db_data = filter_db_data_types(db_data_types, db_data)
|
84 |
self.all_docs_qa = get_subsets_for_qa(["valid"], data_ids, filtered_all_docs)
|
85 |
|
@@ -108,24 +110,16 @@ class SemanticSearch:
|
|
108 |
self.docs_embeds_faiss = docs_embeds_faiss
|
109 |
self.questions_embeds_faiss = questions_embeds_faiss
|
110 |
self.optimal_params = {
|
111 |
-
'НКРФ': {
|
112 |
-
'
|
113 |
-
'
|
114 |
-
'
|
115 |
-
'
|
116 |
-
'
|
117 |
-
'
|
118 |
-
'
|
119 |
-
'
|
120 |
-
'
|
121 |
-
'Письмо ФНС': {
|
122 |
-
'thresh': 0.879310, 'sim_factor': 0.5, 'diff_n': 0},
|
123 |
-
'Приказ ФНС': {
|
124 |
-
'thresh': 0.806896, 'sim_factor': 0.5, 'diff_n': 0},
|
125 |
-
'Постановление Правительства': {
|
126 |
-
'thresh': 0.782758, 'sim_factor': 0.5, 'diff_n': 0},
|
127 |
-
'Суды': {
|
128 |
-
'thresh': 0.846153, 'sim_factor': 0.939230,'diff_n': 0}
|
129 |
}
|
130 |
self.ref_categories = {
|
131 |
'all': 'all',
|
@@ -133,11 +127,12 @@ class SemanticSearch:
|
|
133 |
'ГКРФ': 'ГКРФ',
|
134 |
'ТКРФ': 'ТКРФ',
|
135 |
'Федеральный закон': 'ФЗ',
|
136 |
-
'
|
137 |
'Письмо Минфина': 'Письмо МФ',
|
138 |
'Письмо ФНС': 'Письмо ФНС',
|
139 |
'Приказ ФНС': 'Приказ ФНС',
|
140 |
-
'Постановление Правительства': 'Пост. Прав.'
|
|
|
141 |
}
|
142 |
|
143 |
def test_search(self):
|
@@ -157,7 +152,7 @@ class SemanticSearch:
|
|
157 |
fp, fs = self.search_results_filtering(p, d[0])
|
158 |
pred[idx] = fp
|
159 |
|
160 |
-
# раскомментировать нужное. Если
|
161 |
# посчтаются "как есть", с учетом полной иерархии
|
162 |
filter_parts = [
|
163 |
# "абз.",
|
@@ -168,8 +163,10 @@ class SemanticSearch:
|
|
168 |
filtered_true = filter_ref_parts(true, filter_parts)
|
169 |
|
170 |
metrics_func_params = {
|
171 |
-
|
172 |
-
'dynamic_topk': True
|
|
|
|
|
173 |
}
|
174 |
metrics = get_final_metrics(filtered_pred, filtered_true,
|
175 |
self.ref_categories.keys(), [0],
|
@@ -185,26 +182,15 @@ class SemanticSearch:
|
|
185 |
for ctg in db_data_types:
|
186 |
ctg_thresh = self.optimal_params[ctg]['thresh']
|
187 |
ctg_sim_factor = self.optimal_params[ctg]['sim_factor']
|
188 |
-
ctg_diff_n = self.optimal_params[ctg]['diff_n']
|
189 |
|
190 |
-
|
191 |
-
|
192 |
-
and not any([True for type_ in all_types_but_courts if type_ in ref])]
|
193 |
-
else:
|
194 |
-
ctg_preds = [(ref, dist) for ref, dist in zip(pred, dists)
|
195 |
-
if ctg in ref and dist > ctg_thresh]
|
196 |
|
197 |
sorted_pd = sorted(ctg_preds, key=lambda x: x[1], reverse=True)
|
198 |
sorted_preds = [x[0] for x in sorted_pd]
|
199 |
sorted_dists = [x[1] for x in sorted_pd]
|
200 |
|
201 |
if len(sorted_dists):
|
202 |
-
diffs = np.diff(sorted_dists, ctg_diff_n)
|
203 |
-
if len(diffs):
|
204 |
-
n_preds = np.argmax(diffs) + ctg_diff_n + 1
|
205 |
-
else:
|
206 |
-
n_preds = 0
|
207 |
-
|
208 |
if len(sorted_dists) > 1:
|
209 |
ratios = (sorted_dists[1:] / sorted_dists[0]) >= ctg_sim_factor
|
210 |
ratios = np.array([True, *ratios])
|
@@ -213,15 +199,12 @@ class SemanticSearch:
|
|
213 |
|
214 |
main_preds = np.array(sorted_preds)[np.where(ratios)].tolist()
|
215 |
scores = np.array(sorted_dists)[np.where(ratios)].tolist()
|
216 |
-
if ctg_diff_n > 0 and n_preds > 0:
|
217 |
-
main_preds = main_preds[:n_preds]
|
218 |
-
scores = scores[:n_preds]
|
219 |
else:
|
220 |
main_preds = []
|
221 |
scores = []
|
222 |
|
223 |
-
all_ctg_preds.extend(main_preds)
|
224 |
-
all_scores.extend(scores)
|
225 |
|
226 |
sorted_values = [(ref, score) for ref, score in zip(all_ctg_preds, all_scores)]
|
227 |
sorted_values = sorted(sorted_values, key=lambda x: x[1], reverse=True)
|
@@ -235,7 +218,7 @@ class SemanticSearch:
|
|
235 |
new_docs = []
|
236 |
|
237 |
for ref_name, ref_text in zip(preds, docs):
|
238 |
-
is_court =
|
239 |
has_splitter = court_text_splitter in ref_text
|
240 |
|
241 |
if is_court and has_splitter:
|
@@ -247,7 +230,24 @@ class SemanticSearch:
|
|
247 |
new_docs.append(ref_text)
|
248 |
return new_preds, new_docs
|
249 |
|
250 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
251 |
query_tokens = query_tokenization(query, self.tokenizer)
|
252 |
query_embeds = query_embed_extraction(query_tokens, self.model,
|
253 |
self.do_normalization)
|
|
|
11 |
from legal_info_search_utils.utils import db_tokenization, qa_tokenization
|
12 |
from legal_info_search_utils.utils import extract_text_embeddings, filter_ref_parts
|
13 |
from legal_info_search_utils.utils import print_metrics, get_final_metrics
|
14 |
+
from legal_info_search_utils.utils import court_text_splitter
|
15 |
from legal_info_search_utils.utils import query_tokenization, query_embed_extraction
|
16 |
from legal_info_search_utils.metrics import calculate_metrics_at_k
|
17 |
|
|
|
22 |
|
23 |
# размеченные консультации
|
24 |
data_path_consult = os.environ.get("DATA_PATH_CONSULT",
|
25 |
+
global_data_path + "data_jsons_20240202.pkl")
|
26 |
|
27 |
# id консультаций, побитые на train / valid / test
|
28 |
data_path_consult_ids = os.environ.get("DATA_PATH_CONSULT_IDS",
|
|
|
43 |
'Письмо ФНС',
|
44 |
'Приказ ФНС',
|
45 |
'Постановление Правительства',
|
46 |
+
'Судебный документ',
|
47 |
+
'Внутренний документ'
|
48 |
])
|
49 |
|
50 |
device = os.environ.get("MODEL_DEVICE", 'cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
81 |
db_data = get_subsets_for_db(db_subsets, data_ids, all_docs)
|
82 |
filtered_all_docs = filter_qa_data_types(db_data_types, all_docs)
|
83 |
|
84 |
+
self.mean_refs_count = self.get_mean_refs_counts(db_data_types, filtered_all_docs)
|
85 |
self.filtered_db_data = filter_db_data_types(db_data_types, db_data)
|
86 |
self.all_docs_qa = get_subsets_for_qa(["valid"], data_ids, filtered_all_docs)
|
87 |
|
|
|
110 |
self.docs_embeds_faiss = docs_embeds_faiss
|
111 |
self.questions_embeds_faiss = questions_embeds_faiss
|
112 |
self.optimal_params = {
|
113 |
+
'НКРФ': {'thresh': 0.61579, 'sim_factor': 0.84211},
|
114 |
+
'ГКРФ': {'thresh': 0.55263, 'sim_factor': 0.0},
|
115 |
+
'ТКРФ': {'thresh': 0.48947, 'sim_factor': 1.0},
|
116 |
+
'Федеральный закон': {'thresh': 0.52105, 'sim_factor': 0.94737},
|
117 |
+
'Письмо Минфина': {'thresh': 0.71053, 'sim_factor': 0.0},
|
118 |
+
'Письмо ФНС': {'thresh': 0.61579, 'sim_factor': 0.84211},
|
119 |
+
'Приказ ФНС': {'thresh': 0.52105, 'sim_factor': 0.94737},
|
120 |
+
'Постановление Правительства': {'thresh': 0.45789, 'sim_factor': 0.89474},
|
121 |
+
'Судебный документ': {'thresh': 0.80526, 'sim_factor': 0.89474},
|
122 |
+
'Внутренний документ': {'thresh': 0.71053, 'sim_factor': 0.0}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
}
|
124 |
self.ref_categories = {
|
125 |
'all': 'all',
|
|
|
127 |
'ГКРФ': 'ГКРФ',
|
128 |
'ТКРФ': 'ТКРФ',
|
129 |
'Федеральный закон': 'ФЗ',
|
130 |
+
'Судебный документ': 'Суды',
|
131 |
'Письмо Минфина': 'Письмо МФ',
|
132 |
'Письмо ФНС': 'Письмо ФНС',
|
133 |
'Приказ ФНС': 'Приказ ФНС',
|
134 |
+
'Постановление Правительства': 'Пост. Прав.',
|
135 |
+
'Внутренний документ': 'Внутр. док.'
|
136 |
}
|
137 |
|
138 |
def test_search(self):
|
|
|
152 |
fp, fs = self.search_results_filtering(p, d[0])
|
153 |
pred[idx] = fp
|
154 |
|
155 |
+
# раскомментировать нужное. Если в��ё закомментировано - метрики
|
156 |
# посчтаются "как есть", с учетом полной иерархии
|
157 |
filter_parts = [
|
158 |
# "абз.",
|
|
|
163 |
filtered_true = filter_ref_parts(true, filter_parts)
|
164 |
|
165 |
metrics_func_params = {
|
166 |
+
'compensate_div_0': True,
|
167 |
+
'dynamic_topk': True,
|
168 |
+
'skip_empty_trues': False,
|
169 |
+
'skip_empty_preds': False
|
170 |
}
|
171 |
metrics = get_final_metrics(filtered_pred, filtered_true,
|
172 |
self.ref_categories.keys(), [0],
|
|
|
182 |
for ctg in db_data_types:
|
183 |
ctg_thresh = self.optimal_params[ctg]['thresh']
|
184 |
ctg_sim_factor = self.optimal_params[ctg]['sim_factor']
|
|
|
185 |
|
186 |
+
ctg_preds = [(ref, dist) for ref, dist in zip(pred, dists)
|
187 |
+
if ctg in ref and dist > ctg_thresh]
|
|
|
|
|
|
|
|
|
188 |
|
189 |
sorted_pd = sorted(ctg_preds, key=lambda x: x[1], reverse=True)
|
190 |
sorted_preds = [x[0] for x in sorted_pd]
|
191 |
sorted_dists = [x[1] for x in sorted_pd]
|
192 |
|
193 |
if len(sorted_dists):
|
|
|
|
|
|
|
|
|
|
|
|
|
194 |
if len(sorted_dists) > 1:
|
195 |
ratios = (sorted_dists[1:] / sorted_dists[0]) >= ctg_sim_factor
|
196 |
ratios = np.array([True, *ratios])
|
|
|
199 |
|
200 |
main_preds = np.array(sorted_preds)[np.where(ratios)].tolist()
|
201 |
scores = np.array(sorted_dists)[np.where(ratios)].tolist()
|
|
|
|
|
|
|
202 |
else:
|
203 |
main_preds = []
|
204 |
scores = []
|
205 |
|
206 |
+
all_ctg_preds.extend(main_preds[:self.mean_refs_count[ctg]])
|
207 |
+
all_scores.extend(scores[:self.mean_refs_count[ctg]])
|
208 |
|
209 |
sorted_values = [(ref, score) for ref, score in zip(all_ctg_preds, all_scores)]
|
210 |
sorted_values = sorted(sorted_values, key=lambda x: x[1], reverse=True)
|
|
|
218 |
new_docs = []
|
219 |
|
220 |
for ref_name, ref_text in zip(preds, docs):
|
221 |
+
is_court = 'Судебный документ' in ref_name
|
222 |
has_splitter = court_text_splitter in ref_text
|
223 |
|
224 |
if is_court and has_splitter:
|
|
|
230 |
new_docs.append(ref_text)
|
231 |
return new_preds, new_docs
|
232 |
|
233 |
+
@staticmethod
|
234 |
+
def get_mean_refs_counts(db_data_types, data):
|
235 |
+
mean_refs_count = {}
|
236 |
+
for tp in db_data_types:
|
237 |
+
all_tp_refs = []
|
238 |
+
for doc in data.values():
|
239 |
+
tp_refs_len = len([ref for ref in doc['added_refs'] if tp in ref])
|
240 |
+
if tp_refs_len:
|
241 |
+
all_tp_refs.append(tp_refs_len)
|
242 |
+
|
243 |
+
mean_refs_count[tp] = np.mean(all_tp_refs)
|
244 |
+
|
245 |
+
for k, v in mean_refs_count.items():
|
246 |
+
mean_refs_count[k] = int(v + 1)
|
247 |
+
|
248 |
+
return mean_refs_count
|
249 |
+
|
250 |
+
def search(self, query, top=15):
|
251 |
query_tokens = query_tokenization(query, self.tokenizer)
|
252 |
query_embeds = query_embed_extraction(query_tokens, self.model,
|
253 |
self.do_normalization)
|