dorogan commited on
Commit
45d03f9
1 Parent(s): 5ea5121

Update: new semantic search logic was provided, requirements.txt file was fixed

Browse files
Files changed (3) hide show
  1. app.py +6 -1
  2. requirements.txt +2 -1
  3. semantic_search.py +89 -167
app.py CHANGED
@@ -5,7 +5,6 @@ from semantic_search import SemanticSearch
5
  from datetime import datetime
6
 
7
  search = SemanticSearch()
8
- search.test_search()
9
 
10
  app = Flask(__name__)
11
  app.config['JSON_AS_ASCII'] = False
@@ -20,6 +19,7 @@ if not os.path.exists(LOGS_BASE_PATH):
20
  # Check if logs are enabled
21
  ENABLE_LOGS = os.getenv("ENABLE_LOGS", "0") == "1"
22
 
 
23
  def log_query_result(query, top, request_id, result):
24
  if not ENABLE_LOGS:
25
  return
@@ -38,10 +38,12 @@ def log_query_result(query, top, request_id, result):
38
  with open(log_file_path, 'w') as log_file:
39
  json.dump(log_data, log_file, indent=2)
40
 
 
41
  @app.route('/health', methods=['GET'])
42
  def health():
43
  return jsonify({"status": "ok"})
44
 
 
45
  @app.route('/search', methods=['POST'])
46
  def search_route():
47
  data = request.get_json()
@@ -56,6 +58,7 @@ def search_route():
56
 
57
  return jsonify(result)
58
 
 
59
  @app.route('/read_logs', methods=['GET'])
60
  def read_logs():
61
  logs = []
@@ -66,6 +69,7 @@ def read_logs():
66
  logs.append(log_data)
67
  return jsonify(logs)
68
 
 
69
  @app.route('/analyze_logs', methods=['GET'])
70
  def analyze_logs():
71
  logs_by_query_top = {}
@@ -91,5 +95,6 @@ def analyze_logs():
91
 
92
  return jsonify(invalid_logs)
93
 
 
94
  if __name__ == '__main__':
95
  app.run(debug=False, host='0.0.0.0')
 
5
  from datetime import datetime
6
 
7
  search = SemanticSearch()
 
8
 
9
  app = Flask(__name__)
10
  app.config['JSON_AS_ASCII'] = False
 
19
  # Check if logs are enabled
20
  ENABLE_LOGS = os.getenv("ENABLE_LOGS", "0") == "1"
21
 
22
+
23
  def log_query_result(query, top, request_id, result):
24
  if not ENABLE_LOGS:
25
  return
 
38
  with open(log_file_path, 'w') as log_file:
39
  json.dump(log_data, log_file, indent=2)
40
 
41
+
42
  @app.route('/health', methods=['GET'])
43
  def health():
44
  return jsonify({"status": "ok"})
45
 
46
+
47
  @app.route('/search', methods=['POST'])
48
  def search_route():
49
  data = request.get_json()
 
58
 
59
  return jsonify(result)
60
 
61
+
62
  @app.route('/read_logs', methods=['GET'])
63
  def read_logs():
64
  logs = []
 
69
  logs.append(log_data)
70
  return jsonify(logs)
71
 
72
+
73
  @app.route('/analyze_logs', methods=['GET'])
74
  def analyze_logs():
75
  logs_by_query_top = {}
 
95
 
96
  return jsonify(invalid_logs)
97
 
98
+
99
  if __name__ == '__main__':
100
  app.run(debug=False, host='0.0.0.0')
requirements.txt CHANGED
@@ -8,4 +8,5 @@ transformers==4.29.2
8
  # sentencepiece==0.1.99
9
  # six==1.16.0
10
  # tokenizers==0.13.3
11
- flask==3.0.0
 
 
8
  # sentencepiece==0.1.99
9
  # six==1.16.0
10
  # tokenizers==0.13.3
11
+ flask==3.0.0
12
+ datasets
semantic_search.py CHANGED
@@ -1,54 +1,58 @@
1
  import os
2
- import json
3
  import torch
4
- import pickle
5
  import numpy as np
