import os import json import torch import pickle import numpy as np import faiss from transformers import AutoTokenizer, AutoModel from legal_info_search_utils.utils import get_subsets_for_db, get_subsets_for_qa from legal_info_search_utils.utils import filter_db_data_types, filter_qa_data_types from legal_info_search_utils.utils import db_tokenization, qa_tokenization from legal_info_search_utils.utils import extract_text_embeddings, filter_ref_parts from legal_info_search_utils.utils import print_metrics, get_final_metrics from legal_info_search_utils.utils import query_tokenization, query_embed_extraction from legal_info_search_utils.metrics import calculate_metrics_at_k global_data_path = os.environ.get("GLOBAL_DATA_PATH", "legal_info_search_data/") global_model_path = os.environ.get("GLOBAL_MODEL_PATH", "e5_large_rus_finetuned_20240120_122822_ep6") # размеченные консультации data_path_consult = os.environ.get("DATA_PATH_CONSULT", global_data_path + "data_jsons_20240119.pkl") # id консультаций, побитые на train / valid / test data_path_consult_ids = os.environ.get("DATA_PATH_CONSULT_IDS", global_data_path + "data_ids.json") # состав БД # $ export DB_SUBSETS='["train", "valid", "test"]' db_subsets = os.environ.get("DB_SUBSETS", ["train", "valid", "test"]) # Отбор типов документов. В списке указать те, которые нужно оставить в БД. # $ export DB_DATA_TYPES='["НКРФ", "ГКРФ", "ТКРФ"]' db_data_types = os.environ.get("DB_DATA_TYPES", [ 'НКРФ', 'ГКРФ', 'ТКРФ', 'Федеральный закон', 'Письмо Минфина', 'Письмо ФНС', 'Приказ ФНС', 'Постановление Правительства' ]) device = os.environ.get("MODEL_DEVICE", 'cuda' if torch.cuda.is_available() else 'cpu') # access token huggingface. Если задан, то используется модель с HF hf_token = os.environ.get("HF_TOKEN", "") hf_model_name = os.environ.get("HF_MODEL_NAME", "") class SemanticSearch: def __init__(self, index_type="IndexFlatIP", do_embedding_norm=True, faiss_batch_size=8, do_normalization=True): self.device = device self.do_embedding_norm = do_embedding_norm self.faiss_batch_size = faiss_batch_size self.do_normalization = do_normalization self.load_model() indexes = { "IndexFlatL2": faiss.IndexFlatL2(self.embedding_dim), "IndexFlatIP": faiss.IndexFlatIP(self.embedding_dim) } self.index = indexes[index_type] self.load_data() self.preproces_data() self.test_search() def load_data(self): with open(data_path_consult, "rb") as f: all_docs = pickle.load(f) with open(data_path_consult_ids, "r", encoding="utf-8") as f: data_ids = json.load(f) db_data = get_subsets_for_db(db_subsets, data_ids, all_docs) filtered_all_docs = filter_qa_data_types(db_data_types, all_docs) self.filtered_db_data = filter_db_data_types(db_data_types, db_data) self.all_docs_qa = get_subsets_for_qa(["valid"], data_ids, filtered_all_docs) def load_model(self): if hf_token and hf_model_name: self.tokenizer = AutoTokenizer.from_pretrained(hf_model_name, use_auth_token=True) self.model = AutoModel.from_pretrained(hf_model_name, use_auth_token=True).to(self.device) else: self.tokenizer = AutoTokenizer.from_pretrained(global_model_path) self.model = AutoModel.from_pretrained(global_model_path).to(self.device) self.max_len = self.tokenizer.max_len_single_sentence self.embedding_dim = self.model.config.hidden_size def preproces_data(self): index_keys, index_toks = db_tokenization(self.filtered_db_data, self.tokenizer) val_questions, val_refs = qa_tokenization(self.all_docs_qa, self.tokenizer) docs_embeds_faiss, questions_embeds_faiss = extract_text_embeddings(index_toks, val_questions, self.model, self.do_normalization, self.faiss_batch_size) self.index.add(docs_embeds_faiss) self.index_keys = index_keys self.index_toks = index_toks self.val_questions = val_questions self.val_refs = val_refs self.docs_embeds_faiss = docs_embeds_faiss self.questions_embeds_faiss = questions_embeds_faiss self.optimal_params = { 'НКРФ': { 'thresh': 0.613793, 'sim_factor': 0.878947, 'diff_n': 0}, 'ГКРФ': { 'thresh': 0.758620, 'sim_factor': 0.878947, 'diff_n': 0}, 'ТКРФ': { 'thresh': 0.734482, 'sim_factor': 0.9, 'diff_n': 0}, 'Федеральный закон': { 'thresh': 0.734482, 'sim_factor': 0.5, 'diff_n': 0}, 'Письмо Минфина': { 'thresh': 0.782758, 'sim_factor': 0.5, 'diff_n': 0}, 'Письмо ФНС': { 'thresh': 0.879310, 'sim_factor': 0.5, 'diff_n': 0}, 'Приказ ФНС': { 'thresh': 0.806896, 'sim_factor': 0.5, 'diff_n': 0}, 'Постановление Правительства': { 'thresh': 0.782758, 'sim_factor': 0.5, 'diff_n': 0} } self.ref_categories = { 'all': 'all', 'НКРФ': 'НКРФ', 'ГКРФ': 'ГКРФ', 'ТКРФ': 'ТКРФ', 'Федеральный закон': 'ФЗ', 'Суды': 'Суды', 'Письмо Минфина': 'Письмо МФ', 'Письмо ФНС': 'Письмо ФНС', 'Приказ ФНС': 'Приказ ФНС', 'Постановление Правительства': 'Пост. Прав.' } def test_search(self): topk = len(self.filtered_db_data) pred_raw = {} true = {} all_distances = [] for idx, (q_embed, refs) in enumerate(zip(self.questions_embeds_faiss, self.val_refs.values())): distances, indices = self.index.search(np.expand_dims(q_embed, 0), topk) pred_raw[idx] = [self.index_keys[x] for x in indices[0]] true[idx] = list(refs) all_distances.append(distances) pred = {} for idx, p, d in zip(true.keys(), pred_raw.values(), all_distances): fp, fs = self.search_results_filtering(p, d[0]) pred[idx] = fp # раскомментировать нужное. Если всё закомментировано - метрики # посчтаются "как есть", с учетом полной иерархии filter_parts = [ # "абз.", # "пп.", # "п." ] filtered_pred = filter_ref_parts(pred, filter_parts) filtered_true = filter_ref_parts(true, filter_parts) metrics = get_final_metrics(filtered_pred, filtered_true, self.ref_categories.keys(), [0], metrics_func=calculate_metrics_at_k, dynamic_topk=True) print_metrics(metrics, self.ref_categories) def search_results_filtering(self, pred, dists): all_ctg_preds = [] all_scores = [] for ctg in db_data_types: ctg_thresh = self.optimal_params[ctg]['thresh'] ctg_sim_factor = self.optimal_params[ctg]['sim_factor'] ctg_diff_n = self.optimal_params[ctg]['diff_n'] ctg_preds = [(ref, dist) for ref, dist in zip(pred, dists) if ctg in ref and dist > ctg_thresh] sorted_pd = sorted(ctg_preds, key=lambda x: x[1], reverse=True) sorted_preds = [x[0] for x in sorted_pd] sorted_dists = [x[1] for x in sorted_pd] if len(sorted_dists): diffs = np.diff(sorted_dists, ctg_diff_n) if len(diffs): n_preds = np.argmax(diffs) + ctg_diff_n + 1 else: n_preds = 0 if len(sorted_dists) > 1: ratios = (sorted_dists[1:] / sorted_dists[0]) >= ctg_sim_factor ratios = np.array([True, *ratios]) else: ratios = np.array([True]) main_preds = np.array(sorted_preds)[np.where(ratios)].tolist() scores = np.array(sorted_dists)[np.where(ratios)].tolist() if ctg_diff_n > 0 and n_preds > 0: main_preds = main_preds[:n_preds] scores = scores[:n_preds] else: main_preds = [] scores = [] all_ctg_preds.extend(main_preds) all_scores.extend(scores) sorted_values = [(ref, score) for ref, score in zip(all_ctg_preds, all_scores)] sorted_values = sorted(sorted_values, key=lambda x: x[1], reverse=True) sorted_preds = [x[0] for x in sorted_values] sorted_scores = [x[1] for x in sorted_values] return sorted_preds, sorted_scores def search(self, query, top=10): query_tokens = query_tokenization(query, self.tokenizer) query_embeds = query_embed_extraction(query_tokens, self.model, self.do_normalization) distances, indices = self.index.search(query_embeds, len(self.filtered_db_data)) pred = [self.index_keys[x] for x in indices[0]] preds, scores = self.search_results_filtering(pred, distances[0]) docs = [self.filtered_db_data[ref] for ref in preds] return preds[:top], docs[:top], scores[:top]