# Copyright (c) 2022, Lawrence Livermore National Security, LLC. # All rights reserved. # See the top-level LICENSE and NOTICE files for details. # LLNL-CODE-838964 # SPDX-License-Identifier: Apache-2.0-with-LLVM-exception import sys import json from math import ceil import torch import numpy as np from torch import tensor from torch.nn.functional import log_softmax from torch.distributions.categorical import Categorical from transformers import T5Tokenizer, T5ForConditionalGeneration # load UnifiedQA onto device model_name = "allenai/unifiedqa-v2-t5-large-1363200" tokenizer = T5Tokenizer.from_pretrained(model_name) model = T5ForConditionalGeneration.from_pretrained(model_name) device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") model.to(device) def get_inputs(contexts_json, ranked_contexts_json): with open(contexts_json, 'rt') as fp: contexts = json.load(fp) with open(ranked_contexts_json, 'rt') as fp: ranked_contexts = json.load(fp) question_id = list(ranked_contexts.keys())[0] # assert len(questions) == 1, f'JSON should only have 1 question but found {len(questions)}: {questions}' question = ranked_contexts[question_id]['text'] context_ids_sorted = ranked_contexts[question_id]['ranks'] context_scores = ranked_contexts[question_id]['scores'] contexts = [contexts[context_id]['text'] for context_id in context_ids_sorted] # returns the question (str) and its contexts (sequence) return question, contexts, context_scores def get_tokens(text, tokenizer, max_tokens): return tokenizer.encode_plus(text, return_tensors='pt', max_length=max_tokens, padding='max_length', truncation=True)['input_ids'] def prepare_inputs(tokenizer, max_tokens, context, question): input_str = f'{question} \\n {context}' inputs = get_tokens(input_str, tokenizer, max_tokens) return inputs def get_outputs(model, tokenizer, input_tokens, max_tokens): output_dict = model.generate(input_tokens, output_scores=True, return_dict_in_generate=True, **{'max_length': max_tokens}) pred_tokens = output_dict['sequences'].squeeze().tolist() # initialize metrics logit_entropy = [] sentence_probs = [] # accumulate metrics over logit_sequence logit_sequence = output_dict['scores'][:-1] # discard end token for logit in logit_sequence: log_probs = log_softmax(logit, dim=-1) # update metrics logit_entropy.append(Categorical(log_probs.exp()).entropy()) sentence_probs.append(log_probs.max()) # finish metrics calculation logit_entropy = tensor(logit_entropy) sentence_probs = tensor(sentence_probs) entropy = logit_entropy.mean() sentence_std = 0 if len(logit_sequence) == 1 else sentence_probs.std(unbiased=True).exp() # use entropy * sentence_std as uncertainty uncertainty = (entropy * sentence_std).item() # convert answer tokens to str pred_str = tokenizer.decode(pred_tokens, skip_special_tokens=True).lower() return pred_str, uncertainty # k_percent: percentage of contexts to use, cannot be less than min_k or greater than max_k # min_k: minimum number of contexts to use, if possible. Setting this too small reduces recall # max_k: maximum number of contexts to use. Setting this too big reduces precision # recommended uncertainty thresholds are 2,3,4, and 5. The lower the threshold, the more aggressive the filtering def run_model(model, tokenizer, device, question, contexts, context_scores, k_percent=0.1, min_k=10, max_k=25, uncertainty_thresh=3): k = min(max(ceil(k_percent * len(contexts)), min_k), max_k) contexts = contexts[:k] context_scores = context_scores[:k] # iterate through top-k contexts answers = [] uncertainty = [] for context in contexts: input_tokens = prepare_inputs(tokenizer, 512, context, question).to(device) pred_str, uncertainty_1 = get_outputs(model, tokenizer, input_tokens, 512) answers.append(pred_str) uncertainty.append(uncertainty_1) # contexts = np.array(contexts) # answers = np.array(answers) # uncertainty = np.array(uncertainty) # sort by uncertainty, ascending order # order = np.argsort(uncertainty) # contexts = contexts[order] # answers = answers[order] # uncertainty = uncertainty[order] # init lists for threshed answers # weak_contexts = [] # weak_answers = [] # weak_uncertainty = [] # filter by uncertainty # if len(answers) > min_k: # weak = np.argwhere(uncertainty > uncertainty_thresh) # exceeds threshold # weak_contexts = contexts[weak].tolist() # weak_answers = answers[weak].tolist() # weak_uncertainty = uncertainty[weak].tolist() # strong = np.argwhere(uncertainty <= uncertainty_thresh) # within threshold # contexts = contexts[strong] # answers = answers[strong] # uncertainty = uncertainty[strong] # contexts = contexts.tolist() # answers = answers.tolist() # uncertainty = uncertainty.tolist() # return {'contexts': contexts, 'answers': answers, 'uncertainty': uncertainty}, \ # {'contexts': weak_contexts, 'answers': weak_answers, 'uncertainty': weak_uncertainty} return {'contexts': contexts, 'answers': answers, 'context_scores':context_scores, 'uncertainty': uncertainty} def get_qa_results(contexts_json, ranked_contexts_json, topk): # extract question and contexts from json question, contexts, context_scores = get_inputs(contexts_json, ranked_contexts_json) # infer answers with torch.inference_mode(True): # strong_answers, weak_answers = run_model(model, tokenizer, device, question, contexts, k_percent=k_percent) qa_results = run_model(model, tokenizer, device, question, contexts, context_scores, k_percent=1.0, min_k=1, max_k=topk) return qa_results def get_qa_results_in_memory(contexts, ranked_contexts, topk): question_id = list(ranked_contexts.keys())[0] # assert len(questions) == 1, f'JSON should only have 1 question but found {len(questions)}: {questions}' question = ranked_contexts[question_id]['text'] context_ids_sorted = ranked_contexts[question_id]['ranks'] context_scores = ranked_contexts[question_id]['scores'] contexts = [contexts[context_id]['text'] for context_id in context_ids_sorted] # infer answers with torch.inference_mode(True): # strong_answers, weak_answers = run_model(model, tokenizer, device, question, contexts, k_percent=k_percent) qa_results = run_model(model, tokenizer, device, question, contexts, context_scores, k_percent=1.0, min_k=1, max_k=topk) return qa_results def load_custom_model(finetuned_model_path): global tokenizer global model # load UnifiedQA onto device tokenizer = T5Tokenizer.from_pretrained(finetuned_model_path) model = T5ForConditionalGeneration.from_pretrained(finetuned_model_path) model.to(device) def get_qa_results_in_memory_finetuned_unifiedqa(question, context_scores, contexts, topk): # infer answers with torch.inference_mode(True): # strong_answers, weak_answers = run_model(model, tokenizer, device, question, contexts, k_percent=k_percent) qa_results = run_model(model, tokenizer, device, question, contexts, context_scores, k_percent=1.0, min_k=1, max_k=topk) return qa_results