Spaces:
Runtime error
Runtime error
import os | |
import json | |
import torch | |
import pickle | |
import numpy as np | |
import faiss | |
from transformers import AutoTokenizer, AutoModel | |
from legal_info_search_utils.utils import get_subsets_for_db, get_subsets_for_qa | |
from legal_info_search_utils.utils import filter_db_data_types, filter_qa_data_types | |
from legal_info_search_utils.utils import db_tokenization, qa_tokenization | |
from legal_info_search_utils.utils import extract_text_embeddings, filter_ref_parts | |
from legal_info_search_utils.utils import print_metrics, get_final_metrics | |
from legal_info_search_utils.utils import query_tokenization, query_embed_extraction | |
from legal_info_search_utils.metrics import calculate_metrics_at_k | |
global_data_path = os.environ.get("GLOBAL_DATA_PATH", "legal_info_search_data/") | |
global_model_path = os.environ.get("GLOBAL_MODEL_PATH", "e5_large_rus_finetuned_20240120_122822_ep6") | |
# размеченные консультации | |
data_path_consult = os.environ.get("DATA_PATH_CONSULT", | |
global_data_path + "data_jsons_20240119.pkl") | |
# id консультаций, побитые на train / valid / test | |
data_path_consult_ids = os.environ.get("DATA_PATH_CONSULT_IDS", | |
global_data_path + "data_ids.json") | |
# состав БД | |
# $ export DB_SUBSETS='["train", "valid", "test"]' | |
db_subsets = os.environ.get("DB_SUBSETS", ["train", "valid", "test"]) | |
# Отбор типов документов. В списке указать те, которые нужно оставить в БД. | |
# $ export DB_DATA_TYPES='["НКРФ", "ГКРФ", "ТКРФ"]' | |
db_data_types = os.environ.get("DB_DATA_TYPES", [ | |
'НКРФ', | |
'ГКРФ', | |
'ТКРФ', | |
'Федеральный закон', | |
'Письмо Минфина', | |
'Письмо ФНС', | |
'Приказ ФНС', | |
'Постановление Правительства' | |
]) | |
device = os.environ.get("MODEL_DEVICE", 'cuda' if torch.cuda.is_available() else 'cpu') | |
# access token huggingface. Если задан, то используется модель с HF | |
hf_token = os.environ.get("HF_TOKEN", "") | |
hf_model_name = os.environ.get("HF_MODEL_NAME", "") | |
class SemanticSearch: | |
def __init__(self, index_type="IndexFlatIP", do_embedding_norm=True, | |
faiss_batch_size=8, do_normalization=True): | |
self.device = device | |
self.do_embedding_norm = do_embedding_norm | |
self.faiss_batch_size = faiss_batch_size | |
self.do_normalization = do_normalization | |
self.load_model() | |
indexes = { | |
"IndexFlatL2": faiss.IndexFlatL2(self.embedding_dim), | |
"IndexFlatIP": faiss.IndexFlatIP(self.embedding_dim) | |
} | |
self.index = indexes[index_type] | |
self.load_data() | |
self.preproces_data() | |
self.test_search() | |
def load_data(self): | |
with open(data_path_consult, "rb") as f: | |
all_docs = pickle.load(f) | |
with open(data_path_consult_ids, "r", encoding="utf-8") as f: | |
data_ids = json.load(f) | |
db_data = get_subsets_for_db(db_subsets, data_ids, all_docs) | |
filtered_all_docs = filter_qa_data_types(db_data_types, all_docs) | |
self.filtered_db_data = filter_db_data_types(db_data_types, db_data) | |
self.all_docs_qa = get_subsets_for_qa(["valid"], data_ids, filtered_all_docs) | |
def load_model(self): | |
if hf_token and hf_model_name: | |
self.tokenizer = AutoTokenizer.from_pretrained(hf_model_name, use_auth_token=True) | |
self.model = AutoModel.from_pretrained(hf_model_name, use_auth_token=True).to(self.device) | |
else: | |
self.tokenizer = AutoTokenizer.from_pretrained(global_model_path) | |
self.model = AutoModel.from_pretrained(global_model_path).to(self.device) | |
self.max_len = self.tokenizer.max_len_single_sentence | |
self.embedding_dim = self.model.config.hidden_size | |
def preproces_data(self): | |
index_keys, index_toks = db_tokenization(self.filtered_db_data, self.tokenizer) | |
val_questions, val_refs = qa_tokenization(self.all_docs_qa, self.tokenizer) | |
docs_embeds_faiss, questions_embeds_faiss = extract_text_embeddings(index_toks, | |
val_questions, self.model, self.do_normalization, self.faiss_batch_size) | |
self.index.add(docs_embeds_faiss) | |
self.index_keys = index_keys | |
self.index_toks = index_toks | |
self.val_questions = val_questions | |
self.val_refs = val_refs | |
self.docs_embeds_faiss = docs_embeds_faiss | |
self.questions_embeds_faiss = questions_embeds_faiss | |
self.optimal_params = { | |
'НКРФ': { | |
'thresh': 0.613793, 'sim_factor': 0.878947, 'diff_n': 0}, | |
'ГКРФ': { | |
'thresh': 0.758620, 'sim_factor': 0.878947, 'diff_n': 0}, | |
'ТКРФ': { | |
'thresh': 0.734482, 'sim_factor': 0.9, 'diff_n': 0}, | |
'Федеральный закон': { | |
'thresh': 0.734482, 'sim_factor': 0.5, 'diff_n': 0}, | |
'Письмо Минфина': { | |
'thresh': 0.782758, 'sim_factor': 0.5, 'diff_n': 0}, | |
'Письмо ФНС': { | |
'thresh': 0.879310, 'sim_factor': 0.5, 'diff_n': 0}, | |
'Приказ ФНС': { | |
'thresh': 0.806896, 'sim_factor': 0.5, 'diff_n': 0}, | |
'Постановление Правительства': { | |
'thresh': 0.782758, 'sim_factor': 0.5, 'diff_n': 0} | |
} | |
self.ref_categories = { | |
'all': 'all', | |
'НКРФ': 'НКРФ', | |
'ГКРФ': 'ГКРФ', | |
'ТКРФ': 'ТКРФ', | |
'Федеральный закон': 'ФЗ', | |
'Суды': 'Суды', | |
'Письмо Минфина': 'Письмо МФ', | |
'Письмо ФНС': 'Письмо ФНС', | |
'Приказ ФНС': 'Приказ ФНС', | |
'Постановление Правительства': 'Пост. Прав.' | |
} | |
def test_search(self): | |
topk = len(self.filtered_db_data) | |
pred_raw = {} | |
true = {} | |
all_distances = [] | |
for idx, (q_embed, refs) in enumerate(zip(self.questions_embeds_faiss, | |
self.val_refs.values())): | |
distances, indices = self.index.search(np.expand_dims(q_embed, 0), topk) | |
pred_raw[idx] = [self.index_keys[x] for x in indices[0]] | |
true[idx] = list(refs) | |
all_distances.append(distances) | |
pred = {} | |
for idx, p, d in zip(true.keys(), pred_raw.values(), all_distances): | |
fp, fs = self.search_results_filtering(p, d[0]) | |
pred[idx] = fp | |
# раскомментировать нужное. Если всё закомментировано - метрики | |
# посчтаются "как есть", с учетом полной иерархии | |
filter_parts = [ | |
# "абз.", | |
# "пп.", | |
# "п." | |
] | |
filtered_pred = filter_ref_parts(pred, filter_parts) | |
filtered_true = filter_ref_parts(true, filter_parts) | |
metrics = get_final_metrics(filtered_pred, filtered_true, | |
self.ref_categories.keys(), [0], | |
metrics_func=calculate_metrics_at_k, dynamic_topk=True) | |
print_metrics(metrics, self.ref_categories) | |
def search_results_filtering(self, pred, dists): | |
all_ctg_preds = [] | |
all_scores = [] | |
for ctg in db_data_types: | |
ctg_thresh = self.optimal_params[ctg]['thresh'] | |
ctg_sim_factor = self.optimal_params[ctg]['sim_factor'] | |
ctg_diff_n = self.optimal_params[ctg]['diff_n'] | |
ctg_preds = [(ref, dist) for ref, dist in zip(pred, dists) | |
if ctg in ref and dist > ctg_thresh] | |
sorted_pd = sorted(ctg_preds, key=lambda x: x[1], reverse=True) | |
sorted_preds = [x[0] for x in sorted_pd] | |
sorted_dists = [x[1] for x in sorted_pd] | |
if len(sorted_dists): | |
diffs = np.diff(sorted_dists, ctg_diff_n) | |
if len(diffs): | |
n_preds = np.argmax(diffs) + ctg_diff_n + 1 | |
else: | |
n_preds = 0 | |
if len(sorted_dists) > 1: | |
ratios = (sorted_dists[1:] / sorted_dists[0]) >= ctg_sim_factor | |
ratios = np.array([True, *ratios]) | |
else: | |
ratios = np.array([True]) | |
main_preds = np.array(sorted_preds)[np.where(ratios)].tolist() | |
scores = np.array(sorted_dists)[np.where(ratios)].tolist() | |
if ctg_diff_n > 0 and n_preds > 0: | |
main_preds = main_preds[:n_preds] | |
scores = scores[:n_preds] | |
else: | |
main_preds = [] | |
scores = [] | |
all_ctg_preds.extend(main_preds) | |
all_scores.extend(scores) | |
sorted_values = [(ref, score) for ref, score in zip(all_ctg_preds, all_scores)] | |
sorted_values = sorted(sorted_values, key=lambda x: x[1], reverse=True) | |
sorted_preds = [x[0] for x in sorted_values] | |
sorted_scores = [x[1] for x in sorted_values] | |
return sorted_preds, sorted_scores | |
def search(self, query, top=10): | |
query_tokens = query_tokenization(query, self.tokenizer) | |
query_embeds = query_embed_extraction(query_tokens, self.model, | |
self.do_normalization) | |
distances, indices = self.index.search(query_embeds, len(self.filtered_db_data)) | |
pred = [self.index_keys[x] for x in indices[0]] | |
preds, scores = self.search_results_filtering(pred, distances[0]) | |
docs = [self.filtered_db_data[ref] for ref in preds] | |
return preds[:top], docs[:top], scores[:top] | |