muryshev's picture
Init
729d130
raw
history blame
10 kB
import os
import json
import torch
import pickle
import numpy as np
import faiss
from transformers import AutoTokenizer, AutoModel
from legal_info_search_utils.utils import get_subsets_for_db, get_subsets_for_qa
from legal_info_search_utils.utils import filter_db_data_types, filter_qa_data_types
from legal_info_search_utils.utils import db_tokenization, qa_tokenization
from legal_info_search_utils.utils import extract_text_embeddings, filter_ref_parts
from legal_info_search_utils.utils import print_metrics, get_final_metrics
from legal_info_search_utils.utils import query_tokenization, query_embed_extraction
from legal_info_search_utils.metrics import calculate_metrics_at_k
global_data_path = os.environ.get("GLOBAL_DATA_PATH", "legal_info_search_data/")
global_model_path = os.environ.get("GLOBAL_MODEL_PATH",
"legal_info_search_model/20240120_122822_ep6/")
# размеченные консультации
data_path_consult = os.environ.get("DATA_PATH_CONSULT",
global_data_path + "data_jsons_20240119.pkl")
# id консультаций, побитые на train / valid / test
data_path_consult_ids = os.environ.get("DATA_PATH_CONSULT_IDS",
global_data_path + "data_ids.json")
# состав БД
# $ export DB_SUBSETS='["train", "valid", "test"]'
db_subsets = os.environ.get("DB_SUBSETS", ["train", "valid", "test"])
# Отбор типов документов. В списке указать те, которые нужно оставить в БД.
# $ export DB_DATA_TYPES='["НКРФ", "ГКРФ", "ТКРФ"]'
db_data_types = os.environ.get("DB_DATA_TYPES", [
'НКРФ',
'ГКРФ',
'ТКРФ',
'Федеральный закон',
'Письмо Минфина',
'Письмо ФНС',
'Приказ ФНС',
'Постановление Правительства'
])
device = os.environ.get("MODEL_DEVICE", 'cuda' if torch.cuda.is_available() else 'cpu')
# access token huggingface. Если задан, то используется модель с HF
hf_token = os.environ.get("HF_TOKEN", "")
hf_model_name = os.environ.get("HF_MODEL_NAME", "")
class SemanticSearch:
def __init__(self, device, index_type="IndexFlatIP", do_embedding_norm=True,
faiss_batch_size=8, do_normalization=True):
self.device = device
self.do_embedding_norm = do_embedding_norm
self.faiss_batch_size = faiss_batch_size
self.do_normalization = do_normalization
self.load_model()
indexes = {
"IndexFlatL2": faiss.IndexFlatL2(self.embedding_dim),
"IndexFlatIP": faiss.IndexFlatIP(self.embedding_dim)
}
self.index = indexes[index_type]
self.load_data()
self.preproces_data()
self.test_search()
def load_data(self):
with open(data_path_consult, "rb") as f:
all_docs = pickle.load(f)
with open(data_path_consult_ids, "r", encoding="utf-8") as f:
data_ids = json.load(f)
db_data = get_subsets_for_db(db_subsets, data_ids, all_docs)
filtered_all_docs = filter_qa_data_types(db_data_types, all_docs)
self.filtered_db_data = filter_db_data_types(db_data_types, db_data)
self.all_docs_qa = get_subsets_for_qa(["valid"], data_ids, filtered_all_docs)
def load_model(self):
if hf_token and hf_model_name:
self.tokenizer = AutoTokenizer.from_pretrained(hf_model_name, use_auth_token=True)
self.model = AutoModel.from_pretrained(hf_model_name, use_auth_token=True).to(self.device)
else:
self.tokenizer = AutoTokenizer.from_pretrained(global_model_path)
self.model = AutoModel.from_pretrained(global_model_path).to(self.device)
self.max_len = self.tokenizer.max_len_single_sentence
self.embedding_dim = self.model.config.hidden_size
def preproces_data(self):
index_keys, index_toks = db_tokenization(self.filtered_db_data, self.tokenizer)
val_questions, val_refs = qa_tokenization(self.all_docs_qa, self.tokenizer)
docs_embeds_faiss, questions_embeds_faiss = extract_text_embeddings(index_toks,
val_questions, self.model, self.do_normalization, self.faiss_batch_size)
self.index.add(docs_embeds_faiss)
self.index_keys = index_keys
self.index_toks = index_toks
self.val_questions = val_questions
self.val_refs = val_refs
self.docs_embeds_faiss = docs_embeds_faiss
self.questions_embeds_faiss = questions_embeds_faiss
self.optimal_params = {
'НКРФ': {
'thresh': 0.613793, 'sim_factor': 0.878947, 'diff_n': 0},
'ГКРФ': {
'thresh': 0.758620, 'sim_factor': 0.878947, 'diff_n': 0},
'ТКРФ': {
'thresh': 0.734482, 'sim_factor': 0.9, 'diff_n': 0},
'Федеральный закон': {
'thresh': 0.734482, 'sim_factor': 0.5, 'diff_n': 0},
'Письмо Минфина': {
'thresh': 0.782758, 'sim_factor': 0.5, 'diff_n': 0},
'Письмо ФНС': {
'thresh': 0.879310, 'sim_factor': 0.5, 'diff_n': 0},
'Приказ ФНС': {
'thresh': 0.806896, 'sim_factor': 0.5, 'diff_n': 0},
'Постановление Правительства': {
'thresh': 0.782758, 'sim_factor': 0.5, 'diff_n': 0}
}
self.ref_categories = {
'all': 'all',
'НКРФ': 'НКРФ',
'ГКРФ': 'ГКРФ',
'ТКРФ': 'ТКРФ',
'Федеральный закон': 'ФЗ',
'Суды': 'Суды',
'Письмо Минфина': 'Письмо МФ',
'Письмо ФНС': 'Письмо ФНС',
'Приказ ФНС': 'Приказ ФНС',
'Постановление Правительства': 'Пост. Прав.'
}
def test_search(self):
topk = len(self.filtered_db_data)
pred_raw = {}
true = {}
all_distances = []
for idx, (q_embed, refs) in enumerate(zip(self.questions_embeds_faiss,
self.val_refs.values())):
distances, indices = self.index.search(np.expand_dims(q_embed, 0), topk)
pred_raw[idx] = [self.index_keys[x] for x in indices[0]]
true[idx] = list(refs)
all_distances.append(distances)
pred = {}
for idx, p, d in zip(true.keys(), pred_raw.values(), all_distances):
fp, fs = self.search_results_filtering(p, d[0])
pred[idx] = fp
# раскомментировать нужное. Если всё закомментировано - метрики
# посчтаются "как есть", с учетом полной иерархии
filter_parts = [
# "абз.",
# "пп.",
# "п."
]
filtered_pred = filter_ref_parts(pred, filter_parts)
filtered_true = filter_ref_parts(true, filter_parts)
metrics = get_final_metrics(filtered_pred, filtered_true,
self.ref_categories.keys(), [0],
metrics_func=calculate_metrics_at_k, dynamic_topk=True)
print_metrics(metrics, self.ref_categories)
def search_results_filtering(self, pred, dists):
all_ctg_preds = []
all_scores = []
for ctg in db_data_types:
ctg_thresh = self.optimal_params[ctg]['thresh']
ctg_sim_factor = self.optimal_params[ctg]['sim_factor']
ctg_diff_n = self.optimal_params[ctg]['diff_n']
ctg_preds = [(ref, dist) for ref, dist in zip(pred, dists)
if ctg in ref and dist > ctg_thresh]
sorted_pd = sorted(ctg_preds, key=lambda x: x[1], reverse=True)
sorted_preds = [x[0] for x in sorted_pd]
sorted_dists = [x[1] for x in sorted_pd]
if len(sorted_dists):
diffs = np.diff(sorted_dists, ctg_diff_n)
if len(diffs):
n_preds = np.argmax(diffs) + ctg_diff_n + 1
else:
n_preds = 0
if len(sorted_dists) > 1:
ratios = (sorted_dists[1:] / sorted_dists[0]) >= ctg_sim_factor
ratios = np.array([True, *ratios])
else:
ratios = np.array([True])
main_preds = np.array(sorted_preds)[np.where(ratios)].tolist()
scores = np.array(sorted_dists)[np.where(ratios)].tolist()
if ctg_diff_n > 0 and n_preds > 0:
main_preds = main_preds[:n_preds]
scores = scores[:n_preds]
else:
main_preds = []
scores = []
all_ctg_preds.extend(main_preds)
all_scores.extend(scores)
sorted_values = [(ref, score) for ref, score in zip(all_ctg_preds, all_scores)]
sorted_values = sorted(sorted_values, key=lambda x: x[1], reverse=True)
sorted_preds = [x[0] for x in sorted_values]
sorted_scores = [x[1] for x in sorted_values]
return sorted_preds, sorted_scores
def search(self, query, top=10):
query_tokens = query_tokenization(query, self.tokenizer)
query_embeds = query_embed_extraction(query_tokens, self.model,
self.do_normalization)
distances, indices = self.index.search(query_embeds, len(self.filtered_db_data))
pred = [self.index_keys[x] for x in indices[0]]
preds, scores = self.search_results_filtering(pred, distances[0])
docs = [self.filtered_db_data[ref] for ref in preds]
return preds[:top], docs[:top], scores[:top]