dorogan
Update: changes in search API (teasers and docs texts were separated)
ac7cbfc
import os
# import json
import torch
# import pickle
import numpy as np
import faiss
from datasets import Dataset as dataset
from transformers import AutoTokenizer, AutoModel
from legal_info_search_utils.utils import query_tokenization, query_embed_extraction
import requests
import re
global_data_path = os.environ.get("GLOBAL_DATA_PATH", "legal_info_search_data/")
global_model_path = os.environ.get("GLOBAL_MODEL_PATH",
"legal_info_search_model/20240202_204910_ep8/")
# размеченные консультации
# data_path_consult = os.environ.get("DATA_PATH_CONSULT",
# global_data_path + "data_jsons_20240202.pkl")
data_path_consult = os.environ.get("DATA_PATH_CONSULT",
global_data_path + "all_docs_dataset_chunk_200_correct_tokenizer_for_develop_new_chunking")
# id консультаций, побитые на train / valid / test
# data_path_consult_ids = os.environ.get("DATA_PATH_CONSULT_IDS",
# global_data_path + "data_ids.json")
# предобработанные внутренние документы
# data_path_internal_docs = os.environ.get("DATA_PATH_INTERNAL_DOCS",
# global_data_path + "internal_docs.json")
# состав БД
# $ export DB_SUBSETS='["train", "valid", "test"]'
# db_subsets = os.environ.get("DB_SUBSETS", ["train", "valid", "test"])
# Отбор типов документов. В списке указать те, которые нужно оставить в БД.
# $ export DB_DATA_TYPES='["НКРФ", "ГКРФ", "ТКРФ"]'
teaser_pattern = r"Переделанный отрывок:"
db_data_types = os.environ.get("DB_DATA_TYPES", [
'НКРФ',
'ГКРФ',
'ТКРФ',
'Федеральный закон',
'Письмо Минфина',
'Письмо ФНС',
'Приказ ФНС',
'Постановление Правительства',
'Судебный документ',
'Внутренний документ'
])
device = os.environ.get("MODEL_DEVICE", 'cuda' if torch.cuda.is_available() else 'cpu')
# access token huggingface. Если задан, то используется модель с HF
hf_token = os.environ.get("HF_TOKEN", "")
hf_model_name = os.environ.get("HF_MODEL_NAME", "")
default_teaser_refactoring_prompt = (
f"Ты профессиональный налоговый консультант. У тебя есть вопрос и отрывок документа. "
f"Тебе нужно переделать этот отрывок так, чтобы в нём содержался ответ на вопрос, но в общих чертах и без конкретики, "
f"а также упоминалась конкретная ситуация, если она была в отрывке. "
f"В результате ты должен получить короткий отрывок, в котором есть только самая важная информация из исходного отрывка, а также названия документов из исходного отрывка. Тебе запрещается как-либо искажать исходную информацию. "
f"Используй такой формат ответа: 'Переделанный отрывок: *переделанный отрывок*'. "
f"Переделанный отрывок должен быть объёмом не более 50 слов. "
f"Ты больше ничего не пишешь, не рассуждаешь. "
f"Отвечай обязательно только на РУССКОМ языке!!!.")
class SemanticSearch:
def __init__(self,
index_type="IndexFlatIP",
do_embedding_norm=True,
do_normalization=True):
self.device = device
self.do_embedding_norm = do_embedding_norm
self.do_normalization = do_normalization
self.load_model()
indexes = {
"IndexFlatL2": faiss.IndexFlatL2(self.embedding_dim),
"IndexFlatIP": faiss.IndexFlatIP(self.embedding_dim)
}
self.optimal_params = {
'НКРФ': {'thresh': 0.58421, 'sim_factor': 0.89474},
'ГКРФ': {'thresh': 0.64737, 'sim_factor': 0.89474},
'ТКРФ': {'thresh': 0.58421, 'sim_factor': 0.94737},
'Федеральный закон': {'thresh': 0.58421, 'sim_factor': 0.94737},
'Письмо Минфина': {'thresh': 0.71053, 'sim_factor': 0.84211},
'Письмо ФНС': {'thresh': 0.64737, 'sim_factor': 0.94737},
'Приказ ФНС': {'thresh': 0.42632, 'sim_factor': 0.84211},
'Постановление Правительства': {'thresh': 0.58421, 'sim_factor': 0.94737},
'Судебный документ': {'thresh': 0.67895, 'sim_factor': 0.89474},
'Внутренний документ': {'thresh': 0.55263, 'sim_factor': 0.84211}
}
self.index_type = index_type
self.index_docs = indexes[self.index_type]
self.load_data()
self.docs_embeddings = [torch.unsqueeze(torch.Tensor(x['doc_embedding']), 0) for x in self.all_docs_info]
self.docs_embeddings = torch.cat(self.docs_embeddings, dim=0)
self.index_docs.add(self.docs_embeddings)
@staticmethod
def rebuild_teaser_with_llm(question: str = None,
teaser: str = None) -> str:
teaser_refactoring_prompt = default_teaser_refactoring_prompt + (f"\nВопрос: {question}"
f"\nОтрывок: {teaser}")
response = requests.post(url='https://muryshev-mixtral-api.hf.space/completion',
json={"prompt": f"[INST]{teaser_refactoring_prompt}[/INST]"})
rebuilded_teaser = re.sub(teaser_pattern, '', response.text)
return rebuilded_teaser
def load_data(self):
self.all_docs_info = dataset.load_from_disk(data_path_consult).to_list()
self.docs_names = [doc['doc_name'] for doc in self.all_docs_info]
self.mean_refs_count = {'НКРФ': 4,
'ГКРФ': 3,
'ТКРФ': 2,
'Федеральный закон': 2,
'Письмо Минфина': 3,
'Письмо ФНС': 2,
'Приказ ФНС': 2,
'Постановление Правительства': 2,
'Судебный документ': 3,
'Внутренний документ': 2}
def load_model(self):
if hf_token and hf_model_name:
self.tokenizer = AutoTokenizer.from_pretrained(hf_model_name, use_auth_token=True)
self.model = AutoModel.from_pretrained(hf_model_name, use_auth_token=True).to(self.device)
else:
self.tokenizer = AutoTokenizer.from_pretrained(global_model_path)
self.model = AutoModel.from_pretrained(global_model_path).to(self.device)
self.max_len = self.tokenizer.max_len_single_sentence
self.embedding_dim = self.model.config.hidden_size
def search_results_filtering(self, pred, dists):
all_ctg_preds = []
all_scores = []
for ctg in db_data_types:
ctg_thresh = self.optimal_params[ctg]['thresh']
ctg_sim_factor = self.optimal_params[ctg]['sim_factor']
ctg_preds = [(ref, dist) for ref, dist in zip(pred, dists)
if ctg in ref and dist > ctg_thresh]
sorted_pd = sorted(ctg_preds, key=lambda x: x[1], reverse=True)
sorted_preds = [x[0] for x in sorted_pd]
sorted_dists = [x[1] for x in sorted_pd]
if len(sorted_dists):
if len(sorted_dists) > 1:
ratios = (sorted_dists[1:] / sorted_dists[0]) >= ctg_sim_factor
ratios = np.array([True, *ratios])
else:
ratios = np.array([True])
main_preds = np.array(sorted_preds)[np.where(ratios)].tolist()
scores = np.array(sorted_dists)[np.where(ratios)].tolist()
else:
main_preds = []
scores = []
all_ctg_preds.extend(main_preds[:self.mean_refs_count[ctg]])
all_scores.extend(scores[:self.mean_refs_count[ctg]])
sorted_values = [(ref, score) for ref, score in zip(all_ctg_preds, all_scores)]
sorted_values = sorted(sorted_values, key=lambda x: x[1], reverse=True)
sorted_preds = [x[0] for x in sorted_values]
sorted_scores = [x[1] for x in sorted_values]
return sorted_preds, sorted_scores
def get_most_relevant_teaser(self,
question: str = None,
doc_index: int = None):
teaser_indexes = {
"IndexFlatL2": faiss.IndexFlatL2(self.embedding_dim),
"IndexFlatIP": faiss.IndexFlatIP(self.embedding_dim)
}
teasers_index = teaser_indexes[self.index_type]
question_tokens = query_tokenization(question, self.tokenizer)
question_embedding = query_embed_extraction(question_tokens, self.model,
self.do_normalization)
teasers_texts = [teaser['summary_text'] for teaser in self.all_docs_info[doc_index]['chunks_embeddings']]
teasers_embeddings = [torch.unsqueeze(torch.Tensor(teaser['embedding']), 0) for teaser in
self.all_docs_info[doc_index]['chunks_embeddings']]
teasers_embeddings = torch.cat(teasers_embeddings, 0)
teasers_index.add(teasers_embeddings)
distances, indices = teasers_index.search(question_embedding, 10)
most_relevant_teaser = teasers_texts[indices[0][0]]
return most_relevant_teaser
def search(self, query, top=15, use_llm_for_teasers: bool = False):
query_tokens = query_tokenization(query, self.tokenizer)
query_embeds = query_embed_extraction(query_tokens, self.model,
self.do_normalization)
distances, indices = self.index_docs.search(query_embeds, len(self.all_docs_info))
pred = [self.all_docs_info[x]['doc_name'] for x in indices[0]]
preds, scores = self.search_results_filtering(pred, distances[0])
teasers = []
docs = []
for ref in preds:
doc_index = self.docs_names.index(ref)
doc_text = self.all_docs_info[doc_index]['doc_text']
docs.append(doc_text) # Add the relevant document text
most_relevant_teaser = self.get_most_relevant_teaser(question=query,
doc_index=doc_index)
if use_llm_for_teasers:
most_relevant_teaser = self.rebuild_teaser_with_llm(question=query,
teaser=most_relevant_teaser)
teasers.append(most_relevant_teaser) # Add the most relevant teaser
return preds[:top], docs[:top], teasers[:top], scores[:top]