# This file contains the inference code for loading and running the closed-book and open-book QA models import os import csv import glob import gzip import string import sys from typing import List, Tuple, Dict import re import numpy as np import unicodedata import torch from torch import Tensor as T from torch import nn from models import init_biencoder_components from Options_inf import setup_args_gpu, print_args, set_encoder_params_from_state from Faiss_Indexers_inf import DenseIndexer, DenseFlatIndexer from Data_utils_inf import Tensorizer from Model_utils_inf import load_states_from_checkpoint, get_model_obj from transformers import T5ForConditionalGeneration, AutoTokenizer import time from wordsegment import load, segment load() SEGMENTER_CACHE = {} RERANKER_CACHE = {} def setup_closedbook(model_path, ans_tsv_path, dense_embd_path, process_id, model_type): dpr = DPRForCrossword( model_path, ans_tsv_path, dense_embd_path, retrievalmodel = False, process_id=process_id, model_type = model_type ) return dpr def setup_t5_reranker(reranker_path, reranker_model_type = 't5-small'): tokenizer = AutoTokenizer.from_pretrained(reranker_model_type) model = T5ForConditionalGeneration.from_pretrained(reranker_path) model.eval().to(torch.device('cuda' if torch.cuda.is_available() else 'cpu')) return model, tokenizer def post_process_clue(clue): clue = preprocess_clue_fn(clue) if clue[-3:] == '. .': clue = clue[:-3] elif clue[-3:] == ' ..': clue = clue[:-3] elif clue[-2:] == '..': clue = clue[:-2] elif clue[-1] == '.': clue = clue[:-1] return clue def t5_reranker_score_with_clue(model, tokenizer, model_type, clues, possibly_ungrammatical_fills): global RERANKER_CACHE results = [] device = model.device fills = possibly_ungrammatical_fills.copy() if model_type == 't5-small': segmented_fills = [] for answer in possibly_ungrammatical_fills: segmented_fills.append(" ".join(segment(answer.lower()))) fills = segmented_fills.copy() for clue, possibly_ungrammatical_fill in zip(clues, fills): # possibly here is where the byt5 failed if not possibly_ungrammatical_fill.islower(): possibly_ungrammatical_fill = possibly_ungrammatical_fill.lower() clue = post_process_clue(clue) if clue + possibly_ungrammatical_fill in RERANKER_CACHE: results.append(RERANKER_CACHE[clue + possibly_ungrammatical_fill]) continue else: with torch.no_grad(), torch.inference_mode(): # move all the input tensors to the GPU (cuda) inputs = tokenizer(["Q: " + clue], return_tensors='pt')['input_ids'].to(device) labels = tokenizer([possibly_ungrammatical_fill], return_tensors='pt')['input_ids'].to(device) # model mode set to evaluation model.eval() loss = model(inputs, labels = labels) answer_length = labels.shape[1] logprob = -loss[0].item() * answer_length results.append(logprob) RERANKER_CACHE[clue + possibly_ungrammatical_fill] = logprob return results def preprocess_clue_fn(clue): clue = str(clue) # https://stackoverflow.com/questions/517923/what-is-the-best-way-to-remove-accents-normalize-in-a-python-unicode-string clue = ''.join(c for c in unicodedata.normalize('NFD', clue) if unicodedata.category(c) != 'Mn') clue = re.sub("\x17|\x18|\x93|\x94|“|”|''|\"\"", "\"", clue) clue = re.sub("\x85|…", "...", clue) clue = re.sub("\x91|\x92|‘|’", "'", clue) clue = re.sub("‚", ",", clue) clue = re.sub("—|–", "-", clue) clue = re.sub("¢", " cents", clue) clue = re.sub("¿|¡|^;|\{|\}", "", clue) clue = re.sub("÷", "division", clue) clue = re.sub("°", " degrees", clue) euro = re.search("^£[0-9]+(,*[0-9]*){0,}| £[0-9]+(,*[0-9]*){0,}", clue) if euro: num = clue[:euro.end()] rest_clue = clue[euro.end():] clue = num + " Euros" + rest_clue clue = re.sub(", Euros", " Euros", clue) clue = re.sub("Euros [Mm]illion", "million Euros", clue) clue = re.sub("Euros [Bb]illion", "billion Euros", clue) clue = re.sub("Euros[Kk]", "K Euros", clue) clue = re.sub(" K Euros", "K Euros", clue) clue = re.sub("£", "", clue) clue = re.sub(" *\(\d{1,},*\)$| *\(\d{1,},* \d{1,}\)$", "", clue) clue = re.sub("&", "&", clue) clue = re.sub("<", "<", clue) clue = re.sub(">", ">", clue) clue = re.sub("e\.g\.|for ex\.", "for example", clue) clue = re.sub(": [Aa]bbreviat\.|: [Aa]bbrev\.|: [Aa]bbrv\.|: [Aa]bbrv|: [Aa]bbr\.|: [Aa]bbr", " abbreviation", clue) clue = re.sub("abbr\.|abbrv\.", "abbreviation", clue) clue = re.sub("Abbr\.|Abbrv\.", "Abbreviation", clue) clue = re.sub("\(anag\.\)|\(anag\)", "(anagram)", clue) clue = re.sub("org\.", "organization", clue) clue = re.sub("Org\.", "Organization", clue) clue = re.sub("Grp\.|Gp\.", "Group", clue) clue = re.sub("grp\.|gp\.", "group", clue) clue = re.sub(": Sp\.", " (Spanish)", clue) clue = re.sub("\(Sp\.\)|Sp\.", "(Spanish)", clue) clue = re.sub("Ave\.", "Avenue", clue) clue = re.sub("Sch\.", "School", clue) clue = re.sub("sch\.", "school", clue) clue = re.sub("Agcy\.", "Agency", clue) clue = re.sub("agcy\.", "agency", clue) clue = re.sub("Co\.", "Company", clue) clue = re.sub("co\.", "company", clue) clue = re.sub("No\.", "Number", clue) clue = re.sub("no\.", "number", clue) clue = re.sub(": [Vv]ar\.", " variable", clue) clue = re.sub("Subj\.", "Subject", clue) clue = re.sub("subj\.", "subject", clue) clue = re.sub("Subjs\.", "Subjects", clue) clue = re.sub("subjs\.", "subjects", clue) theme_clue = re.search("^.+\|[A-Z]{1,}", clue) if theme_clue: clue = re.sub("\|", " | ", clue) if "Partner of" in clue: clue = re.sub("Partner of", "", clue) clue = clue + " and ___" link = re.search("^.+-.+ [Ll]ink$", clue) if link: no_link = re.search("^.+-.+ ", clue) x_y = clue[no_link.start():no_link.end() - 1] x_y_lst = x_y.split("-") clue = x_y_lst[0] + " ___ " + x_y_lst[1] follower = re.search("^.+ [Ff]ollower$", clue) if follower: no_follower = re.search("^.+ ", clue) x = clue[:no_follower.end() - 1] clue = x + " ___" preceder = re.search("^.+ [Pp]receder$", clue) if preceder: no_preceder = re.search("^.+ ", clue) x = clue[:no_preceder.end() - 1] clue = "___ " + x if re.search("--[^A-Za-z]|--$", clue): clue = re.sub("--", "__", clue) if not re.search("_-[A-Za-z]|_-$", clue): clue = re.sub("_-", "__", clue) clue = re.sub("_{2,}", "___", clue) clue = re.sub("\?$", " (wordplay)", clue) nonverbal = re.search("\[[^0-9]+,* *[^0-9]*\]", clue) if nonverbal: clue = re.sub("\[|\]", "", clue) clue = clue + " (nonverbal)" if clue[:4] == "\"\"\" " and clue[-4:] == " \"\"\"": clue = "\"" + clue[4:-4] + "\"" if clue[:4] == "''' " and clue[-4:] == " '''": clue = "'" + clue[4:-4] + "'" if clue[:3] == "\"\"\"" and clue[-3:] == "\"\"\"": clue = "\"" + clue[3:-3] + "\"" if clue[:3] == "'''" and clue[-3:] == "'''": clue = "'" + clue[3:-3] + "'" return clue def answer_clues(dpr, clues, max_answers, output_strings=False): clues = [preprocess_clue_fn(c.rstrip()) for c in clues] outputs = dpr.answer_clues_closedbook(clues, max_answers, output_strings=output_strings) return outputs class DenseRetriever(object): """ Does passage retrieving over the provided index and question encoder """ def __init__( self, question_encoder: nn.Module, batch_size: int, tensorizer: Tensorizer, index: DenseIndexer, device=None, model_type = 'bert' ): self.question_encoder = question_encoder self.batch_size = batch_size self.tensorizer = tensorizer self.index = index self.device = device self.model_type = model_type def generate_question_vectors(self, questions: List[str]) -> T: n = len(questions) bsz = self.batch_size query_vectors = [] self.question_encoder.eval() with torch.no_grad(): for j, batch_start in enumerate(range(0, n, bsz)): batch_token_tensors = [ self.tensorizer.text_to_tensor(q) for q in questions[batch_start : batch_start + bsz] ] q_ids_batch = torch.stack(batch_token_tensors, dim=0).to(self.device) q_seg_batch = torch.zeros_like(q_ids_batch).to(self.device) # q_attn_mask = self.tensorizer.get_attn_mask(q_ids_batch) q_attn_mask = (q_ids_batch != 0) if self.model_type == 'bert': _, out, _ = self.question_encoder(q_ids_batch, q_seg_batch, q_attn_mask) elif self.model_type == 'distilbert': _, out, _ = self.question_encoder(q_ids_batch, q_attn_mask) query_vectors.extend(out.cpu().split(1, dim=0)) query_tensor = torch.cat(query_vectors, dim=0) print("CLUE Vector Shape", query_tensor.shape) assert query_tensor.size(0) == len(questions) return query_tensor def get_top_docs(self, query_vectors: np.array, top_docs: int = 100) -> List[Tuple[List[object], List[float]]]: """ Does the retrieval of the best matching passages given the query vectors batch :param query_vectors: :param top_docs: :return: """ results = self.index.search_knn(query_vectors, top_docs) return results class FakeRetrieverArgs: """Used to surpress the existing argparse inside DPR so we can have our own argparse""" def __init__(self): self.do_lower_case = False self.pretrained_model_cfg = None self.encoder_model_type = None self.model_file = None self.projection_dim = 0 self.sequence_length = 512 self.do_fill_lower_case = False self.desegment_valid_fill = False self.no_cuda = True self.local_rank = -1 self.fp16 = False self.fp16_opt_level = "O1" class DPRForCrossword(object): """Closedbook model for Crossword clue answering""" def __init__( self, model_file, ctx_file, encoded_ctx_file, batch_size = 16, retrievalmodel=False, process_id = 0, model_type = 'bert' ): self.retrievalmodel = retrievalmodel # am I a wikipedia retrieval model or a closed-book model args = FakeRetrieverArgs() args.model_file = model_file args.ctx_file = ctx_file args.encoded_ctx_file = encoded_ctx_file args.batch_size = batch_size # self.device = torch.device("cuda:"+str(process_id%torch.cuda.device_count())) self.device = 'cpu' self.model_type = model_type setup_args_gpu(args) saved_state = load_states_from_checkpoint(args.model_file) set_encoder_params_from_state(saved_state.encoder_params, args) tensorizer, encoder, _ = init_biencoder_components(args.encoder_model_type, args, inference_only = True) question_encoder = encoder.question_model question_encoder = question_encoder.to(self.device) question_encoder.eval() # load weights from the model file model_to_load = get_model_obj(question_encoder) prefix_len = len("question_model.") question_encoder_state = { key[prefix_len:]: value for (key, value) in saved_state.model_dict.items() if key.startswith("question_model.") } model_to_load.load_state_dict(question_encoder_state, strict = False) vector_size = model_to_load.get_out_size() index = DenseFlatIndexer(vector_size, 50000) self.retriever = DenseRetriever( question_encoder, args.batch_size, tensorizer, index, self.device, self.model_type ) # index all passages embd_file_path = args.encoded_ctx_file if isinstance(embd_file_path, str): file_path = embd_file_path else: file_path = embd_file_path[0] self.retriever.index.index_data(file_path) self.all_passages = self.load_passages(args.ctx_file) self.fill2id = {} for key in self.all_passages.keys(): self.fill2id[ "".join( [ letter for letter in self.all_passages[key][1].upper() if letter in string.ascii_uppercase ] ) ] = key # might as well uppercase and remove non-alphas from the fills before we start to save time later if not retrievalmodel: temp = {} for my_id in self.all_passages.keys(): temp[my_id] = "".join([c.upper() for c in self.all_passages[my_id][1] if c.upper() in string.ascii_uppercase]) self.len_all_passages = len(list(self.all_passages.values())) self.all_passages = temp @staticmethod def load_passages(ctx_file: str) -> Dict[object, Tuple[str, str]]: docs = {} if isinstance(ctx_file, tuple): ctx_file = ctx_file[0] if ctx_file.endswith(".gz"): with gzip.open(ctx_file, "rt") as tsvfile: reader = csv.reader( tsvfile, delimiter="\t", ) # file format: doc_id, doc_text, title for row in reader: if row[0] != "id": docs[row[0]] = (row[1], row[2]) else: with open(ctx_file) as tsvfile: reader = csv.reader( tsvfile, delimiter="\t", ) # file format: doc_id, doc_text, title for row in reader: if row[0] != "id": docs[row[0]] = (row[1], row[2]) return docs def answer_clues_closedbook(self, questions, max_answers, output_strings=False): # assumes clues are preprocessed assert self.retrievalmodel == False questions_tensor = self.retriever.generate_question_vectors(questions) if max_answers > self.len_all_passages: max_answers = self.len_all_passages start_time = time.time() # get top k results top_ids_and_scores = self.retriever.get_top_docs(questions_tensor.numpy(), max_answers) end_time = time.time() print("\n\nTime taken by FAISS INDEXER: ", end_time - start_time) if not output_strings: return top_ids_and_scores else: # get the string forms all_answers = [] all_scores = [] for ans in top_ids_and_scores: all_answers.append(list(map(self.all_passages.get, ans[0]))) all_scores.append(ans[1]) return all_answers, all_scores def get_wikipedia_docs(self, questions, max_docs): # assumes clues are preprocessed assert self.retrievalmodel questions_tensor = self.retriever.generate_question_vectors(questions) # get top k results. add 2 in case of duplicates (see below top_ids_and_scores = self.retriever.get_top_docs(questions_tensor.numpy(), max_docs + 2) all_paragraphs = [] for ans in top_ids_and_scores: paragraphs = [] for i in range(len(ans[0])): id_ = ans[0][i] id_ = id_.replace("wiki:", "") mydocument = self.all_passages[id_] if mydocument in paragraphs: print("woah, duplicate!!!") continue paragraphs.append(mydocument) all_paragraphs.append(paragraphs[0:max_docs]) return all_paragraphs