muryshev's picture
Hf related changes
804ddc3
raw
history blame
9.99 kB
import os
import json
import torch
import pickle
import numpy as np
import faiss
from transformers import AutoTokenizer, AutoModel
from legal_info_search_utils.utils import get_subsets_for_db, get_subsets_for_qa
from legal_info_search_utils.utils import filter_db_data_types, filter_qa_data_types
from legal_info_search_utils.utils import db_tokenization, qa_tokenization
from legal_info_search_utils.utils import extract_text_embeddings, filter_ref_parts
from legal_info_search_utils.utils import print_metrics, get_final_metrics
from legal_info_search_utils.utils import query_tokenization, query_embed_extraction
from legal_info_search_utils.metrics import calculate_metrics_at_k
global_data_path = os.environ.get("GLOBAL_DATA_PATH", "legal_info_search_data/")
global_model_path = os.environ.get("GLOBAL_MODEL_PATH", "e5_large_rus_finetuned_20240120_122822_ep6")
# размеченные консультации
data_path_consult = os.environ.get("DATA_PATH_CONSULT",
global_data_path + "data_jsons_20240119.pkl")
# id консультаций, побитые на train / valid / test
data_path_consult_ids = os.environ.get("DATA_PATH_CONSULT_IDS",
global_data_path + "data_ids.json")
# состав БД
# $ export DB_SUBSETS='["train", "valid", "test"]'
db_subsets = os.environ.get("DB_SUBSETS", ["train", "valid", "test"])
# Отбор типов документов. В списке указать те, которые нужно оставить в БД.
# $ export DB_DATA_TYPES='["НКРФ", "ГКРФ", "ТКРФ"]'
db_data_types = os.environ.get("DB_DATA_TYPES", [
'НКРФ',
'ГКРФ',
'ТКРФ',
'Федеральный закон',
'Письмо Минфина',
'Письмо ФНС',
'Приказ ФНС',
'Постановление Правительства'
])
device = os.environ.get("MODEL_DEVICE", 'cuda' if torch.cuda.is_available() else 'cpu')
# access token huggingface. Если задан, то используется модель с HF
hf_token = os.environ.get("HF_TOKEN", "")
hf_model_name = os.environ.get("HF_MODEL_NAME", "")
class SemanticSearch:
def __init__(self, index_type="IndexFlatIP", do_embedding_norm=True,
faiss_batch_size=8, do_normalization=True):
self.device = device
self.do_embedding_norm = do_embedding_norm
self.faiss_batch_size = faiss_batch_size
self.do_normalization = do_normalization
self.load_model()
indexes = {
"IndexFlatL2": faiss.IndexFlatL2(self.embedding_dim),
"IndexFlatIP": faiss.IndexFlatIP(self.embedding_dim)
}
self.index = indexes[index_type]
self.load_data()
self.preproces_data()
self.test_search()
def load_data(self):
with open(data_path_consult, "rb") as f:
all_docs = pickle.load(f)
with open(data_path_consult_ids, "r", encoding="utf-8") as f:
data_ids = json.load(f)
db_data = get_subsets_for_db(db_subsets, data_ids, all_docs)
filtered_all_docs = filter_qa_data_types(db_data_types, all_docs)
self.filtered_db_data = filter_db_data_types(db_data_types, db_data)
self.all_docs_qa = get_subsets_for_qa(["valid"], data_ids, filtered_all_docs)
def load_model(self):
if hf_token and hf_model_name:
self.tokenizer = AutoTokenizer.from_pretrained(hf_model_name, use_auth_token=True)
self.model = AutoModel.from_pretrained(hf_model_name, use_auth_token=True).to(self.device)
else:
self.tokenizer = AutoTokenizer.from_pretrained(global_model_path)
self.model = AutoModel.from_pretrained(global_model_path).to(self.device)
self.max_len = self.tokenizer.max_len_single_sentence
self.embedding_dim = self.model.config.hidden_size
def preproces_data(self):
index_keys, index_toks = db_tokenization(self.filtered_db_data, self.tokenizer)
val_questions, val_refs = qa_tokenization(self.all_docs_qa, self.tokenizer)
docs_embeds_faiss, questions_embeds_faiss = extract_text_embeddings(index_toks,
val_questions, self.model, self.do_normalization, self.faiss_batch_size)
self.index.add(docs_embeds_faiss)
self.index_keys = index_keys
self.index_toks = index_toks
self.val_questions = val_questions
self.val_refs = val_refs
self.docs_embeds_faiss = docs_embeds_faiss
self.questions_embeds_faiss = questions_embeds_faiss
self.optimal_params = {
'НКРФ': {
'thresh': 0.613793, 'sim_factor': 0.878947, 'diff_n': 0},
'ГКРФ': {
'thresh': 0.758620, 'sim_factor': 0.878947, 'diff_n': 0},
'ТКРФ': {
'thresh': 0.734482, 'sim_factor': 0.9, 'diff_n': 0},
'Федеральный закон': {
'thresh': 0.734482, 'sim_factor': 0.5, 'diff_n': 0},
'Письмо Минфина': {
'thresh': 0.782758, 'sim_factor': 0.5, 'diff_n': 0},
'Письмо ФНС': {
'thresh': 0.879310, 'sim_factor': 0.5, 'diff_n': 0},
'Приказ ФНС': {
'thresh': 0.806896, 'sim_factor': 0.5, 'diff_n': 0},
'Постановление Правительства': {
'thresh': 0.782758, 'sim_factor': 0.5, 'diff_n': 0}
}
self.ref_categories = {
'all': 'all',
'НКРФ': 'НКРФ',
'ГКРФ': 'ГКРФ',
'ТКРФ': 'ТКРФ',
'Федеральный закон': 'ФЗ',
'Суды': 'Суды',
'Письмо Минфина': 'Письмо МФ',
'Письмо ФНС': 'Письмо ФНС',
'Приказ ФНС': 'Приказ ФНС',
'Постановление Правительства': 'Пост. Прав.'
}
def test_search(self):
topk = len(self.filtered_db_data)
pred_raw = {}
true = {}
all_distances = []
for idx, (q_embed, refs) in enumerate(zip(self.questions_embeds_faiss,
self.val_refs.values())):
distances, indices = self.index.search(np.expand_dims(q_embed, 0), topk)
pred_raw[idx] = [self.index_keys[x] for x in indices[0]]
true[idx] = list(refs)
all_distances.append(distances)
pred = {}
for idx, p, d in zip(true.keys(), pred_raw.values(), all_distances):
fp, fs = self.search_results_filtering(p, d[0])
pred[idx] = fp
# раскомментировать нужное. Если всё закомментировано - метрики
# посчтаются "как есть", с учетом полной иерархии
filter_parts = [
# "абз.",
# "пп.",
# "п."
]
filtered_pred = filter_ref_parts(pred, filter_parts)
filtered_true = filter_ref_parts(true, filter_parts)
metrics = get_final_metrics(filtered_pred, filtered_true,
self.ref_categories.keys(), [0],
metrics_func=calculate_metrics_at_k, dynamic_topk=True)
print_metrics(metrics, self.ref_categories)
def search_results_filtering(self, pred, dists):
all_ctg_preds = []
all_scores = []
for ctg in db_data_types:
ctg_thresh = self.optimal_params[ctg]['thresh']
ctg_sim_factor = self.optimal_params[ctg]['sim_factor']
ctg_diff_n = self.optimal_params[ctg]['diff_n']
ctg_preds = [(ref, dist) for ref, dist in zip(pred, dists)
if ctg in ref and dist > ctg_thresh]
sorted_pd = sorted(ctg_preds, key=lambda x: x[1], reverse=True)
sorted_preds = [x[0] for x in sorted_pd]
sorted_dists = [x[1] for x in sorted_pd]
if len(sorted_dists):
diffs = np.diff(sorted_dists, ctg_diff_n)
if len(diffs):
n_preds = np.argmax(diffs) + ctg_diff_n + 1
else:
n_preds = 0
if len(sorted_dists) > 1:
ratios = (sorted_dists[1:] / sorted_dists[0]) >= ctg_sim_factor
ratios = np.array([True, *ratios])
else:
ratios = np.array([True])
main_preds = np.array(sorted_preds)[np.where(ratios)].tolist()
scores = np.array(sorted_dists)[np.where(ratios)].tolist()
if ctg_diff_n > 0 and n_preds > 0:
main_preds = main_preds[:n_preds]
scores = scores[:n_preds]
else:
main_preds = []
scores = []
all_ctg_preds.extend(main_preds)
all_scores.extend(scores)
sorted_values = [(ref, score) for ref, score in zip(all_ctg_preds, all_scores)]
sorted_values = sorted(sorted_values, key=lambda x: x[1], reverse=True)
sorted_preds = [x[0] for x in sorted_values]
sorted_scores = [x[1] for x in sorted_values]
return sorted_preds, sorted_scores
def search(self, query, top=10):
query_tokens = query_tokenization(query, self.tokenizer)
query_embeds = query_embed_extraction(query_tokens, self.model,
self.do_normalization)
distances, indices = self.index.search(query_embeds, len(self.filtered_db_data))
pred = [self.index_keys[x] for x in indices[0]]
preds, scores = self.search_results_filtering(pred, distances[0])
docs = [self.filtered_db_data[ref] for ref in preds]
return preds[:top], docs[:top], scores[:top]