import os # import json import torch # import pickle import numpy as np import faiss from datasets import Dataset as dataset from transformers import AutoTokenizer, AutoModel from legal_info_search_utils.utils import query_tokenization, query_embed_extraction import requests import re global_data_path = os.environ.get("GLOBAL_DATA_PATH", "legal_info_search_data/") global_model_path = os.environ.get("GLOBAL_MODEL_PATH", "legal_info_search_model/20240202_204910_ep8/") # размеченные консультации # data_path_consult = os.environ.get("DATA_PATH_CONSULT", # global_data_path + "data_jsons_20240202.pkl") data_path_consult = os.environ.get("DATA_PATH_CONSULT", global_data_path + "all_docs_dataset_chunk_200_correct_tokenizer_for_develop_new_chunking") # id консультаций, побитые на train / valid / test # data_path_consult_ids = os.environ.get("DATA_PATH_CONSULT_IDS", # global_data_path + "data_ids.json") # предобработанные внутренние документы # data_path_internal_docs = os.environ.get("DATA_PATH_INTERNAL_DOCS", # global_data_path + "internal_docs.json") # состав БД # $ export DB_SUBSETS='["train", "valid", "test"]' # db_subsets = os.environ.get("DB_SUBSETS", ["train", "valid", "test"]) # Отбор типов документов. В списке указать те, которые нужно оставить в БД. # $ export DB_DATA_TYPES='["НКРФ", "ГКРФ", "ТКРФ"]' teaser_pattern = r"Переделанный отрывок:" db_data_types = os.environ.get("DB_DATA_TYPES", [ 'НКРФ', 'ГКРФ', 'ТКРФ', 'Федеральный закон', 'Письмо Минфина', 'Письмо ФНС', 'Приказ ФНС', 'Постановление Правительства', 'Судебный документ', 'Внутренний документ' ]) device = os.environ.get("MODEL_DEVICE", 'cuda' if torch.cuda.is_available() else 'cpu') # access token huggingface. Если задан, то используется модель с HF hf_token = os.environ.get("HF_TOKEN", "") hf_model_name = os.environ.get("HF_MODEL_NAME", "") default_teaser_refactoring_prompt = ( f"Ты профессиональный налоговый консультант. У тебя есть вопрос и отрывок документа. " f"Тебе нужно переделать этот отрывок так, чтобы в нём содержался ответ на вопрос, но в общих чертах и без конкретики, " f"а также упоминалась конкретная ситуация, если она была в отрывке. " f"В результате ты должен получить короткий отрывок, в котором есть только самая важная информация из исходного отрывка, а также названия документов из исходного отрывка. Тебе запрещается как-либо искажать исходную информацию. " f"Используй такой формат ответа: 'Переделанный отрывок: *переделанный отрывок*'. " f"Переделанный отрывок должен быть объёмом не более 50 слов. " f"Ты больше ничего не пишешь, не рассуждаешь. " f"Отвечай обязательно только на РУССКОМ языке!!!.") class SemanticSearch: def __init__(self, index_type="IndexFlatIP", do_embedding_norm=True, do_normalization=True): self.device = device self.do_embedding_norm = do_embedding_norm self.do_normalization = do_normalization self.load_model() indexes = { "IndexFlatL2": faiss.IndexFlatL2(self.embedding_dim), "IndexFlatIP": faiss.IndexFlatIP(self.embedding_dim) } self.optimal_params = { 'НКРФ': {'thresh': 0.58421, 'sim_factor': 0.89474}, 'ГКРФ': {'thresh': 0.64737, 'sim_factor': 0.89474}, 'ТКРФ': {'thresh': 0.58421, 'sim_factor': 0.94737}, 'Федеральный закон': {'thresh': 0.58421, 'sim_factor': 0.94737}, 'Письмо Минфина': {'thresh': 0.71053, 'sim_factor': 0.84211}, 'Письмо ФНС': {'thresh': 0.64737, 'sim_factor': 0.94737}, 'Приказ ФНС': {'thresh': 0.42632, 'sim_factor': 0.84211}, 'Постановление Правительства': {'thresh': 0.58421, 'sim_factor': 0.94737}, 'Судебный документ': {'thresh': 0.67895, 'sim_factor': 0.89474}, 'Внутренний документ': {'thresh': 0.55263, 'sim_factor': 0.84211} } self.index_type = index_type self.index_docs = indexes[self.index_type] self.load_data() self.docs_embeddings = [torch.unsqueeze(torch.Tensor(x['doc_embedding']), 0) for x in self.all_docs_info] self.docs_embeddings = torch.cat(self.docs_embeddings, dim=0) self.index_docs.add(self.docs_embeddings) @staticmethod def rebuild_teaser_with_llm(question: str = None, teaser: str = None) -> str: teaser_refactoring_prompt = default_teaser_refactoring_prompt + (f"\nВопрос: {question}" f"\nОтрывок: {teaser}") response = requests.post(url='https://muryshev-mixtral-api.hf.space/completion', json={"prompt": f"[INST]{teaser_refactoring_prompt}[/INST]"}) rebuilded_teaser = re.sub(teaser_pattern, '', response.text) return rebuilded_teaser def load_data(self): self.all_docs_info = dataset.load_from_disk(data_path_consult).to_list() self.docs_names = [doc['doc_name'] for doc in self.all_docs_info] self.mean_refs_count = {'НКРФ': 4, 'ГКРФ': 3, 'ТКРФ': 2, 'Федеральный закон': 2, 'Письмо Минфина': 3, 'Письмо ФНС': 2, 'Приказ ФНС': 2, 'Постановление Правительства': 2, 'Судебный документ': 3, 'Внутренний документ': 2} def load_model(self): if hf_token and hf_model_name: self.tokenizer = AutoTokenizer.from_pretrained(hf_model_name, use_auth_token=True) self.model = AutoModel.from_pretrained(hf_model_name, use_auth_token=True).to(self.device) else: self.tokenizer = AutoTokenizer.from_pretrained(global_model_path) self.model = AutoModel.from_pretrained(global_model_path).to(self.device) self.max_len = self.tokenizer.max_len_single_sentence self.embedding_dim = self.model.config.hidden_size def search_results_filtering(self, pred, dists): all_ctg_preds = [] all_scores = [] for ctg in db_data_types: ctg_thresh = self.optimal_params[ctg]['thresh'] ctg_sim_factor = self.optimal_params[ctg]['sim_factor'] ctg_preds = [(ref, dist) for ref, dist in zip(pred, dists) if ctg in ref and dist > ctg_thresh] sorted_pd = sorted(ctg_preds, key=lambda x: x[1], reverse=True) sorted_preds = [x[0] for x in sorted_pd] sorted_dists = [x[1] for x in sorted_pd] if len(sorted_dists): if len(sorted_dists) > 1: ratios = (sorted_dists[1:] / sorted_dists[0]) >= ctg_sim_factor ratios = np.array([True, *ratios]) else: ratios = np.array([True]) main_preds = np.array(sorted_preds)[np.where(ratios)].tolist() scores = np.array(sorted_dists)[np.where(ratios)].tolist() else: main_preds = [] scores = [] all_ctg_preds.extend(main_preds[:self.mean_refs_count[ctg]]) all_scores.extend(scores[:self.mean_refs_count[ctg]]) sorted_values = [(ref, score) for ref, score in zip(all_ctg_preds, all_scores)] sorted_values = sorted(sorted_values, key=lambda x: x[1], reverse=True) sorted_preds = [x[0] for x in sorted_values] sorted_scores = [x[1] for x in sorted_values] return sorted_preds, sorted_scores def get_most_relevant_teaser(self, question: str = None, doc_index: int = None): teaser_indexes = { "IndexFlatL2": faiss.IndexFlatL2(self.embedding_dim), "IndexFlatIP": faiss.IndexFlatIP(self.embedding_dim) } teasers_index = teaser_indexes[self.index_type] question_tokens = query_tokenization(question, self.tokenizer) question_embedding = query_embed_extraction(question_tokens, self.model, self.do_normalization) teasers_texts = [teaser['summary_text'] for teaser in self.all_docs_info[doc_index]['chunks_embeddings']] teasers_embeddings = [torch.unsqueeze(torch.Tensor(teaser['embedding']), 0) for teaser in self.all_docs_info[doc_index]['chunks_embeddings']] teasers_embeddings = torch.cat(teasers_embeddings, 0) teasers_index.add(teasers_embeddings) distances, indices = teasers_index.search(question_embedding, 10) most_relevant_teaser = teasers_texts[indices[0][0]] return most_relevant_teaser def search(self, query, top=15, use_llm_for_teasers: bool = False): query_tokens = query_tokenization(query, self.tokenizer) query_embeds = query_embed_extraction(query_tokens, self.model, self.do_normalization) distances, indices = self.index_docs.search(query_embeds, len(self.all_docs_info)) pred = [self.all_docs_info[x]['doc_name'] for x in indices[0]] preds, scores = self.search_results_filtering(pred, distances[0]) teasers = [] docs = [] for ref in preds: doc_index = self.docs_names.index(ref) doc_text = self.all_docs_info[doc_index]['doc_text'] docs.append(doc_text) # Add the relevant document text most_relevant_teaser = self.get_most_relevant_teaser(question=query, doc_index=doc_index) if use_llm_for_teasers: most_relevant_teaser = self.rebuild_teaser_with_llm(question=query, teaser=most_relevant_teaser) teasers.append(most_relevant_teaser) # Add the most relevant teaser return preds[:top], docs[:top], teasers[:top], scores[:top]