6
  import faiss
7
-
8
  from transformers import AutoTokenizer, AutoModel
9
- from legal_info_search_utils.utils import get_subsets_for_db, get_subsets_for_qa
10
- from legal_info_search_utils.utils import filter_db_data_types, filter_qa_data_types
11
- from legal_info_search_utils.utils import db_tokenization, qa_tokenization
12
- from legal_info_search_utils.utils import extract_text_embeddings, filter_ref_parts
13
- from legal_info_search_utils.utils import print_metrics, get_final_metrics
14
- from legal_info_search_utils.utils import court_text_splitter
15
  from legal_info_search_utils.utils import query_tokenization, query_embed_extraction
16
- from legal_info_search_utils.metrics import calculate_metrics_at_k
 
17
 
18
 
19
  global_data_path = os.environ.get("GLOBAL_DATA_PATH", "legal_info_search_data/")
20
- global_model_path = os.environ.get("GLOBAL_MODEL_PATH",
21
  "legal_info_search_model/20240202_204910_ep8/")
22
 
23
  # размеченные консультации
24
- data_path_consult = os.environ.get("DATA_PATH_CONSULT",
25
- global_data_path + "data_jsons_20240202.pkl")
 
 
 
26
 
27
  # id консультаций, побитые на train / valid / test
28
- data_path_consult_ids = os.environ.get("DATA_PATH_CONSULT_IDS",
29
- global_data_path + "data_ids.json")
30
 
31
  # предобработанные внутренние документы
32
- data_path_internal_docs = os.environ.get("DATA_PATH_INTERNAL_DOCS",
33
- global_data_path + "internal_docs.json")
34
 
35
  # состав БД
36
  # $ export DB_SUBSETS='["train", "valid", "test"]'
37
- db_subsets = os.environ.get("DB_SUBSETS", ["train", "valid", "test"])
38
 
39
  # Отбор типов документов. В списке указать те, которые нужно оставить в БД.
40
  # $ export DB_DATA_TYPES='["НКРФ", "ГКРФ", "ТКРФ"]'
41
  db_data_types = os.environ.get("DB_DATA_TYPES", [
42
- 'НКРФ',
43
- 'ГКРФ',
44
- 'ТКРФ',
45
- 'Федеральный закон',
46
  'Письмо Минфина',
47
  'Письмо ФНС',
48
- 'Приказ ФНС',
49
- 'Постановление Правительства',
50
  'Судебный документ',
51
- 'Внутренний документ'
52
  ])
53
 
54
  device = os.environ.get("MODEL_DEVICE", 'cuda' if torch.cuda.is_available() else 'cpu')
