muryshev's picture
Init
729d130
raw
history blame
No virus
9.58 kB
import json
import copy
import torch
import torch.nn.functional as F
import numpy as np
import faiss
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast
all_types_but_courts = ['НКРФ', 'ГКРФ', 'ТКРФ', 'Федеральный закон', 'Письмо Минфина',
'Письмо ФНС', 'Приказ ФНС', 'Постановление Правительства']
class FaissDocsDataset(Dataset):
def __init__(self, data):
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
def preprocess_inputs(inputs, device):
return {k: v[:, 0, :].to(device) for k, v in inputs.items()}
def get_subsets_for_db(subsets, data_ids, all_docs):
subsets = [data_ids[ss_name] for ss_name in subsets]
subsets = [x for ss in subsets for x in ss]
all_docs_db = {k: v for k, v in all_docs.items()
if v['id'] in subsets}
unique_refs = set([ref for doc in all_docs_db.values()
for ref, text in doc['added_refs'].items()])
db_data = {ref: text for doc in all_docs_db.values()
for ref, text in doc['added_refs'].items() if ref in unique_refs}
return db_data
def get_subsets_for_qa(subsets, data_ids, all_docs):
subsets = [data_ids[ss_name] for ss_name in subsets]
subsets = [x for ss in subsets for x in ss]
all_docs_qa = {k: v for k, v in all_docs.items()
if v['id'] in subsets}
return all_docs_qa
def filter_db_data_types(text_parts, db_data_in):
filtered_db_data = {}
db_data = copy.deepcopy(db_data_in)
check_if_courts = 'Суды' in text_parts
for ref, text in db_data.items():
check_not_other = not any([True for x in all_types_but_courts if x in ref])
court_condition = check_if_courts and check_not_other
if court_condition or any([True for x in text_parts if x in ref]):
filtered_db_data[ref] = text
return filtered_db_data
def filter_qa_data_types(text_parts, all_docs_in):
filtered_all_docs = {}
all_docs = copy.deepcopy(all_docs_in)
for doc_key, doc in all_docs.items():
if not len(doc['added_refs']):
filtered_all_docs[doc_key] = doc
continue
filtered_refs = {}
check_if_courts = 'Суды' in text_parts
for ref, text in doc['added_refs'].items():
check_not_other = not any([True for x in all_types_but_courts if x in ref])
court_condition = check_if_courts and check_not_other
if court_condition or any([True for x in text_parts if x in ref]):
filtered_refs[ref] = text
filtered_all_docs[doc_key] = doc
filtered_all_docs[doc_key]['added_refs'] = filtered_refs
return filtered_all_docs
def db_tokenization(filtered_db_data, tokenizer, max_len=510):
index_keys = {}
index_toks = {}
for key_idx, (ref, text) in enumerate(tqdm(filtered_db_data.items(),
desc="Tokenizing DB refs")):
index_keys[key_idx] = ref
text = "passage: " + text
index_toks[key_idx] = tokenizer(text, return_tensors="pt",
padding='max_length', truncation=True,
max_length=max_len)
return index_keys, index_toks
def qa_tokenization(all_docs_qa, tokenizer, max_len=510):
ss_docs = []
for doc in tqdm(all_docs_qa.values(), desc="Tokenizing QA docs"):
text = doc['title'] + '\n' + doc['question']
text = "query: " + text
text = tokenizer(text, return_tensors="pt",
padding='max_length', truncation=True,
max_length=max_len)
ss_docs.append([text, list(doc['added_refs'].keys())])
val_questions = [x[0] for x in ss_docs]
val_refs = {idx: x[1] for idx, x in enumerate(ss_docs)}
return val_questions, val_refs
def query_tokenization(text, tokenizer, max_len=510):
text = "query: " + text
text = tokenizer(text, return_tensors="pt",
padding='max_length', truncation=True,
max_length=max_len)
return text
def query_embed_extraction(tokens, model, do_normalization=True):
model.eval()
device = model.device
with torch.no_grad():
with autocast():
inputs = {k: v[:, :].to(device) for k, v in tokens.items()}
outputs = model(**inputs)
embedding = outputs.last_hidden_state[:, 0].cpu()
if do_normalization:
embedding = F.normalize(embedding, dim=-1)
return embedding.numpy()
def extract_text_embeddings(index_toks, val_questions, model,
do_normalization=True, faiss_batch_size=16):
faiss_dataset = FaissDocsDataset(list(index_toks.values()))
db_data_loader = DataLoader(faiss_dataset, batch_size=faiss_batch_size)
ss_val_dataset = FaissDocsDataset(val_questions)
qu_data_loader = DataLoader(ss_val_dataset, batch_size=faiss_batch_size)
model.eval()
device = model.device
docs_embeds = []
questions_embeds = []
with torch.no_grad():
for batch in tqdm(db_data_loader, desc="db_embeds_extraction"):
with autocast():
outputs = model(**preprocess_inputs(batch, device))
docs_embeds.extend(outputs.last_hidden_state[:, 0].cpu())
for batch in tqdm(qu_data_loader, desc="qu_embeds_extraction"):
with autocast():
outputs = model(**preprocess_inputs(batch, device))
questions_embeds.extend(outputs.last_hidden_state[:, 0].cpu())
docs_embeds_faiss = [torch.unsqueeze(x, 0) for x in docs_embeds]
docs_embeds_faiss = torch.cat(docs_embeds_faiss)
questions_embeds_faiss = [torch.unsqueeze(x, 0) for x in questions_embeds]
questions_embeds_faiss = torch.cat(questions_embeds_faiss)
if do_normalization:
docs_embeds_faiss = F.normalize(docs_embeds_faiss, dim=-1)
questions_embeds_faiss = F.normalize(questions_embeds_faiss, dim=-1)
return docs_embeds_faiss.numpy(), questions_embeds_faiss.numpy()
def run_semantic_search(index, model, tokenizer, filtered_db_data, all_docs_qa,
do_normalization=True, faiss_batch_size=16, topk=100):
index_keys, index_toks = db_tokenization(filtered_db_data, tokenizer)
val_questions, val_refs = qa_tokenization(all_docs_qa, tokenizer)
docs_embeds_faiss, questions_embeds_faiss = extract_text_embeddings(index_toks,
val_questions, model, do_normalization, faiss_batch_size)
index.add(docs_embeds_faiss)
pred = {}
true = {}
all_distances = []
for idx, (q_embed, refs) in enumerate(zip(questions_embeds_faiss, val_refs.values())):
distances, indices = index.search(np.expand_dims(q_embed, 0), topk)
pred[idx] = [index_keys[x] for x in indices[0]]
true[idx] = list(refs)
all_distances.append(distances)
return pred, true, all_distances
def filter_ref_parts(ref_dict, filter_parts):
filtered_dict = {}
for k, refs in ref_dict.items():
filtered_refs = [" ".join([x for x in ref.split() if not any([True for part in filter_parts if part in x])])
for ref in refs]
filtered_dict[k] = filtered_refs
return filtered_dict
def get_final_metrics(pred, true, categories, top_k_values,
metrics_func, dynamic_topk=False):
metrics = {}
for top_k in top_k_values:
ctg_metrics = {}
for ctg in categories:
ctg_pred, ctg_true = get_exact_ctg_data(pred, true, ctg)
metrics_at_k = metrics_func(ctg_pred, ctg_true, top_k, dynamic_topk)
for mk in metrics_at_k.keys():
metrics_at_k[mk] = round(metrics_at_k[mk] * 100, 6)
ctg_metrics[ctg] = metrics_at_k
metrics[top_k] = ctg_metrics
return metrics
def get_exact_ctg_data(pred_in, true_in, ctg):
if ctg == "all":
return pred_in, true_in
out_pred = {}
out_true = {}
check_if_courts = ctg == "Суды"
for idx, (pred, true) in zip(true_in.keys(), zip(pred_in.values(), true_in.values())):
if check_if_courts:
ctg_refs_true = [ref for ref in true
if not any([True for x in all_types_but_courts if x in ref])]
ctg_refs_pred = [ref for ref in pred
if not any([True for x in all_types_but_courts if x in ref])]
else:
ctg_refs_true = [ref for ref in true if ctg in ref]
ctg_refs_pred = [ref for ref in pred if ctg in ref]
out_true[idx] = ctg_refs_true
out_pred[idx] = ctg_refs_pred
return out_pred, out_true
def print_metrics(metrics, ref_categories):
first_ctg = metrics[list(metrics.keys())[0]]
metric_tags = list(first_ctg[list(first_ctg.keys())[0]].keys())
metric_tags = [x.split('@')[0] for x in metric_tags]
print('\t', *metric_tags, sep='\t')
for ctg, ctg_short in ref_categories.items():
for top_k, vals in metrics.items():
for ctg_tag, ctg_val in vals.items():
if ctg_tag == ctg:
ctg_vals_str = ["{:.3f}".format(x).zfill(6) for x in ctg_val.values()]
print(f"{ctg_short}@{top_k}", *ctg_vals_str, sep='\t\t')