muryshev commited on
Commit
58de913
1 Parent(s): d84c926

Updated search

Browse files
legal_info_search_data/internal_docs.json ADDED
The diff for this file is too large to render. See raw diff
 
semantic_search.py CHANGED
@@ -18,7 +18,7 @@ from legal_info_search_utils.metrics import calculate_metrics_at_k
18
 
19
  global_data_path = os.environ.get("GLOBAL_DATA_PATH", "legal_info_search_data/")
20
  global_model_path = os.environ.get("GLOBAL_MODEL_PATH",
21
- "legal_info_search_model/20240120_122822_ep6/")
22
 
23
  # размеченные консультации
24
  data_path_consult = os.environ.get("DATA_PATH_CONSULT",
@@ -28,6 +28,10 @@ data_path_consult = os.environ.get("DATA_PATH_CONSULT",
28
  data_path_consult_ids = os.environ.get("DATA_PATH_CONSULT_IDS",
29
  global_data_path + "data_ids.json")
30
 
 
 
 
 
31
  # состав БД
32
  # $ export DB_SUBSETS='["train", "valid", "test"]'
33
  db_subsets = os.environ.get("DB_SUBSETS", ["train", "valid", "test"])
@@ -78,20 +82,23 @@ class SemanticSearch:
78
  with open(data_path_consult_ids, "r", encoding="utf-8") as f:
79
  data_ids = json.load(f)
80
 
 
 
 
81
  db_data = get_subsets_for_db(db_subsets, data_ids, all_docs)
82
  filtered_all_docs = filter_qa_data_types(db_data_types, all_docs)
83
 
84
  self.mean_refs_count = self.get_mean_refs_counts(db_data_types, filtered_all_docs)
 
85
  self.filtered_db_data = filter_db_data_types(db_data_types, db_data)
 
86
  self.all_docs_qa = get_subsets_for_qa(["valid"], data_ids, filtered_all_docs)
87
 
88
  def load_model(self):
89
  if hf_token and hf_model_name:
90
- print('Using model '+hf_model_name)
91
  self.tokenizer = AutoTokenizer.from_pretrained(hf_model_name, use_auth_token=True)
92
  self.model = AutoModel.from_pretrained(hf_model_name, use_auth_token=True).to(self.device)
93
  else:
94
- print('Using model '+global_model_path)
95
  self.tokenizer = AutoTokenizer.from_pretrained(global_model_path)
96
  self.model = AutoModel.from_pretrained(global_model_path).to(self.device)
97
 
 
18
 
19
  global_data_path = os.environ.get("GLOBAL_DATA_PATH", "legal_info_search_data/")
20
  global_model_path = os.environ.get("GLOBAL_MODEL_PATH",
21
+ "legal_info_search_model/20240202_204910_ep8/")
22
 
23
  # размеченные консультации
24
  data_path_consult = os.environ.get("DATA_PATH_CONSULT",
 
28
  data_path_consult_ids = os.environ.get("DATA_PATH_CONSULT_IDS",
29
  global_data_path + "data_ids.json")
30
 
31
+ # предобработанные внутренние документы
32
+ data_path_internal_docs = os.environ.get("DATA_PATH_INTERNAL_DOCS",
33
+ global_data_path + "internal_docs.json")
34
+
35
  # состав БД
36
  # $ export DB_SUBSETS='["train", "valid", "test"]'
37
  db_subsets = os.environ.get("DB_SUBSETS", ["train", "valid", "test"])
 
82
  with open(data_path_consult_ids, "r", encoding="utf-8") as f:
83
  data_ids = json.load(f)
84
 
85
+ with open(data_path_internal_docs, "r", encoding="utf-8") as f:
86
+ internal_docs = json.load(f)
87
+
88
  db_data = get_subsets_for_db(db_subsets, data_ids, all_docs)
89
  filtered_all_docs = filter_qa_data_types(db_data_types, all_docs)
90
 
91
  self.mean_refs_count = self.get_mean_refs_counts(db_data_types, filtered_all_docs)
92
+ self.mean_refs_count['Внутренний документ'] = 3
93
  self.filtered_db_data = filter_db_data_types(db_data_types, db_data)
94
+ self.filtered_db_data.update(internal_docs)
95
  self.all_docs_qa = get_subsets_for_qa(["valid"], data_ids, filtered_all_docs)
96
 
97
  def load_model(self):
98
  if hf_token and hf_model_name:
 
99
  self.tokenizer = AutoTokenizer.from_pretrained(hf_model_name, use_auth_token=True)
100
  self.model = AutoModel.from_pretrained(hf_model_name, use_auth_token=True).to(self.device)
101
  else:
 
102
  self.tokenizer = AutoTokenizer.from_pretrained(global_model_path)
103
  self.model = AutoModel.from_pretrained(global_model_path).to(self.device)
104