@@ -57,67 +61,22 @@ device = os.environ.get("MODEL_DEVICE", 'cuda' if torch.cuda.is_available() else
57
  hf_token = os.environ.get("HF_TOKEN", "")
58
  hf_model_name = os.environ.get("HF_MODEL_NAME", "")
59
 
 
60
  class SemanticSearch:
61
- def __init__(self, index_type="IndexFlatIP", do_embedding_norm=True,
62
- faiss_batch_size=8, do_normalization=True):
 
 
63
  self.device = device
64
  self.do_embedding_norm = do_embedding_norm
65
- self.faiss_batch_size = faiss_batch_size
66
  self.do_normalization = do_normalization
67
  self.load_model()
68
-
69
  indexes = {
70
  "IndexFlatL2": faiss.IndexFlatL2(self.embedding_dim),
71
  "IndexFlatIP": faiss.IndexFlatIP(self.embedding_dim)
72
  }
73
- self.index = indexes[index_type]
74
- self.load_data()
75
- self.preproces_data()
76
- self.test_search()
77
 
78
- def load_data(self):
79
- with open(data_path_consult, "rb") as f:
80
- all_docs = pickle.load(f)
81
-
82
- with open(data_path_consult_ids, "r", encoding="utf-8") as f:
83
- data_ids = json.load(f)
84
-
85
- with open(data_path_internal_docs, "r", encoding="utf-8") as f:
86
- internal_docs = json.load(f)
87
-
88
- db_data = get_subsets_for_db(db_subsets, data_ids, all_docs)
89
- filtered_all_docs = filter_qa_data_types(db_data_types, all_docs)
90
-
91
- self.mean_refs_count = self.get_mean_refs_counts(db_data_types, filtered_all_docs)
92
- self.mean_refs_count['Внутренний документ'] = 3
93
- self.filtered_db_data = filter_db_data_types(db_data_types, db_data)
94
- self.filtered_db_data.update(internal_docs)
95
- self.all_docs_qa = get_subsets_for_qa(["valid"], data_ids, filtered_all_docs)
96
-
97
- def load_model(self):
98
- if hf_token and hf_model_name:
99
- self.tokenizer = AutoTokenizer.from_pretrained(hf_model_name, use_auth_token=True)
100
- self.model = AutoModel.from_pretrained(hf_model_name, use_auth_token=True).to(self.device)
101
- else:
102
- self.tokenizer = AutoTokenizer.from_pretrained(global_model_path)
103
- self.model = AutoModel.from_pretrained(global_model_path).to(self.device)
104
-
105
- self.max_len = self.tokenizer.max_len_single_sentence
106
- self.embedding_dim = self.model.config.hidden_size
107
-
108
- def preproces_data(self):
109
- index_keys, index_toks = db_tokenization(self.filtered_db_data, self.tokenizer)
110
- val_questions, val_refs = qa_tokenization(self.all_docs_qa, self.tokenizer)
111
- docs_embeds_faiss, questions_embeds_faiss = extract_text_embeddings(index_toks,
112
- val_questions, self.model, self.do_normalization, self.faiss_batch_size)
113
-
114
- self.index.add(docs_embeds_faiss)
115
- self.index_keys = index_keys
116
- self.index_toks = index_toks
117
- self.val_questions = val_questions
118
- self.val_refs = val_refs
119
- self.docs_embeds_faiss = docs_embeds_faiss
120
- self.questions_embeds_faiss = questions_embeds_faiss
121
  self.optimal_params = {
122
  'НКРФ': {'thresh': 0.58421, 'sim_factor': 0.89474},
123
  'ГКРФ': {'thresh': 0.64737, 'sim_factor': 0.89474},
@@ -130,61 +89,33 @@ class SemanticSearch:
130
  'Судебный документ': {'thresh': 0.67895, 'sim_factor': 0.89474},
131
  'Внутренний документ': {'thresh': 0.55263, 'sim_factor': 0.84211}
132
  }
133
- self.ref_categories = {
134
- 'all': 'all',
135
- 'НКРФ': 'НКРФ',
136
- 'ГКРФ': 'ГКРФ',
137
- 'ТКРФ': 'ТКРФ',
138
- 'Федеральный закон': 'ФЗ',
139
- 'Судебный документ': 'Суды',
140
- 'Письмо Минфина': 'Письмо МФ',
141
- 'Письмо ФНС': 'Письмо ФНС',
142
- 'Приказ ФНС': 'Приказ ФНС',
143
- 'Постановление Правительства': 'Пост. Прав.',
144
- 'Внутренний документ': 'Внутр. док.'
145
- }
146
-
147
- def test_search(self):
148
- topk = len(self.filtered_db_data)
149
- pred_raw = {}
150
- true = {}
151
- all_distances = []
152
- for idx, (q_embed, refs) in enumerate(zip(self.questions_embeds_faiss,
153
- self.val_refs.values())):
154
- distances, indices = self.index.search(np.expand_dims(q_embed, 0), topk)
155
- pred_raw[idx] = [self.index_keys[x] for x in indices[0]]
156
- true[idx] = list(refs)
157
- all_distances.append(distances)
158
 
159
- pred = {}
160
- for idx, p, d in zip(true.keys(), pred_raw.values(), all_distances):
161
- fp, fs = self.search_results_filtering(p, d[0])
162
- pred[idx] = fp
 
 
163
 
164
- # раскомментировать нужное. Если всё закомментировано - метрики
165
- # посчтаются "как есть", с учетом полной иерархии
166
- filter_parts = [
167
- # "абз.",
168
- # "пп.",
169
- # "п."
170
- ]
171
- filtered_pred = filter_ref_parts(pred, filter_parts)
172
- filtered_true = filter_ref_parts(true, filter_parts)
173
 
174
- metrics_func_params = {
175
- 'compensate_div_0': True,
176
- 'dynamic_topk': True,
177
- 'skip_empty_trues': False,
178
- 'skip_empty_preds': False
179
- }
180
- metrics = get_final_metrics(filtered_pred, filtered_true,
181
- self.ref_categories.keys(), [0],
182
- metrics_func=calculate_metrics_at_k,
183
- metrics_func_params=metrics_func_params)
184
-
185
- print_metrics(metrics, self.ref_categories)
186
 
187
-
188
  def search_results_filtering(self, pred, dists):
189
  all_ctg_preds = []
190
  all_scores = []
@@ -192,7 +123,7 @@ class SemanticSearch:
192
  ctg_thresh = self.optimal_params[ctg]['thresh']
193
  ctg_sim_factor = self.optimal_params[ctg]['sim_factor']
194
 
195
- ctg_preds = [(ref, dist) for ref, dist in zip(pred, dists)
196
  if ctg in ref and dist > ctg_thresh]
197
 
198
  sorted_pd = sorted(ctg_preds, key=lambda x: x[1], reverse=True)
@@ -222,48 +153,39 @@ class SemanticSearch:
222
 
223
  return sorted_preds, sorted_scores
224
 
225
- def court_docs_shrinking(self, preds, docs):
226
- new_preds = []
227
- new_docs = []
228
-
229
- for ref_name, ref_text in zip(preds, docs):
230
- is_court = 'Судебный документ' in ref_name
231
- has_splitter = court_text_splitter in ref_text
232
-
233
- if is_court and has_splitter:
234
- new_ref_text = ref_text.split(court_text_splitter)[0].strip()
235
- new_preds.append(ref_name)
236
- new_docs.append(new_ref_text)
237
- else:
238
- new_preds.append(ref_name)
239
- new_docs.append(ref_text)
240
- return new_preds, new_docs
241
-
242
- @staticmethod
243
- def get_mean_refs_counts(db_data_types, data):
244
- mean_refs_count = {}
245
- for tp in db_data_types:
246
- all_tp_refs = []
247
- for doc in data.values():
248
- tp_refs_len = len([ref for ref in doc['added_refs'] if tp in ref])
249
- if tp_refs_len:
250
- all_tp_refs.append(tp_refs_len)
251
-
252
- mean_refs_count[tp] = np.mean(all_tp_refs)
253
-
254
- for k, v in mean_refs_count.items():
255
- mean_refs_count[k] = int(v + 1)
256
-
257
- return mean_refs_count
258
 
259
  def search(self, query, top=15):
260
  query_tokens = query_tokenization(query, self.tokenizer)
261
- query_embeds = query_embed_extraction(query_tokens, self.model,
262
  self.do_normalization)
263
- distances, indices = self.index.search(query_embeds, len(self.filtered_db_data))
264
- pred = [self.index_keys[x] for x in indices[0]]
265
  preds, scores = self.search_results_filtering(pred, distances[0])
266
- docs = [self.filtered_db_data[ref] for ref in preds]
267
- preds, docs = self.court_docs_shrinking(preds, docs)
268
-
 
 
 
 
269
  return preds[:top], docs[:top], scores[:top]
 
1
  import os
2
+ # import json
3
  import torch
4
+ # import pickle
5
  import numpy as np
6
  import faiss
7
+ from datasets import Dataset as dataset
8
  from transformers import AutoTokenizer, AutoModel
9
+ # from legal_info_search_utils.utils import get_subsets_for_db, get_subsets_for_qa
10
+ # from legal_info_search_utils.utils import filter_db_data_types, filter_qa_data_types
11
+ # from legal_info_search_utils.utils import db_tokenization, qa_tokenization
12
+ # from legal_info_search_utils.utils import extract_text_embeddings, filter_ref_parts
13
+ # from legal_info_search_utils.utils import print_metrics, get_final_metrics
14
+ # from legal_info_search_utils.utils import court_text_splitter
15
  from legal_info_search_utils.utils import query_tokenization, query_embed_extraction
16
+
17
+ # from legal_info_search_utils.metrics import calculate_metrics_at_k
18
 
19
 
20
  global_data_path = os.environ.get("GLOBAL_DATA_PATH", "legal_info_search_data/")
21
+ global_model_path = os.environ.get("GLOBAL_MODEL_PATH",
22
  "legal_info_search_model/20240202_204910_ep8/")
23
 
24
  # размеченные консультации
25
+ # data_path_consult = os.environ.get("DATA_PATH_CONSULT",
26
+ # global_data_path + "data_jsons_20240202.pkl")
27
+
28
+ data_path_consult = os.environ.get("DATA_PATH_CONSULT",
29
+ global_data_path + "court_dataset_chunk_200_correct_tokenizer_for_develop")
30
 
31
  # id консультаций, побитые на train / valid / test
32
+ # data_path_consult_ids = os.environ.get("DATA_PATH_CONSULT_IDS",
33
+ # global_data_path + "data_ids.json")
34
 
35
  # предобработанные внутренние документы
36
+ # data_path_internal_docs = os.environ.get("DATA_PATH_INTERNAL_DOCS",
37
+ # global_data_path + "internal_docs.json")
38
 
39
  # состав БД
40
  # $ export DB_SUBSETS='["train", "valid", "test"]'
41
+ # db_subsets = os.environ.get("DB_SUBSETS", ["train", "valid", "test"])
42
 
43
  # Отбор типов документов. В списке указать те, которые нужно оставить в БД.
44
  # $ export DB_DATA_TYPES='["НКРФ", "ГКРФ", "ТКРФ"]'
45
  db_data_types = os.environ.get("DB_DATA_TYPES", [
46
+ # 'НКРФ',
47
+ # 'ГКРФ',
48
+ # 'ТКРФ',
49
+ # 'Федеральный закон',
50
  'Письмо Минфина',
51
  'Письмо ФНС',
52
+ # 'Приказ ФНС',
53
+ # 'Постановление Правительства',
54
  'Судебный документ',
55
+ # 'Внутренний документ'
56
  ])
57
 
58
  device = os.environ.get("MODEL_DEVICE", 'cuda' if torch.cuda.is_available() else 'cpu')
 
61
  hf_token = os.environ.get("HF_TOKEN", "")
62
  hf_model_name = os.environ.get("HF_MODEL_NAME", "")
63
 
64
+
65
  class SemanticSearch:
66
+ def __init__(self,
67
+ index_type="IndexFlatIP",
68
+ do_embedding_norm=True,
69
+ do_normalization=True):
70
  self.device = device
71
  self.do_embedding_norm = do_embedding_norm
 
72
  self.do_normalization = do_normalization
73
  self.load_model()
74
+
75
  indexes = {
76
  "IndexFlatL2": faiss.IndexFlatL2(self.embedding_dim),
77
  "IndexFlatIP": faiss.IndexFlatIP(self.embedding_dim)
78
  }
 
 
 
 
79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  self.optimal_params = {
81
  'НКРФ': {'thresh': 0.58421, 'sim_factor': 0.89474},
82
  'ГКРФ': {'thresh': 0.64737, 'sim_factor': 0.89474},
 
89
  'Судебный документ': {'thresh': 0.67895, 'sim_factor': 0.89474},
90
  'Внутренний документ': {'thresh': 0.55263, 'sim_factor': 0.84211}
91
  }
92
+ self.index_type = index_type
93
+ self.index_docs = indexes[self.index_type]
94
+ self.load_data()
95
+ self.docs_embeddings = [torch.unsqueeze(torch.Tensor(x['doc_embedding']), 0) for x in self.all_docs_info]
96
+ self.docs_embeddings = torch.cat(self.docs_embeddings, dim=0)
97
+ self.index_docs.add(self.docs_embeddings)
98
+ # self.preproces_data()
99
+ # self.test_search()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
+ def load_data(self):
102
+ self.all_docs_info = dataset.load_from_disk(data_path_consult).to_list()
103
+ self.docs_names = [doc['doc_name'] for doc in self.all_docs_info]
104
+ self.mean_refs_count = {'Письмо Минфина': 3,
105
+ 'Письмо ФНС': 2,
106
+ 'Судебный документ': 3}
107
 
108
+ def load_model(self):
109
+ if hf_token and hf_model_name:
110
+ self.tokenizer = AutoTokenizer.from_pretrained(hf_model_name, use_auth_token=True)
111
+ self.model = AutoModel.from_pretrained(hf_model_name, use_auth_token=True).to(self.device)
112
+ else:
113
+ self.tokenizer = AutoTokenizer.from_pretrained(global_model_path)
114
+ self.model = AutoModel.from_pretrained(global_model_path).to(self.device)
 
 
115
 
116
+ self.max_len = self.tokenizer.max_len_single_sentence
117
+ self.embedding_dim = self.model.config.hidden_size
 
 
 
 
 
 
 
 
 
 
118
 
 
119
  def search_results_filtering(self, pred, dists):
120
  all_ctg_preds = []
121
  all_scores = []
 
123
  ctg_thresh = self.optimal_params[ctg]['thresh']
124
  ctg_sim_factor = self.optimal_params[ctg]['sim_factor']
125
 
126
+ ctg_preds = [(ref, dist) for ref, dist in zip(pred, dists)
127
  if ctg in ref and dist > ctg_thresh]
128
 
129
  sorted_pd = sorted(ctg_preds, key=lambda x: x[1], reverse=True)
 
153
 
154
  return sorted_preds, sorted_scores
155
 
156
+ def get_most_relevant_teaser(self,
157
+ question: str = None,
158
+ doc_index: int = None):
159
+ teaser_indexes = {
160
+ "IndexFlatL2": faiss.IndexFlatL2(self.embedding_dim),
161
+ "IndexFlatIP": faiss.IndexFlatIP(self.embedding_dim)
162
+ }
163
+ teasers_index = teaser_indexes[self.index_type]
164
+ question_tokens = query_tokenization(question, self.tokenizer)
165
+ question_embedding = query_embed_extraction(question_tokens, self.model,
166
+ self.do_normalization)
167
+ # question_embedding = self.custom_embedder.create_question_embedding(question)
168
+ teasers_texts = [teaser['summary_text'] for teaser in self.all_docs_info[doc_index]['chunks_embeddings']]
169
+ teasers_embeddings = [torch.unsqueeze(torch.Tensor(teaser['embedding']), 0) for teaser in
170
+ self.all_docs_info[doc_index]['chunks_embeddings']]
171
+ teasers_embeddings = torch.cat(teasers_embeddings, 0)
172
+ teasers_index.add(teasers_embeddings)
173
+ distances, indices = teasers_index.search(question_embedding, 10)
174
+ most_relevant_teaser = teasers_texts[indices[0][0]]
175
+ return most_relevant_teaser
 
 
 
 
 
 
 
 
 
 
 
 
 
176
 
177
  def search(self, query, top=15):
178
  query_tokens = query_tokenization(query, self.tokenizer)
179
+ query_embeds = query_embed_extraction(query_tokens, self.model,
180
  self.do_normalization)
181
+ distances, indices = self.index_docs.search(query_embeds, len(self.all_docs_info))
182
+ pred = [self.all_docs_info[x]['doc_name'] for x in indices[0]]
183
  preds, scores = self.search_results_filtering(pred, distances[0])
184
+ # docs = [self.all_docs_info[x][ref] for ref in preds]
185
+ docs = []
186
+ for ref in preds:
187
+ doc_index = self.docs_names.index(ref)
188
+ most_relevant_teaser = self.get_most_relevant_teaser(question=query,
189
+ doc_index=doc_index)
190
+ docs.append(most_relevant_teaser)
191
  return preds[:top], docs[:top], scores[:top]