Spaces:
Runtime error
Runtime error
import os | |
# import json | |
import torch | |
# import pickle | |
import numpy as np | |
import faiss | |
from datasets import Dataset as dataset | |
from transformers import AutoTokenizer, AutoModel | |
# from legal_info_search_utils.utils import get_subsets_for_db, get_subsets_for_qa | |
# from legal_info_search_utils.utils import filter_db_data_types, filter_qa_data_types | |
# from legal_info_search_utils.utils import db_tokenization, qa_tokenization | |
# from legal_info_search_utils.utils import extract_text_embeddings, filter_ref_parts | |
# from legal_info_search_utils.utils import print_metrics, get_final_metrics | |
# from legal_info_search_utils.utils import court_text_splitter | |
from legal_info_search_utils.utils import query_tokenization, query_embed_extraction | |
import requests | |
# from legal_info_search_utils.metrics import calculate_metrics_at_k | |
global_data_path = os.environ.get("GLOBAL_DATA_PATH", "legal_info_search_data/") | |
global_model_path = os.environ.get("GLOBAL_MODEL_PATH", | |
"legal_info_search_model/20240202_204910_ep8/") | |
# размеченные консультации | |
# data_path_consult = os.environ.get("DATA_PATH_CONSULT", | |
# global_data_path + "data_jsons_20240202.pkl") | |
data_path_consult = os.environ.get("DATA_PATH_CONSULT", | |
global_data_path + "all_docs_dataset_chunk_200_correct_tokenizer_for_develop_new_chunking") | |
# id консультаций, побитые на train / valid / test | |
# data_path_consult_ids = os.environ.get("DATA_PATH_CONSULT_IDS", | |
# global_data_path + "data_ids.json") | |
# предобработанные внутренние документы | |
# data_path_internal_docs = os.environ.get("DATA_PATH_INTERNAL_DOCS", | |
# global_data_path + "internal_docs.json") | |
# состав БД | |
# $ export DB_SUBSETS='["train", "valid", "test"]' | |
# db_subsets = os.environ.get("DB_SUBSETS", ["train", "valid", "test"]) | |
# Отбор типов документов. В списке указать те, которые нужно оставить в БД. | |
# $ export DB_DATA_TYPES='["НКРФ", "ГКРФ", "ТКРФ"]' | |
db_data_types = os.environ.get("DB_DATA_TYPES", [ | |
'НКРФ', | |
'ГКРФ', | |
'ТКРФ', | |
'Федеральный закон', | |
'Письмо Минфина', | |
'Письмо ФНС', | |
'Приказ ФНС', | |
'Постановление Правительства', | |
'Судебный документ', | |
'Внутренний документ' | |
]) | |
device = os.environ.get("MODEL_DEVICE", 'cuda' if torch.cuda.is_available() else 'cpu') | |
# access token huggingface. Если задан, то используется модель с HF | |
hf_token = os.environ.get("HF_TOKEN", "") | |
hf_model_name = os.environ.get("HF_MODEL_NAME", "") | |
default_teaser_refactoring_prompt = (f"Ты профессиональный налоговый консультант. У тебя есть вопрос и отрывок документа. " | |
f"Тебе нужно переделать этот отрывок так, чтобы в нём содержался ответ на вопрос " | |
f"и по возможности было название документа, " | |
f"а также описывался конкретный кейс, если он был в отрывке." | |
f"Ты больше ничего не пишешь, не рассуждаешь. " | |
f"Отвечай только на РУССКОМ языке.") | |
class SemanticSearch: | |
def __init__(self, | |
index_type="IndexFlatIP", | |
do_embedding_norm=True, | |
do_normalization=True): | |
self.device = device | |
self.do_embedding_norm = do_embedding_norm | |
self.do_normalization = do_normalization | |
self.load_model() | |
indexes = { | |
"IndexFlatL2": faiss.IndexFlatL2(self.embedding_dim), | |
"IndexFlatIP": faiss.IndexFlatIP(self.embedding_dim) | |
} | |
self.optimal_params = { | |
'НКРФ': {'thresh': 0.58421, 'sim_factor': 0.89474}, | |
'ГКРФ': {'thresh': 0.64737, 'sim_factor': 0.89474}, | |
'ТКРФ': {'thresh': 0.58421, 'sim_factor': 0.94737}, | |
'Федеральный закон': {'thresh': 0.58421, 'sim_factor': 0.94737}, | |
'Письмо Минфина': {'thresh': 0.71053, 'sim_factor': 0.84211}, | |
'Письмо ФНС': {'thresh': 0.64737, 'sim_factor': 0.94737}, | |
'Приказ ФНС': {'thresh': 0.42632, 'sim_factor': 0.84211}, | |
'Постановление Правительства': {'thresh': 0.58421, 'sim_factor': 0.94737}, | |
'Судебный документ': {'thresh': 0.67895, 'sim_factor': 0.89474}, | |
'Внутренний документ': {'thresh': 0.55263, 'sim_factor': 0.84211} | |
} | |
self.index_type = index_type | |
self.index_docs = indexes[self.index_type] | |
self.load_data() | |
self.docs_embeddings = [torch.unsqueeze(torch.Tensor(x['doc_embedding']), 0) for x in self.all_docs_info] | |
self.docs_embeddings = torch.cat(self.docs_embeddings, dim=0) | |
self.index_docs.add(self.docs_embeddings) | |
def rebuild_teaser_with_llm(question: str = None, | |
teaser: str = None) -> str: | |
teaser_refactoring_prompt = default_teaser_refactoring_prompt + (f"\nВопрос: {question}" | |
f"\nОтрывок: {teaser}") | |
response = requests.post(url='https://muryshev-mixtral-api.hf.space/completion', | |
json={"prompt": f"[INST]{teaser_refactoring_prompt}[/INST]"}) | |
return response.text | |
def load_data(self): | |
self.all_docs_info = dataset.load_from_disk(data_path_consult).to_list() | |
self.docs_names = [doc['doc_name'] for doc in self.all_docs_info] | |
self.mean_refs_count = {'Письмо Минфина': 3, | |
'Письмо ФНС': 2, | |
'Судебный документ': 3} | |
def load_model(self): | |
if hf_token and hf_model_name: | |
self.tokenizer = AutoTokenizer.from_pretrained(hf_model_name, use_auth_token=True) | |
self.model = AutoModel.from_pretrained(hf_model_name, use_auth_token=True).to(self.device) | |
else: | |
self.tokenizer = AutoTokenizer.from_pretrained(global_model_path) | |
self.model = AutoModel.from_pretrained(global_model_path).to(self.device) | |
self.max_len = self.tokenizer.max_len_single_sentence | |
self.embedding_dim = self.model.config.hidden_size | |
def search_results_filtering(self, pred, dists): | |
all_ctg_preds = [] | |
all_scores = [] | |
for ctg in db_data_types: | |
ctg_thresh = self.optimal_params[ctg]['thresh'] | |
ctg_sim_factor = self.optimal_params[ctg]['sim_factor'] | |
ctg_preds = [(ref, dist) for ref, dist in zip(pred, dists) | |
if ctg in ref and dist > ctg_thresh] | |
sorted_pd = sorted(ctg_preds, key=lambda x: x[1], reverse=True) | |
sorted_preds = [x[0] for x in sorted_pd] | |
sorted_dists = [x[1] for x in sorted_pd] | |
if len(sorted_dists): | |
if len(sorted_dists) > 1: | |
ratios = (sorted_dists[1:] / sorted_dists[0]) >= ctg_sim_factor | |
ratios = np.array([True, *ratios]) | |
else: | |
ratios = np.array([True]) | |
main_preds = np.array(sorted_preds)[np.where(ratios)].tolist() | |
scores = np.array(sorted_dists)[np.where(ratios)].tolist() | |
else: | |
main_preds = [] | |
scores = [] | |
all_ctg_preds.extend(main_preds[:self.mean_refs_count[ctg]]) | |
all_scores.extend(scores[:self.mean_refs_count[ctg]]) | |
sorted_values = [(ref, score) for ref, score in zip(all_ctg_preds, all_scores)] | |
sorted_values = sorted(sorted_values, key=lambda x: x[1], reverse=True) | |
sorted_preds = [x[0] for x in sorted_values] | |
sorted_scores = [x[1] for x in sorted_values] | |
return sorted_preds, sorted_scores | |
def get_most_relevant_teaser(self, | |
question: str = None, | |
doc_index: int = None): | |
teaser_indexes = { | |
"IndexFlatL2": faiss.IndexFlatL2(self.embedding_dim), | |
"IndexFlatIP": faiss.IndexFlatIP(self.embedding_dim) | |
} | |
teasers_index = teaser_indexes[self.index_type] | |
question_tokens = query_tokenization(question, self.tokenizer) | |
question_embedding = query_embed_extraction(question_tokens, self.model, | |
self.do_normalization) | |
teasers_texts = [teaser['summary_text'] for teaser in self.all_docs_info[doc_index]['chunks_embeddings']] | |
teasers_embeddings = [torch.unsqueeze(torch.Tensor(teaser['embedding']), 0) for teaser in | |
self.all_docs_info[doc_index]['chunks_embeddings']] | |
teasers_embeddings = torch.cat(teasers_embeddings, 0) | |
teasers_index.add(teasers_embeddings) | |
distances, indices = teasers_index.search(question_embedding, 10) | |
most_relevant_teaser = teasers_texts[indices[0][0]] | |
return most_relevant_teaser | |
def search(self, query, top=15, use_llm_for_teasers: bool = False): | |
query_tokens = query_tokenization(query, self.tokenizer) | |
query_embeds = query_embed_extraction(query_tokens, self.model, | |
self.do_normalization) | |
distances, indices = self.index_docs.search(query_embeds, len(self.all_docs_info)) | |
pred = [self.all_docs_info[x]['doc_name'] for x in indices[0]] | |
preds, scores = self.search_results_filtering(pred, distances[0]) | |
docs = [] | |
for ref in preds: | |
doc_index = self.docs_names.index(ref) | |
most_relevant_teaser = self.get_most_relevant_teaser(question=query, | |
doc_index=doc_index) | |
if use_llm_for_teasers: | |
most_relevant_teaser = self.rebuild_teaser_with_llm(question=query, | |
teaser=most_relevant_teaser) | |
docs.append(most_relevant_teaser) | |
return preds[:top], docs[:top], scores[:top] | |