Spaces:
Runtime error
Runtime error
| import os | |
| # import json | |
| import torch | |
| # import pickle | |
| import numpy as np | |
| import faiss | |
| from datasets import Dataset as dataset | |
| from transformers import AutoTokenizer, AutoModel | |
| from legal_info_search_utils.utils import query_tokenization, query_embed_extraction | |
| import requests | |
| import re | |
| global_data_path = os.environ.get("GLOBAL_DATA_PATH", "legal_info_search_data/") | |
| global_model_path = os.environ.get("GLOBAL_MODEL_PATH", | |
| "legal_info_search_model/20240202_204910_ep8/") | |
| # размеченные консультации | |
| # data_path_consult = os.environ.get("DATA_PATH_CONSULT", | |
| # global_data_path + "data_jsons_20240202.pkl") | |
| data_path_consult = os.environ.get("DATA_PATH_CONSULT", | |
| global_data_path + "all_docs_dataset_chunk_200_correct_tokenizer_for_develop_new_chunking") | |
| # id консультаций, побитые на train / valid / test | |
| # data_path_consult_ids = os.environ.get("DATA_PATH_CONSULT_IDS", | |
| # global_data_path + "data_ids.json") | |
| # предобработанные внутренние документы | |
| # data_path_internal_docs = os.environ.get("DATA_PATH_INTERNAL_DOCS", | |
| # global_data_path + "internal_docs.json") | |
| # состав БД | |
| # $ export DB_SUBSETS='["train", "valid", "test"]' | |
| # db_subsets = os.environ.get("DB_SUBSETS", ["train", "valid", "test"]) | |
| # Отбор типов документов. В списке указать те, которые нужно оставить в БД. | |
| # $ export DB_DATA_TYPES='["НКРФ", "ГКРФ", "ТКРФ"]' | |
| teaser_pattern = r"Переделанный отрывок:" | |
| db_data_types = os.environ.get("DB_DATA_TYPES", [ | |
| 'НКРФ', | |
| 'ГКРФ', | |
| 'ТКРФ', | |
| 'Федеральный закон', | |
| 'Письмо Минфина', | |
| 'Письмо ФНС', | |
| 'Приказ ФНС', | |
| 'Постановление Правительства', | |
| 'Судебный документ', | |
| 'Внутренний документ' | |
| ]) | |
| device = os.environ.get("MODEL_DEVICE", 'cuda' if torch.cuda.is_available() else 'cpu') | |
| # access token huggingface. Если задан, то используется модель с HF | |
| hf_token = os.environ.get("HF_TOKEN", "") | |
| hf_model_name = os.environ.get("HF_MODEL_NAME", "") | |
| default_teaser_refactoring_prompt = ( | |
| f"Ты профессиональный налоговый консультант. У тебя есть вопрос и отрывок документа. " | |
| f"Тебе нужно переделать этот отрывок так, чтобы в нём содержался ответ на вопрос, но в общих чертах и без конкретики, " | |
| f"а также упоминалась конкретная ситуация, если она была в отрывке. " | |
| f"В результате ты должен получить короткий отрывок, в котором есть только самая важная информация из исходного отрывка, а также названия документов из исходного отрывка. Тебе запрещается как-либо искажать исходную информацию. " | |
| f"Используй такой формат ответа: 'Переделанный отрывок: *переделанный отрывок*'. " | |
| f"Переделанный отрывок должен быть объёмом не более 50 слов. " | |
| f"Ты больше ничего не пишешь, не рассуждаешь. " | |
| f"Отвечай обязательно только на РУССКОМ языке!!!.") | |
| class SemanticSearch: | |
| def __init__(self, | |
| index_type="IndexFlatIP", | |
| do_embedding_norm=True, | |
| do_normalization=True): | |
| self.device = device | |
| self.do_embedding_norm = do_embedding_norm | |
| self.do_normalization = do_normalization | |
| self.load_model() | |
| indexes = { | |
| "IndexFlatL2": faiss.IndexFlatL2(self.embedding_dim), | |
| "IndexFlatIP": faiss.IndexFlatIP(self.embedding_dim) | |
| } | |
| self.optimal_params = { | |
| 'НКРФ': {'thresh': 0.58421, 'sim_factor': 0.89474}, | |
| 'ГКРФ': {'thresh': 0.64737, 'sim_factor': 0.89474}, | |
| 'ТКРФ': {'thresh': 0.58421, 'sim_factor': 0.94737}, | |
| 'Федеральный закон': {'thresh': 0.58421, 'sim_factor': 0.94737}, | |
| 'Письмо Минфина': {'thresh': 0.71053, 'sim_factor': 0.84211}, | |
| 'Письмо ФНС': {'thresh': 0.64737, 'sim_factor': 0.94737}, | |
| 'Приказ ФНС': {'thresh': 0.42632, 'sim_factor': 0.84211}, | |
| 'Постановление Правительства': {'thresh': 0.58421, 'sim_factor': 0.94737}, | |
| 'Судебный документ': {'thresh': 0.67895, 'sim_factor': 0.89474}, | |
| 'Внутренний документ': {'thresh': 0.55263, 'sim_factor': 0.84211} | |
| } | |
| self.index_type = index_type | |
| self.index_docs = indexes[self.index_type] | |
| self.load_data() | |
| self.docs_embeddings = [torch.unsqueeze(torch.Tensor(x['doc_embedding']), 0) for x in self.all_docs_info] | |
| self.docs_embeddings = torch.cat(self.docs_embeddings, dim=0) | |
| self.index_docs.add(self.docs_embeddings) | |
| def rebuild_teaser_with_llm(question: str = None, | |
| teaser: str = None) -> str: | |
| teaser_refactoring_prompt = default_teaser_refactoring_prompt + (f"\nВопрос: {question}" | |
| f"\nОтрывок: {teaser}") | |
| response = requests.post(url='https://muryshev-mixtral-api.hf.space/completion', | |
| json={"prompt": f"[INST]{teaser_refactoring_prompt}[/INST]"}) | |
| rebuilded_teaser = re.sub(teaser_pattern, '', response.text) | |
| return rebuilded_teaser | |
| def load_data(self): | |
| self.all_docs_info = dataset.load_from_disk(data_path_consult).to_list() | |
| self.docs_names = [doc['doc_name'] for doc in self.all_docs_info] | |
| self.mean_refs_count = {'НКРФ': 4, | |
| 'ГКРФ': 3, | |
| 'ТКРФ': 2, | |
| 'Федеральный закон': 2, | |
| 'Письмо Минфина': 3, | |
| 'Письмо ФНС': 2, | |
| 'Приказ ФНС': 2, | |
| 'Постановление Правительства': 2, | |
| 'Судебный документ': 3, | |
| 'Внутренний документ': 2} | |
| def load_model(self): | |
| if hf_token and hf_model_name: | |
| self.tokenizer = AutoTokenizer.from_pretrained(hf_model_name, use_auth_token=True) | |
| self.model = AutoModel.from_pretrained(hf_model_name, use_auth_token=True).to(self.device) | |
| else: | |
| self.tokenizer = AutoTokenizer.from_pretrained(global_model_path) | |
| self.model = AutoModel.from_pretrained(global_model_path).to(self.device) | |
| self.max_len = self.tokenizer.max_len_single_sentence | |
| self.embedding_dim = self.model.config.hidden_size | |
| def search_results_filtering(self, pred, dists): | |
| all_ctg_preds = [] | |
| all_scores = [] | |
| for ctg in db_data_types: | |
| ctg_thresh = self.optimal_params[ctg]['thresh'] | |
| ctg_sim_factor = self.optimal_params[ctg]['sim_factor'] | |
| ctg_preds = [(ref, dist) for ref, dist in zip(pred, dists) | |
| if ctg in ref and dist > ctg_thresh] | |
| sorted_pd = sorted(ctg_preds, key=lambda x: x[1], reverse=True) | |
| sorted_preds = [x[0] for x in sorted_pd] | |
| sorted_dists = [x[1] for x in sorted_pd] | |
| if len(sorted_dists): | |
| if len(sorted_dists) > 1: | |
| ratios = (sorted_dists[1:] / sorted_dists[0]) >= ctg_sim_factor | |
| ratios = np.array([True, *ratios]) | |
| else: | |
| ratios = np.array([True]) | |
| main_preds = np.array(sorted_preds)[np.where(ratios)].tolist() | |
| scores = np.array(sorted_dists)[np.where(ratios)].tolist() | |
| else: | |
| main_preds = [] | |
| scores = [] | |
| all_ctg_preds.extend(main_preds[:self.mean_refs_count[ctg]]) | |
| all_scores.extend(scores[:self.mean_refs_count[ctg]]) | |
| sorted_values = [(ref, score) for ref, score in zip(all_ctg_preds, all_scores)] | |
| sorted_values = sorted(sorted_values, key=lambda x: x[1], reverse=True) | |
| sorted_preds = [x[0] for x in sorted_values] | |
| sorted_scores = [x[1] for x in sorted_values] | |
| return sorted_preds, sorted_scores | |
| def get_most_relevant_teaser(self, | |
| question: str = None, | |
| doc_index: int = None): | |
| teaser_indexes = { | |
| "IndexFlatL2": faiss.IndexFlatL2(self.embedding_dim), | |
| "IndexFlatIP": faiss.IndexFlatIP(self.embedding_dim) | |
| } | |
| teasers_index = teaser_indexes[self.index_type] | |
| question_tokens = query_tokenization(question, self.tokenizer) | |
| question_embedding = query_embed_extraction(question_tokens, self.model, | |
| self.do_normalization) | |
| teasers_texts = [teaser['summary_text'] for teaser in self.all_docs_info[doc_index]['chunks_embeddings']] | |
| teasers_embeddings = [torch.unsqueeze(torch.Tensor(teaser['embedding']), 0) for teaser in | |
| self.all_docs_info[doc_index]['chunks_embeddings']] | |
| teasers_embeddings = torch.cat(teasers_embeddings, 0) | |
| teasers_index.add(teasers_embeddings) | |
| distances, indices = teasers_index.search(question_embedding, 10) | |
| most_relevant_teaser = teasers_texts[indices[0][0]] | |
| return most_relevant_teaser | |
| def search(self, query, top=15, use_llm_for_teasers: bool = False): | |
| query_tokens = query_tokenization(query, self.tokenizer) | |
| query_embeds = query_embed_extraction(query_tokens, self.model, | |
| self.do_normalization) | |
| distances, indices = self.index_docs.search(query_embeds, len(self.all_docs_info)) | |
| pred = [self.all_docs_info[x]['doc_name'] for x in indices[0]] | |
| preds, scores = self.search_results_filtering(pred, distances[0]) | |
| teasers = [] | |
| docs = [] | |
| for ref in preds: | |
| doc_index = self.docs_names.index(ref) | |
| doc_text = self.all_docs_info[doc_index]['doc_text'] | |
| docs.append(doc_text) # Add the relevant document text | |
| most_relevant_teaser = self.get_most_relevant_teaser(question=query, | |
| doc_index=doc_index) | |
| if use_llm_for_teasers: | |
| most_relevant_teaser = self.rebuild_teaser_with_llm(question=query, | |
| teaser=most_relevant_teaser) | |
| teasers.append(most_relevant_teaser) # Add the most relevant teaser | |
| return preds[:top], docs[:top], teasers[:top], scores[:top] | |