Spaces:
Runtime error
Runtime error
File size: 7,726 Bytes
729d130 9b1050d 729d130 32c50d0 729d130 32c50d0 729d130 9b1050d 729d130 9b1050d 729d130 32c50d0 729d130 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 |
import json
import copy
import torch
import torch.nn.functional as F
import numpy as np
import faiss
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast
court_text_splitter = "Весь текст судебного документа: "
class FaissDocsDataset(Dataset):
def __init__(self, data):
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
def preprocess_inputs(inputs, device):
return {k: v[:, 0, :].to(device) for k, v in inputs.items()}
def get_subsets_for_db(subsets, data_ids, all_docs):
subsets = [data_ids[ss_name] for ss_name in subsets]
subsets = [x for ss in subsets for x in ss]
all_docs_db = {k: v for k, v in all_docs.items()
if v['id'] in subsets}
unique_refs = set([ref for doc in all_docs_db.values()
for ref, text in doc['added_refs'].items()])
db_data = {ref: text for doc in all_docs_db.values()
for ref, text in doc['added_refs'].items() if ref in unique_refs}
return db_data
def get_subsets_for_qa(subsets, data_ids, all_docs):
subsets = [data_ids[ss_name] for ss_name in subsets]
subsets = [x for ss in subsets for x in ss]
all_docs_qa = {k: v for k, v in all_docs.items()
if v['id'] in subsets}
return all_docs_qa
def filter_db_data_types(text_parts, db_data_in):
filtered_db_data = {}
db_data = copy.deepcopy(db_data_in)
for ref, text in db_data.items():
if any([True for x in text_parts if x in ref]):
filtered_db_data[ref] = text
return filtered_db_data
def filter_qa_data_types(text_parts, all_docs_in):
filtered_all_docs = {}
all_docs = copy.deepcopy(all_docs_in)
for doc_key, doc in all_docs.items():
if not len(doc['added_refs']):
filtered_all_docs[doc_key] = doc
continue
filtered_refs = {}
for ref, text in doc['added_refs'].items():
if any([True for x in text_parts if x in ref]):
filtered_refs[ref] = text
filtered_all_docs[doc_key] = doc
filtered_all_docs[doc_key]['added_refs'] = filtered_refs
return filtered_all_docs
def db_tokenization(filtered_db_data, tokenizer, max_len=510):
index_keys = {}
index_toks = {}
for key_idx, (ref, text) in enumerate(tqdm(filtered_db_data.items(),
desc="Tokenizing DB refs")):
index_keys[key_idx] = ref
text = "passage: " + text
index_toks[key_idx] = tokenizer(text, return_tensors="pt",
padding='max_length', truncation=True,
max_length=max_len)
return index_keys, index_toks
def qa_tokenization(all_docs_qa, tokenizer, max_len=510):
ss_docs = []
for doc in tqdm(all_docs_qa.values(), desc="Tokenizing QA docs"):
text = doc['title'] + '\n' + doc['question']
text = "query: " + text
text = tokenizer(text, return_tensors="pt",
padding='max_length', truncation=True,
max_length=max_len)
ss_docs.append([text, list(doc['added_refs'].keys())])
val_questions = [x[0] for x in ss_docs]
val_refs = {idx: x[1] for idx, x in enumerate(ss_docs)}
return val_questions, val_refs
def query_tokenization(text, tokenizer, max_len=510):
text = "query: " + text
text = tokenizer(text, return_tensors="pt",
padding='max_length', truncation=True,
max_length=max_len)
return text
def query_embed_extraction(tokens, model, do_normalization=True):
model.eval()
device = model.device
with torch.no_grad():
with autocast():
inputs = {k: v[:, :].to(device) for k, v in tokens.items()}
outputs = model(**inputs)
embedding = outputs.last_hidden_state[:, 0].cpu()
if do_normalization:
embedding = F.normalize(embedding, dim=-1)
return embedding.numpy()
def extract_text_embeddings(index_toks, val_questions, model,
do_normalization=True, faiss_batch_size=16):
faiss_dataset = FaissDocsDataset(list(index_toks.values()))
db_data_loader = DataLoader(faiss_dataset, batch_size=faiss_batch_size)
ss_val_dataset = FaissDocsDataset(val_questions)
qu_data_loader = DataLoader(ss_val_dataset, batch_size=faiss_batch_size)
model.eval()
device = model.device
docs_embeds = []
questions_embeds = []
with torch.no_grad():
for batch in tqdm(db_data_loader, desc="db_embeds_extraction"):
with autocast():
outputs = model(**preprocess_inputs(batch, device))
docs_embeds.extend(outputs.last_hidden_state[:, 0].cpu())
for batch in tqdm(qu_data_loader, desc="qu_embeds_extraction"):
with autocast():
outputs = model(**preprocess_inputs(batch, device))
questions_embeds.extend(outputs.last_hidden_state[:, 0].cpu())
docs_embeds_faiss = [torch.unsqueeze(x, 0) for x in docs_embeds]
docs_embeds_faiss = torch.cat(docs_embeds_faiss)
questions_embeds_faiss = [torch.unsqueeze(x, 0) for x in questions_embeds]
questions_embeds_faiss = torch.cat(questions_embeds_faiss)
if do_normalization:
docs_embeds_faiss = F.normalize(docs_embeds_faiss, dim=-1)
questions_embeds_faiss = F.normalize(questions_embeds_faiss, dim=-1)
return docs_embeds_faiss.numpy(), questions_embeds_faiss.numpy()
def filter_ref_parts(ref_dict, filter_parts):
filtered_dict = {}
for k, refs in ref_dict.items():
filtered_refs = [" ".join([x for x in ref.split() if not any([True for part in filter_parts if part in x])])
for ref in refs]
filtered_dict[k] = filtered_refs
return filtered_dict
def get_final_metrics(pred, true, categories, top_k_values,
metrics_func, metrics_func_params):
metrics = {}
for top_k in top_k_values:
ctg_metrics = {}
for ctg in categories:
ctg_pred, ctg_true = get_exact_ctg_data(pred, true, ctg)
metrics_at_k = metrics_func(ctg_pred, ctg_true, top_k, **metrics_func_params)
for mk in metrics_at_k.keys():
metrics_at_k[mk] = round(metrics_at_k[mk] * 100, 6)
ctg_metrics[ctg] = metrics_at_k
metrics[top_k] = ctg_metrics
return metrics
def get_exact_ctg_data(pred_in, true_in, ctg):
if ctg == "all":
return pred_in, true_in
out_pred = {}
out_true = {}
for idx, (pred, true) in zip(true_in.keys(), zip(pred_in.values(), true_in.values())):
ctg_refs_true = [ref for ref in true if ctg in ref]
ctg_refs_pred = [ref for ref in pred if ctg in ref]
out_true[idx] = ctg_refs_true
out_pred[idx] = ctg_refs_pred
return out_pred, out_true
def print_metrics(metrics, ref_categories):
first_ctg = metrics[list(metrics.keys())[0]]
metric_tags = list(first_ctg[list(first_ctg.keys())[0]].keys())
metric_tags = [x.split('@')[0] for x in metric_tags]
print('\t', *metric_tags, sep='\t')
for ctg, ctg_short in ref_categories.items():
for top_k, vals in metrics.items():
for ctg_tag, ctg_val in vals.items():
if ctg_tag == ctg:
ctg_vals_str = ["{:.3f}".format(x).zfill(6) for x in ctg_val.values()]
print(f"{ctg_short}@{top_k}", *ctg_vals_str, sep='\t\t')
|