Spaces:
Running
Running
# This file contains the inference code for loading and running the closed-book and open-book QA models | |
import os | |
import csv | |
import glob | |
import gzip | |
import string | |
import sys | |
from typing import List, Tuple, Dict | |
import re | |
import numpy as np | |
import unicodedata | |
import torch | |
from torch import Tensor as T | |
from torch import nn | |
from models import init_biencoder_components | |
from Options_inf import setup_args_gpu, print_args, set_encoder_params_from_state | |
from Faiss_Indexers_inf import DenseIndexer, DenseFlatIndexer | |
from Data_utils_inf import Tensorizer | |
from Model_utils_inf import load_states_from_checkpoint, get_model_obj | |
from transformers import T5ForConditionalGeneration, AutoTokenizer | |
import time | |
from wordsegment import load, segment | |
load() | |
SEGMENTER_CACHE = {} | |
RERANKER_CACHE = {} | |
def setup_closedbook(model_path, ans_tsv_path, dense_embd_path, process_id, model_type): | |
dpr = DPRForCrossword( | |
model_path, | |
ans_tsv_path, | |
dense_embd_path, | |
retrievalmodel = False, | |
process_id=process_id, | |
model_type = model_type | |
) | |
return dpr | |
def setup_t5_reranker(reranker_path, reranker_model_type = 't5-small'): | |
tokenizer = AutoTokenizer.from_pretrained(reranker_model_type) | |
model = T5ForConditionalGeneration.from_pretrained(reranker_path) | |
model.eval().to(torch.device('cuda' if torch.cuda.is_available() else 'cpu')) | |
return model, tokenizer | |
def post_process_clue(clue): | |
clue = preprocess_clue_fn(clue) | |
if clue[-3:] == '. .': | |
clue = clue[:-3] | |
elif clue[-3:] == ' ..': | |
clue = clue[:-3] | |
elif clue[-2:] == '..': | |
clue = clue[:-2] | |
elif clue[-1] == '.': | |
clue = clue[:-1] | |
return clue | |
def t5_reranker_score_with_clue(model, tokenizer, model_type, clues, possibly_ungrammatical_fills): | |
global RERANKER_CACHE | |
results = [] | |
device = model.device | |
fills = possibly_ungrammatical_fills.copy() | |
if model_type == 't5-small': | |
segmented_fills = [] | |
for answer in possibly_ungrammatical_fills: | |
segmented_fills.append(" ".join(segment(answer.lower()))) | |
fills = segmented_fills.copy() | |
for clue, possibly_ungrammatical_fill in zip(clues, fills): | |
# possibly here is where the byt5 failed | |
if not possibly_ungrammatical_fill.islower(): | |
possibly_ungrammatical_fill = possibly_ungrammatical_fill.lower() | |
clue = post_process_clue(clue) | |
if clue + possibly_ungrammatical_fill in RERANKER_CACHE: | |
results.append(RERANKER_CACHE[clue + possibly_ungrammatical_fill]) | |
continue | |
else: | |
with torch.no_grad(), torch.inference_mode(): | |
# move all the input tensors to the GPU (cuda) | |
inputs = tokenizer(["Q: " + clue], return_tensors='pt')['input_ids'].to(device) | |
labels = tokenizer([possibly_ungrammatical_fill], return_tensors='pt')['input_ids'].to(device) | |
# model mode set to evaluation | |
model.eval() | |
loss = model(inputs, labels = labels) | |
answer_length = labels.shape[1] | |
logprob = -loss[0].item() * answer_length | |
results.append(logprob) | |
RERANKER_CACHE[clue + possibly_ungrammatical_fill] = logprob | |
return results | |
def preprocess_clue_fn(clue): | |
clue = str(clue) | |
# https://stackoverflow.com/questions/517923/what-is-the-best-way-to-remove-accents-normalize-in-a-python-unicode-string | |
clue = ''.join(c for c in unicodedata.normalize('NFD', clue) if unicodedata.category(c) != 'Mn') | |
clue = re.sub("\x17|\x18|\x93|\x94|“|”|''|\"\"", "\"", clue) | |
clue = re.sub("\x85|…", "...", clue) | |
clue = re.sub("\x91|\x92|‘|’", "'", clue) | |
clue = re.sub("‚", ",", clue) | |
clue = re.sub("—|–", "-", clue) | |
clue = re.sub("¢", " cents", clue) | |
clue = re.sub("¿|¡|^;|\{|\}", "", clue) | |
clue = re.sub("÷", "division", clue) | |
clue = re.sub("°", " degrees", clue) | |
euro = re.search("^£[0-9]+(,*[0-9]*){0,}| £[0-9]+(,*[0-9]*){0,}", clue) | |
if euro: | |
num = clue[:euro.end()] | |
rest_clue = clue[euro.end():] | |
clue = num + " Euros" + rest_clue | |
clue = re.sub(", Euros", " Euros", clue) | |
clue = re.sub("Euros [Mm]illion", "million Euros", clue) | |
clue = re.sub("Euros [Bb]illion", "billion Euros", clue) | |
clue = re.sub("Euros[Kk]", "K Euros", clue) | |
clue = re.sub(" K Euros", "K Euros", clue) | |
clue = re.sub("£", "", clue) | |
clue = re.sub(" *\(\d{1,},*\)$| *\(\d{1,},* \d{1,}\)$", "", clue) | |
clue = re.sub("&", "&", clue) | |
clue = re.sub("<", "<", clue) | |
clue = re.sub(">", ">", clue) | |
clue = re.sub("e\.g\.|for ex\.", "for example", clue) | |
clue = re.sub(": [Aa]bbreviat\.|: [Aa]bbrev\.|: [Aa]bbrv\.|: [Aa]bbrv|: [Aa]bbr\.|: [Aa]bbr", " abbreviation", clue) | |
clue = re.sub("abbr\.|abbrv\.", "abbreviation", clue) | |
clue = re.sub("Abbr\.|Abbrv\.", "Abbreviation", clue) | |
clue = re.sub("\(anag\.\)|\(anag\)", "(anagram)", clue) | |
clue = re.sub("org\.", "organization", clue) | |
clue = re.sub("Org\.", "Organization", clue) | |
clue = re.sub("Grp\.|Gp\.", "Group", clue) | |
clue = re.sub("grp\.|gp\.", "group", clue) | |
clue = re.sub(": Sp\.", " (Spanish)", clue) | |
clue = re.sub("\(Sp\.\)|Sp\.", "(Spanish)", clue) | |
clue = re.sub("Ave\.", "Avenue", clue) | |
clue = re.sub("Sch\.", "School", clue) | |
clue = re.sub("sch\.", "school", clue) | |
clue = re.sub("Agcy\.", "Agency", clue) | |
clue = re.sub("agcy\.", "agency", clue) | |
clue = re.sub("Co\.", "Company", clue) | |
clue = re.sub("co\.", "company", clue) | |
clue = re.sub("No\.", "Number", clue) | |
clue = re.sub("no\.", "number", clue) | |
clue = re.sub(": [Vv]ar\.", " variable", clue) | |
clue = re.sub("Subj\.", "Subject", clue) | |
clue = re.sub("subj\.", "subject", clue) | |
clue = re.sub("Subjs\.", "Subjects", clue) | |
clue = re.sub("subjs\.", "subjects", clue) | |
theme_clue = re.search("^.+\|[A-Z]{1,}", clue) | |
if theme_clue: | |
clue = re.sub("\|", " | ", clue) | |
if "Partner of" in clue: | |
clue = re.sub("Partner of", "", clue) | |
clue = clue + " and ___" | |
link = re.search("^.+-.+ [Ll]ink$", clue) | |
if link: | |
no_link = re.search("^.+-.+ ", clue) | |
x_y = clue[no_link.start():no_link.end() - 1] | |
x_y_lst = x_y.split("-") | |
clue = x_y_lst[0] + " ___ " + x_y_lst[1] | |
follower = re.search("^.+ [Ff]ollower$", clue) | |
if follower: | |
no_follower = re.search("^.+ ", clue) | |
x = clue[:no_follower.end() - 1] | |
clue = x + " ___" | |
preceder = re.search("^.+ [Pp]receder$", clue) | |
if preceder: | |
no_preceder = re.search("^.+ ", clue) | |
x = clue[:no_preceder.end() - 1] | |
clue = "___ " + x | |
if re.search("--[^A-Za-z]|--$", clue): | |
clue = re.sub("--", "__", clue) | |
if not re.search("_-[A-Za-z]|_-$", clue): | |
clue = re.sub("_-", "__", clue) | |
clue = re.sub("_{2,}", "___", clue) | |
clue = re.sub("\?$", " (wordplay)", clue) | |
nonverbal = re.search("\[[^0-9]+,* *[^0-9]*\]", clue) | |
if nonverbal: | |
clue = re.sub("\[|\]", "", clue) | |
clue = clue + " (nonverbal)" | |
if clue[:4] == "\"\"\" " and clue[-4:] == " \"\"\"": | |
clue = "\"" + clue[4:-4] + "\"" | |
if clue[:4] == "''' " and clue[-4:] == " '''": | |
clue = "'" + clue[4:-4] + "'" | |
if clue[:3] == "\"\"\"" and clue[-3:] == "\"\"\"": | |
clue = "\"" + clue[3:-3] + "\"" | |
if clue[:3] == "'''" and clue[-3:] == "'''": | |
clue = "'" + clue[3:-3] + "'" | |
return clue | |
def answer_clues(dpr, clues, max_answers, output_strings=False): | |
clues = [preprocess_clue_fn(c.rstrip()) for c in clues] | |
outputs = dpr.answer_clues_closedbook(clues, max_answers, output_strings=output_strings) | |
return outputs | |
class DenseRetriever(object): | |
""" | |
Does passage retrieving over the provided index and question encoder | |
""" | |
def __init__( | |
self, | |
question_encoder: nn.Module, | |
batch_size: int, | |
tensorizer: Tensorizer, | |
index: DenseIndexer, | |
device=None, | |
model_type = 'bert' | |
): | |
self.question_encoder = question_encoder | |
self.batch_size = batch_size | |
self.tensorizer = tensorizer | |
self.index = index | |
self.device = device | |
self.model_type = model_type | |
def generate_question_vectors(self, questions: List[str]) -> T: | |
n = len(questions) | |
bsz = self.batch_size | |
query_vectors = [] | |
self.question_encoder.eval() | |
with torch.no_grad(): | |
for j, batch_start in enumerate(range(0, n, bsz)): | |
batch_token_tensors = [ | |
self.tensorizer.text_to_tensor(q) | |
for q in questions[batch_start : batch_start + bsz] | |
] | |
q_ids_batch = torch.stack(batch_token_tensors, dim=0).to(self.device) | |
q_seg_batch = torch.zeros_like(q_ids_batch).to(self.device) | |
# q_attn_mask = self.tensorizer.get_attn_mask(q_ids_batch) | |
q_attn_mask = (q_ids_batch != 0) | |
if self.model_type == 'bert': | |
_, out, _ = self.question_encoder(q_ids_batch, q_seg_batch, q_attn_mask) | |
elif self.model_type == 'distilbert': | |
_, out, _ = self.question_encoder(q_ids_batch, q_attn_mask) | |
query_vectors.extend(out.cpu().split(1, dim=0)) | |
query_tensor = torch.cat(query_vectors, dim=0) | |
print("CLUE Vector Shape", query_tensor.shape) | |
assert query_tensor.size(0) == len(questions) | |
return query_tensor | |
def get_top_docs(self, query_vectors: np.array, top_docs: int = 100) -> List[Tuple[List[object], List[float]]]: | |
""" | |
Does the retrieval of the best matching passages given the query vectors batch | |
:param query_vectors: | |
:param top_docs: | |
:return: | |
""" | |
results = self.index.search_knn(query_vectors, top_docs) | |
return results | |
class FakeRetrieverArgs: | |
"""Used to surpress the existing argparse inside DPR so we can have our own argparse""" | |
def __init__(self): | |
self.do_lower_case = False | |
self.pretrained_model_cfg = None | |
self.encoder_model_type = None | |
self.model_file = None | |
self.projection_dim = 0 | |
self.sequence_length = 512 | |
self.do_fill_lower_case = False | |
self.desegment_valid_fill = False | |
self.no_cuda = True | |
self.local_rank = -1 | |
self.fp16 = False | |
self.fp16_opt_level = "O1" | |
class DPRForCrossword(object): | |
"""Closedbook model for Crossword clue answering""" | |
def __init__( | |
self, | |
model_file, | |
ctx_file, | |
encoded_ctx_file, | |
batch_size = 16, | |
retrievalmodel=False, | |
process_id = 0, | |
model_type = 'bert' | |
): | |
self.retrievalmodel = retrievalmodel # am I a wikipedia retrieval model or a closed-book model | |
args = FakeRetrieverArgs() | |
args.model_file = model_file | |
args.ctx_file = ctx_file | |
args.encoded_ctx_file = encoded_ctx_file | |
args.batch_size = batch_size | |
# self.device = torch.device("cuda:"+str(process_id%torch.cuda.device_count())) | |
self.device = 'cpu' | |
self.model_type = model_type | |
setup_args_gpu(args) | |
saved_state = load_states_from_checkpoint(args.model_file) | |
set_encoder_params_from_state(saved_state.encoder_params, args) | |
tensorizer, encoder, _ = init_biencoder_components(args.encoder_model_type, args, inference_only = True) | |
question_encoder = encoder.question_model | |
question_encoder = question_encoder.to(self.device) | |
question_encoder.eval() | |
# load weights from the model file | |
model_to_load = get_model_obj(question_encoder) | |
prefix_len = len("question_model.") | |
question_encoder_state = { | |
key[prefix_len:]: value | |
for (key, value) in saved_state.model_dict.items() | |
if key.startswith("question_model.") | |
} | |
model_to_load.load_state_dict(question_encoder_state, strict = False) | |
vector_size = model_to_load.get_out_size() | |
index = DenseFlatIndexer(vector_size, 50000) | |
self.retriever = DenseRetriever( | |
question_encoder, | |
args.batch_size, | |
tensorizer, | |
index, | |
self.device, | |
self.model_type | |
) | |
# index all passages | |
embd_file_path = args.encoded_ctx_file | |
if isinstance(embd_file_path, str): | |
file_path = embd_file_path | |
else: | |
file_path = embd_file_path[0] | |
self.retriever.index.index_data(file_path) | |
self.all_passages = self.load_passages(args.ctx_file) | |
self.fill2id = {} | |
for key in self.all_passages.keys(): | |
self.fill2id[ | |
"".join( | |
[ | |
letter | |
for letter in self.all_passages[key][1].upper() | |
if letter in string.ascii_uppercase | |
] | |
) | |
] = key | |
# might as well uppercase and remove non-alphas from the fills before we start to save time later | |
if not retrievalmodel: | |
temp = {} | |
for my_id in self.all_passages.keys(): | |
temp[my_id] = "".join([c.upper() for c in self.all_passages[my_id][1] if c.upper() in string.ascii_uppercase]) | |
self.len_all_passages = len(list(self.all_passages.values())) | |
self.all_passages = temp | |
def load_passages(ctx_file: str) -> Dict[object, Tuple[str, str]]: | |
docs = {} | |
if isinstance(ctx_file, tuple): | |
ctx_file = ctx_file[0] | |
if ctx_file.endswith(".gz"): | |
with gzip.open(ctx_file, "rt") as tsvfile: | |
reader = csv.reader( | |
tsvfile, | |
delimiter="\t", | |
) | |
# file format: doc_id, doc_text, title | |
for row in reader: | |
if row[0] != "id": | |
docs[row[0]] = (row[1], row[2]) | |
else: | |
with open(ctx_file) as tsvfile: | |
reader = csv.reader( | |
tsvfile, | |
delimiter="\t", | |
) | |
# file format: doc_id, doc_text, title | |
for row in reader: | |
if row[0] != "id": | |
docs[row[0]] = (row[1], row[2]) | |
return docs | |
def answer_clues_closedbook(self, questions, max_answers, output_strings=False): | |
# assumes clues are preprocessed | |
assert self.retrievalmodel == False | |
questions_tensor = self.retriever.generate_question_vectors(questions) | |
if max_answers > self.len_all_passages: | |
max_answers = self.len_all_passages | |
start_time = time.time() | |
# get top k results | |
top_ids_and_scores = self.retriever.get_top_docs(questions_tensor.numpy(), max_answers) | |
end_time = time.time() | |
print("\n\nTime taken by FAISS INDEXER: ", end_time - start_time) | |
if not output_strings: | |
return top_ids_and_scores | |
else: | |
# get the string forms | |
all_answers = [] | |
all_scores = [] | |
for ans in top_ids_and_scores: | |
all_answers.append(list(map(self.all_passages.get, ans[0]))) | |
all_scores.append(ans[1]) | |
return all_answers, all_scores | |
def get_wikipedia_docs(self, questions, max_docs): | |
# assumes clues are preprocessed | |
assert self.retrievalmodel | |
questions_tensor = self.retriever.generate_question_vectors(questions) | |
# get top k results. add 2 in case of duplicates (see below | |
top_ids_and_scores = self.retriever.get_top_docs(questions_tensor.numpy(), max_docs + 2) | |
all_paragraphs = [] | |
for ans in top_ids_and_scores: | |
paragraphs = [] | |
for i in range(len(ans[0])): | |
id_ = ans[0][i] | |
id_ = id_.replace("wiki:", "") | |
mydocument = self.all_passages[id_] | |
if mydocument in paragraphs: | |
print("woah, duplicate!!!") | |
continue | |
paragraphs.append(mydocument) | |
all_paragraphs.append(paragraphs[0:max_docs]) | |
return all_paragraphs |