EZ-Crossword / Models_inf.py
Ujjwal123's picture
second pass model integrated
076da67
# This file contains the inference code for loading and running the closed-book and open-book QA models
import os
import csv
import glob
import gzip
import string
import sys
from typing import List, Tuple, Dict
import re
import numpy as np
import unicodedata
import torch
from torch import Tensor as T
from torch import nn
from models import init_biencoder_components
from Options_inf import setup_args_gpu, print_args, set_encoder_params_from_state
from Faiss_Indexers_inf import DenseIndexer, DenseFlatIndexer
from Data_utils_inf import Tensorizer
from Model_utils_inf import load_states_from_checkpoint, get_model_obj
from transformers import T5ForConditionalGeneration, AutoTokenizer
import time
from wordsegment import load, segment
load()
SEGMENTER_CACHE = {}
RERANKER_CACHE = {}
def setup_closedbook(model_path, ans_tsv_path, dense_embd_path, process_id, model_type):
dpr = DPRForCrossword(
model_path,
ans_tsv_path,
dense_embd_path,
retrievalmodel = False,
process_id=process_id,
model_type = model_type
)
return dpr
def setup_t5_reranker(reranker_path, reranker_model_type = 't5-small'):
tokenizer = AutoTokenizer.from_pretrained(reranker_model_type)
model = T5ForConditionalGeneration.from_pretrained(reranker_path)
model.eval().to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
return model, tokenizer
def post_process_clue(clue):
clue = preprocess_clue_fn(clue)
if clue[-3:] == '. .':
clue = clue[:-3]
elif clue[-3:] == ' ..':
clue = clue[:-3]
elif clue[-2:] == '..':
clue = clue[:-2]
elif clue[-1] == '.':
clue = clue[:-1]
return clue
def t5_reranker_score_with_clue(model, tokenizer, model_type, clues, possibly_ungrammatical_fills):
global RERANKER_CACHE
results = []
device = model.device
fills = possibly_ungrammatical_fills.copy()
if model_type == 't5-small':
segmented_fills = []
for answer in possibly_ungrammatical_fills:
segmented_fills.append(" ".join(segment(answer.lower())))
fills = segmented_fills.copy()
for clue, possibly_ungrammatical_fill in zip(clues, fills):
# possibly here is where the byt5 failed
if not possibly_ungrammatical_fill.islower():
possibly_ungrammatical_fill = possibly_ungrammatical_fill.lower()
clue = post_process_clue(clue)
if clue + possibly_ungrammatical_fill in RERANKER_CACHE:
results.append(RERANKER_CACHE[clue + possibly_ungrammatical_fill])
continue
else:
with torch.no_grad(), torch.inference_mode():
# move all the input tensors to the GPU (cuda)
inputs = tokenizer(["Q: " + clue], return_tensors='pt')['input_ids'].to(device)
labels = tokenizer([possibly_ungrammatical_fill], return_tensors='pt')['input_ids'].to(device)
# model mode set to evaluation
model.eval()
loss = model(inputs, labels = labels)
answer_length = labels.shape[1]
logprob = -loss[0].item() * answer_length
results.append(logprob)
RERANKER_CACHE[clue + possibly_ungrammatical_fill] = logprob
return results
def preprocess_clue_fn(clue):
clue = str(clue)
# https://stackoverflow.com/questions/517923/what-is-the-best-way-to-remove-accents-normalize-in-a-python-unicode-string
clue = ''.join(c for c in unicodedata.normalize('NFD', clue) if unicodedata.category(c) != 'Mn')
clue = re.sub("\x17|\x18|\x93|\x94|“|”|''|\"\"", "\"", clue)
clue = re.sub("\x85|…", "...", clue)
clue = re.sub("\x91|\x92|‘|’", "'", clue)
clue = re.sub("‚", ",", clue)
clue = re.sub("—|–", "-", clue)
clue = re.sub("¢", " cents", clue)
clue = re.sub("¿|¡|^;|\{|\}", "", clue)
clue = re.sub("÷", "division", clue)
clue = re.sub("°", " degrees", clue)
euro = re.search("^£[0-9]+(,*[0-9]*){0,}| £[0-9]+(,*[0-9]*){0,}", clue)
if euro:
num = clue[:euro.end()]
rest_clue = clue[euro.end():]
clue = num + " Euros" + rest_clue
clue = re.sub(", Euros", " Euros", clue)
clue = re.sub("Euros [Mm]illion", "million Euros", clue)
clue = re.sub("Euros [Bb]illion", "billion Euros", clue)
clue = re.sub("Euros[Kk]", "K Euros", clue)
clue = re.sub(" K Euros", "K Euros", clue)
clue = re.sub("£", "", clue)
clue = re.sub(" *\(\d{1,},*\)$| *\(\d{1,},* \d{1,}\)$", "", clue)
clue = re.sub("&", "&", clue)
clue = re.sub("&lt;", "<", clue)
clue = re.sub("&gt;", ">", clue)
clue = re.sub("e\.g\.|for ex\.", "for example", clue)
clue = re.sub(": [Aa]bbreviat\.|: [Aa]bbrev\.|: [Aa]bbrv\.|: [Aa]bbrv|: [Aa]bbr\.|: [Aa]bbr", " abbreviation", clue)
clue = re.sub("abbr\.|abbrv\.", "abbreviation", clue)
clue = re.sub("Abbr\.|Abbrv\.", "Abbreviation", clue)
clue = re.sub("\(anag\.\)|\(anag\)", "(anagram)", clue)
clue = re.sub("org\.", "organization", clue)
clue = re.sub("Org\.", "Organization", clue)
clue = re.sub("Grp\.|Gp\.", "Group", clue)
clue = re.sub("grp\.|gp\.", "group", clue)
clue = re.sub(": Sp\.", " (Spanish)", clue)
clue = re.sub("\(Sp\.\)|Sp\.", "(Spanish)", clue)
clue = re.sub("Ave\.", "Avenue", clue)
clue = re.sub("Sch\.", "School", clue)
clue = re.sub("sch\.", "school", clue)
clue = re.sub("Agcy\.", "Agency", clue)
clue = re.sub("agcy\.", "agency", clue)
clue = re.sub("Co\.", "Company", clue)
clue = re.sub("co\.", "company", clue)
clue = re.sub("No\.", "Number", clue)
clue = re.sub("no\.", "number", clue)
clue = re.sub(": [Vv]ar\.", " variable", clue)
clue = re.sub("Subj\.", "Subject", clue)
clue = re.sub("subj\.", "subject", clue)
clue = re.sub("Subjs\.", "Subjects", clue)
clue = re.sub("subjs\.", "subjects", clue)
theme_clue = re.search("^.+\|[A-Z]{1,}", clue)
if theme_clue:
clue = re.sub("\|", " | ", clue)
if "Partner of" in clue:
clue = re.sub("Partner of", "", clue)
clue = clue + " and ___"
link = re.search("^.+-.+ [Ll]ink$", clue)
if link:
no_link = re.search("^.+-.+ ", clue)
x_y = clue[no_link.start():no_link.end() - 1]
x_y_lst = x_y.split("-")
clue = x_y_lst[0] + " ___ " + x_y_lst[1]
follower = re.search("^.+ [Ff]ollower$", clue)
if follower:
no_follower = re.search("^.+ ", clue)
x = clue[:no_follower.end() - 1]
clue = x + " ___"
preceder = re.search("^.+ [Pp]receder$", clue)
if preceder:
no_preceder = re.search("^.+ ", clue)
x = clue[:no_preceder.end() - 1]
clue = "___ " + x
if re.search("--[^A-Za-z]|--$", clue):
clue = re.sub("--", "__", clue)
if not re.search("_-[A-Za-z]|_-$", clue):
clue = re.sub("_-", "__", clue)
clue = re.sub("_{2,}", "___", clue)
clue = re.sub("\?$", " (wordplay)", clue)
nonverbal = re.search("\[[^0-9]+,* *[^0-9]*\]", clue)
if nonverbal:
clue = re.sub("\[|\]", "", clue)
clue = clue + " (nonverbal)"
if clue[:4] == "\"\"\" " and clue[-4:] == " \"\"\"":
clue = "\"" + clue[4:-4] + "\""
if clue[:4] == "''' " and clue[-4:] == " '''":
clue = "'" + clue[4:-4] + "'"
if clue[:3] == "\"\"\"" and clue[-3:] == "\"\"\"":
clue = "\"" + clue[3:-3] + "\""
if clue[:3] == "'''" and clue[-3:] == "'''":
clue = "'" + clue[3:-3] + "'"
return clue
def answer_clues(dpr, clues, max_answers, output_strings=False):
clues = [preprocess_clue_fn(c.rstrip()) for c in clues]
outputs = dpr.answer_clues_closedbook(clues, max_answers, output_strings=output_strings)
return outputs
class DenseRetriever(object):
"""
Does passage retrieving over the provided index and question encoder
"""
def __init__(
self,
question_encoder: nn.Module,
batch_size: int,
tensorizer: Tensorizer,
index: DenseIndexer,
device=None,
model_type = 'bert'
):
self.question_encoder = question_encoder
self.batch_size = batch_size
self.tensorizer = tensorizer
self.index = index
self.device = device
self.model_type = model_type
def generate_question_vectors(self, questions: List[str]) -> T:
n = len(questions)
bsz = self.batch_size
query_vectors = []
self.question_encoder.eval()
with torch.no_grad():
for j, batch_start in enumerate(range(0, n, bsz)):
batch_token_tensors = [
self.tensorizer.text_to_tensor(q)
for q in questions[batch_start : batch_start + bsz]
]
q_ids_batch = torch.stack(batch_token_tensors, dim=0).to(self.device)
q_seg_batch = torch.zeros_like(q_ids_batch).to(self.device)
# q_attn_mask = self.tensorizer.get_attn_mask(q_ids_batch)
q_attn_mask = (q_ids_batch != 0)
if self.model_type == 'bert':
_, out, _ = self.question_encoder(q_ids_batch, q_seg_batch, q_attn_mask)
elif self.model_type == 'distilbert':
_, out, _ = self.question_encoder(q_ids_batch, q_attn_mask)
query_vectors.extend(out.cpu().split(1, dim=0))
query_tensor = torch.cat(query_vectors, dim=0)
print("CLUE Vector Shape", query_tensor.shape)
assert query_tensor.size(0) == len(questions)
return query_tensor
def get_top_docs(self, query_vectors: np.array, top_docs: int = 100) -> List[Tuple[List[object], List[float]]]:
"""
Does the retrieval of the best matching passages given the query vectors batch
:param query_vectors:
:param top_docs:
:return:
"""
results = self.index.search_knn(query_vectors, top_docs)
return results
class FakeRetrieverArgs:
"""Used to surpress the existing argparse inside DPR so we can have our own argparse"""
def __init__(self):
self.do_lower_case = False
self.pretrained_model_cfg = None
self.encoder_model_type = None
self.model_file = None
self.projection_dim = 0
self.sequence_length = 512
self.do_fill_lower_case = False
self.desegment_valid_fill = False
self.no_cuda = True
self.local_rank = -1
self.fp16 = False
self.fp16_opt_level = "O1"
class DPRForCrossword(object):
"""Closedbook model for Crossword clue answering"""
def __init__(
self,
model_file,
ctx_file,
encoded_ctx_file,
batch_size = 16,
retrievalmodel=False,
process_id = 0,
model_type = 'bert'
):
self.retrievalmodel = retrievalmodel # am I a wikipedia retrieval model or a closed-book model
args = FakeRetrieverArgs()
args.model_file = model_file
args.ctx_file = ctx_file
args.encoded_ctx_file = encoded_ctx_file
args.batch_size = batch_size
# self.device = torch.device("cuda:"+str(process_id%torch.cuda.device_count()))
self.device = 'cpu'
self.model_type = model_type
setup_args_gpu(args)
saved_state = load_states_from_checkpoint(args.model_file)
set_encoder_params_from_state(saved_state.encoder_params, args)
tensorizer, encoder, _ = init_biencoder_components(args.encoder_model_type, args, inference_only = True)
question_encoder = encoder.question_model
question_encoder = question_encoder.to(self.device)
question_encoder.eval()
# load weights from the model file
model_to_load = get_model_obj(question_encoder)
prefix_len = len("question_model.")
question_encoder_state = {
key[prefix_len:]: value
for (key, value) in saved_state.model_dict.items()
if key.startswith("question_model.")
}
model_to_load.load_state_dict(question_encoder_state, strict = False)
vector_size = model_to_load.get_out_size()
index = DenseFlatIndexer(vector_size, 50000)
self.retriever = DenseRetriever(
question_encoder,
args.batch_size,
tensorizer,
index,
self.device,
self.model_type
)
# index all passages
embd_file_path = args.encoded_ctx_file
if isinstance(embd_file_path, str):
file_path = embd_file_path
else:
file_path = embd_file_path[0]
self.retriever.index.index_data(file_path)
self.all_passages = self.load_passages(args.ctx_file)
self.fill2id = {}
for key in self.all_passages.keys():
self.fill2id[
"".join(
[
letter
for letter in self.all_passages[key][1].upper()
if letter in string.ascii_uppercase
]
)
] = key
# might as well uppercase and remove non-alphas from the fills before we start to save time later
if not retrievalmodel:
temp = {}
for my_id in self.all_passages.keys():
temp[my_id] = "".join([c.upper() for c in self.all_passages[my_id][1] if c.upper() in string.ascii_uppercase])
self.len_all_passages = len(list(self.all_passages.values()))
self.all_passages = temp
@staticmethod
def load_passages(ctx_file: str) -> Dict[object, Tuple[str, str]]:
docs = {}
if isinstance(ctx_file, tuple):
ctx_file = ctx_file[0]
if ctx_file.endswith(".gz"):
with gzip.open(ctx_file, "rt") as tsvfile:
reader = csv.reader(
tsvfile,
delimiter="\t",
)
# file format: doc_id, doc_text, title
for row in reader:
if row[0] != "id":
docs[row[0]] = (row[1], row[2])
else:
with open(ctx_file) as tsvfile:
reader = csv.reader(
tsvfile,
delimiter="\t",
)
# file format: doc_id, doc_text, title
for row in reader:
if row[0] != "id":
docs[row[0]] = (row[1], row[2])
return docs
def answer_clues_closedbook(self, questions, max_answers, output_strings=False):
# assumes clues are preprocessed
assert self.retrievalmodel == False
questions_tensor = self.retriever.generate_question_vectors(questions)
if max_answers > self.len_all_passages:
max_answers = self.len_all_passages
start_time = time.time()
# get top k results
top_ids_and_scores = self.retriever.get_top_docs(questions_tensor.numpy(), max_answers)
end_time = time.time()
print("\n\nTime taken by FAISS INDEXER: ", end_time - start_time)
if not output_strings:
return top_ids_and_scores
else:
# get the string forms
all_answers = []
all_scores = []
for ans in top_ids_and_scores:
all_answers.append(list(map(self.all_passages.get, ans[0])))
all_scores.append(ans[1])
return all_answers, all_scores
def get_wikipedia_docs(self, questions, max_docs):
# assumes clues are preprocessed
assert self.retrievalmodel
questions_tensor = self.retriever.generate_question_vectors(questions)
# get top k results. add 2 in case of duplicates (see below
top_ids_and_scores = self.retriever.get_top_docs(questions_tensor.numpy(), max_docs + 2)
all_paragraphs = []
for ans in top_ids_and_scores:
paragraphs = []
for i in range(len(ans[0])):
id_ = ans[0][i]
id_ = id_.replace("wiki:", "")
mydocument = self.all_passages[id_]
if mydocument in paragraphs:
print("woah, duplicate!!!")
continue
paragraphs.append(mydocument)
all_paragraphs.append(paragraphs[0:max_docs])
return all_paragraphs