import json import copy import torch import torch.nn.functional as F import numpy as np import faiss from tqdm import tqdm from torch.utils.data import Dataset, DataLoader from torch.cuda.amp import autocast court_text_splitter = "Весь текст судебного документа: " class FaissDocsDataset(Dataset): def __init__(self, data): self.data = data def __len__(self): return len(self.data) def __getitem__(self, idx): return self.data[idx] def preprocess_inputs(inputs, device): return {k: v[:, 0, :].to(device) for k, v in inputs.items()} def get_subsets_for_db(subsets, data_ids, all_docs): subsets = [data_ids[ss_name] for ss_name in subsets] subsets = [x for ss in subsets for x in ss] all_docs_db = {k: v for k, v in all_docs.items() if v['id'] in subsets} unique_refs = set([ref for doc in all_docs_db.values() for ref, text in doc['added_refs'].items()]) db_data = {ref: text for doc in all_docs_db.values() for ref, text in doc['added_refs'].items() if ref in unique_refs} return db_data def get_subsets_for_qa(subsets, data_ids, all_docs): subsets = [data_ids[ss_name] for ss_name in subsets] subsets = [x for ss in subsets for x in ss] all_docs_qa = {k: v for k, v in all_docs.items() if v['id'] in subsets} return all_docs_qa def filter_db_data_types(text_parts, db_data_in): filtered_db_data = {} db_data = copy.deepcopy(db_data_in) for ref, text in db_data.items(): if any([True for x in text_parts if x in ref]): filtered_db_data[ref] = text return filtered_db_data def filter_qa_data_types(text_parts, all_docs_in): filtered_all_docs = {} all_docs = copy.deepcopy(all_docs_in) for doc_key, doc in all_docs.items(): if not len(doc['added_refs']): filtered_all_docs[doc_key] = doc continue filtered_refs = {} for ref, text in doc['added_refs'].items(): if any([True for x in text_parts if x in ref]): filtered_refs[ref] = text filtered_all_docs[doc_key] = doc filtered_all_docs[doc_key]['added_refs'] = filtered_refs return filtered_all_docs def db_tokenization(filtered_db_data, tokenizer, max_len=510): index_keys = {} index_toks = {} for key_idx, (ref, text) in enumerate(tqdm(filtered_db_data.items(), desc="Tokenizing DB refs")): index_keys[key_idx] = ref text = "passage: " + text index_toks[key_idx] = tokenizer(text, return_tensors="pt", padding='max_length', truncation=True, max_length=max_len) return index_keys, index_toks def qa_tokenization(all_docs_qa, tokenizer, max_len=510): ss_docs = [] for doc in tqdm(all_docs_qa.values(), desc="Tokenizing QA docs"): text = doc['title'] + '\n' + doc['question'] text = "query: " + text text = tokenizer(text, return_tensors="pt", padding='max_length', truncation=True, max_length=max_len) ss_docs.append([text, list(doc['added_refs'].keys())]) val_questions = [x[0] for x in ss_docs] val_refs = {idx: x[1] for idx, x in enumerate(ss_docs)} return val_questions, val_refs def query_tokenization(text, tokenizer, max_len=510): text = "query: " + text text = tokenizer(text, return_tensors="pt", padding='max_length', truncation=True, max_length=max_len) return text def query_embed_extraction(tokens, model, do_normalization=True): model.eval() device = model.device with torch.no_grad(): with autocast(): inputs = {k: v[:, :].to(device) for k, v in tokens.items()} outputs = model(**inputs) embedding = outputs.last_hidden_state[:, 0].cpu() if do_normalization: embedding = F.normalize(embedding, dim=-1) return embedding.numpy() def extract_text_embeddings(index_toks, val_questions, model, do_normalization=True, faiss_batch_size=16): faiss_dataset = FaissDocsDataset(list(index_toks.values())) db_data_loader = DataLoader(faiss_dataset, batch_size=faiss_batch_size) ss_val_dataset = FaissDocsDataset(val_questions) qu_data_loader = DataLoader(ss_val_dataset, batch_size=faiss_batch_size) model.eval() device = model.device docs_embeds = [] questions_embeds = [] with torch.no_grad(): for batch in tqdm(db_data_loader, desc="db_embeds_extraction"): with autocast(): outputs = model(**preprocess_inputs(batch, device)) docs_embeds.extend(outputs.last_hidden_state[:, 0].cpu()) for batch in tqdm(qu_data_loader, desc="qu_embeds_extraction"): with autocast(): outputs = model(**preprocess_inputs(batch, device)) questions_embeds.extend(outputs.last_hidden_state[:, 0].cpu()) docs_embeds_faiss = [torch.unsqueeze(x, 0) for x in docs_embeds] docs_embeds_faiss = torch.cat(docs_embeds_faiss) questions_embeds_faiss = [torch.unsqueeze(x, 0) for x in questions_embeds] questions_embeds_faiss = torch.cat(questions_embeds_faiss) if do_normalization: docs_embeds_faiss = F.normalize(docs_embeds_faiss, dim=-1) questions_embeds_faiss = F.normalize(questions_embeds_faiss, dim=-1) return docs_embeds_faiss.numpy(), questions_embeds_faiss.numpy() def filter_ref_parts(ref_dict, filter_parts): filtered_dict = {} for k, refs in ref_dict.items(): filtered_refs = [" ".join([x for x in ref.split() if not any([True for part in filter_parts if part in x])]) for ref in refs] filtered_dict[k] = filtered_refs return filtered_dict def get_final_metrics(pred, true, categories, top_k_values, metrics_func, metrics_func_params): metrics = {} for top_k in top_k_values: ctg_metrics = {} for ctg in categories: ctg_pred, ctg_true = get_exact_ctg_data(pred, true, ctg) metrics_at_k = metrics_func(ctg_pred, ctg_true, top_k, **metrics_func_params) for mk in metrics_at_k.keys(): metrics_at_k[mk] = round(metrics_at_k[mk] * 100, 6) ctg_metrics[ctg] = metrics_at_k metrics[top_k] = ctg_metrics return metrics def get_exact_ctg_data(pred_in, true_in, ctg): if ctg == "all": return pred_in, true_in out_pred = {} out_true = {} for idx, (pred, true) in zip(true_in.keys(), zip(pred_in.values(), true_in.values())): ctg_refs_true = [ref for ref in true if ctg in ref] ctg_refs_pred = [ref for ref in pred if ctg in ref] out_true[idx] = ctg_refs_true out_pred[idx] = ctg_refs_pred return out_pred, out_true def print_metrics(metrics, ref_categories): first_ctg = metrics[list(metrics.keys())[0]] metric_tags = list(first_ctg[list(first_ctg.keys())[0]].keys()) metric_tags = [x.split('@')[0] for x in metric_tags] print('\t', *metric_tags, sep='\t') for ctg, ctg_short in ref_categories.items(): for top_k, vals in metrics.items(): for ctg_tag, ctg_val in vals.items(): if ctg_tag == ctg: ctg_vals_str = ["{:.3f}".format(x).zfill(6) for x in ctg_val.values()] print(f"{ctg_short}@{top_k}", *ctg_vals_str, sep='\t\t')