Spaces:
Runtime error
Runtime error
Updated search
Browse files
legal_info_search_data/internal_docs.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
semantic_search.py
CHANGED
@@ -18,7 +18,7 @@ from legal_info_search_utils.metrics import calculate_metrics_at_k
|
|
18 |
|
19 |
global_data_path = os.environ.get("GLOBAL_DATA_PATH", "legal_info_search_data/")
|
20 |
global_model_path = os.environ.get("GLOBAL_MODEL_PATH",
|
21 |
-
"legal_info_search_model/
|
22 |
|
23 |
# размеченные консультации
|
24 |
data_path_consult = os.environ.get("DATA_PATH_CONSULT",
|
@@ -28,6 +28,10 @@ data_path_consult = os.environ.get("DATA_PATH_CONSULT",
|
|
28 |
data_path_consult_ids = os.environ.get("DATA_PATH_CONSULT_IDS",
|
29 |
global_data_path + "data_ids.json")
|
30 |
|
|
|
|
|
|
|
|
|
31 |
# состав БД
|
32 |
# $ export DB_SUBSETS='["train", "valid", "test"]'
|
33 |
db_subsets = os.environ.get("DB_SUBSETS", ["train", "valid", "test"])
|
@@ -78,20 +82,23 @@ class SemanticSearch:
|
|
78 |
with open(data_path_consult_ids, "r", encoding="utf-8") as f:
|
79 |
data_ids = json.load(f)
|
80 |
|
|
|
|
|
|
|
81 |
db_data = get_subsets_for_db(db_subsets, data_ids, all_docs)
|
82 |
filtered_all_docs = filter_qa_data_types(db_data_types, all_docs)
|
83 |
|
84 |
self.mean_refs_count = self.get_mean_refs_counts(db_data_types, filtered_all_docs)
|
|
|
85 |
self.filtered_db_data = filter_db_data_types(db_data_types, db_data)
|
|
|
86 |
self.all_docs_qa = get_subsets_for_qa(["valid"], data_ids, filtered_all_docs)
|
87 |
|
88 |
def load_model(self):
|
89 |
if hf_token and hf_model_name:
|
90 |
-
print('Using model '+hf_model_name)
|
91 |
self.tokenizer = AutoTokenizer.from_pretrained(hf_model_name, use_auth_token=True)
|
92 |
self.model = AutoModel.from_pretrained(hf_model_name, use_auth_token=True).to(self.device)
|
93 |
else:
|
94 |
-
print('Using model '+global_model_path)
|
95 |
self.tokenizer = AutoTokenizer.from_pretrained(global_model_path)
|
96 |
self.model = AutoModel.from_pretrained(global_model_path).to(self.device)
|
97 |
|
|
|
18 |
|
19 |
global_data_path = os.environ.get("GLOBAL_DATA_PATH", "legal_info_search_data/")
|
20 |
global_model_path = os.environ.get("GLOBAL_MODEL_PATH",
|
21 |
+
"legal_info_search_model/20240202_204910_ep8/")
|
22 |
|
23 |
# размеченные консультации
|
24 |
data_path_consult = os.environ.get("DATA_PATH_CONSULT",
|
|
|
28 |
data_path_consult_ids = os.environ.get("DATA_PATH_CONSULT_IDS",
|
29 |
global_data_path + "data_ids.json")
|
30 |
|
31 |
+
# предобработанные внутренние документы
|
32 |
+
data_path_internal_docs = os.environ.get("DATA_PATH_INTERNAL_DOCS",
|
33 |
+
global_data_path + "internal_docs.json")
|
34 |
+
|
35 |
# состав БД
|
36 |
# $ export DB_SUBSETS='["train", "valid", "test"]'
|
37 |
db_subsets = os.environ.get("DB_SUBSETS", ["train", "valid", "test"])
|
|
|
82 |
with open(data_path_consult_ids, "r", encoding="utf-8") as f:
|
83 |
data_ids = json.load(f)
|
84 |
|
85 |
+
with open(data_path_internal_docs, "r", encoding="utf-8") as f:
|
86 |
+
internal_docs = json.load(f)
|
87 |
+
|
88 |
db_data = get_subsets_for_db(db_subsets, data_ids, all_docs)
|
89 |
filtered_all_docs = filter_qa_data_types(db_data_types, all_docs)
|
90 |
|
91 |
self.mean_refs_count = self.get_mean_refs_counts(db_data_types, filtered_all_docs)
|
92 |
+
self.mean_refs_count['Внутренний документ'] = 3
|
93 |
self.filtered_db_data = filter_db_data_types(db_data_types, db_data)
|
94 |
+
self.filtered_db_data.update(internal_docs)
|
95 |
self.all_docs_qa = get_subsets_for_qa(["valid"], data_ids, filtered_all_docs)
|
96 |
|
97 |
def load_model(self):
|
98 |
if hf_token and hf_model_name:
|
|
|
99 |
self.tokenizer = AutoTokenizer.from_pretrained(hf_model_name, use_auth_token=True)
|
100 |
self.model = AutoModel.from_pretrained(hf_model_name, use_auth_token=True).to(self.device)
|
101 |
else:
|
|
|
102 |
self.tokenizer = AutoTokenizer.from_pretrained(global_model_path)
|
103 |
self.model = AutoModel.from_pretrained(global_model_path).to(self.device)
|
104 |
|