DocumentQA / UnifiedQA /demo_QA.py
Epoching's picture
init
c14d9ad
raw
history blame
No virus
7.06 kB
# Copyright (c) 2022, Lawrence Livermore National Security, LLC.
# All rights reserved.
# See the top-level LICENSE and NOTICE files for details.
# LLNL-CODE-838964
# SPDX-License-Identifier: Apache-2.0-with-LLVM-exception
import sys
import json
from math import ceil
import torch
import numpy as np
from torch import tensor
from torch.nn.functional import log_softmax
from torch.distributions.categorical import Categorical
from transformers import T5Tokenizer, T5ForConditionalGeneration
# load UnifiedQA onto device
model_name = "allenai/unifiedqa-v2-t5-large-1363200"
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)
def get_inputs(contexts_json, ranked_contexts_json):
with open(contexts_json, 'rt') as fp:
contexts = json.load(fp)
with open(ranked_contexts_json, 'rt') as fp:
ranked_contexts = json.load(fp)
question_id = list(ranked_contexts.keys())[0]
# assert len(questions) == 1, f'JSON should only have 1 question but found {len(questions)}: {questions}'
question = ranked_contexts[question_id]['text']
context_ids_sorted = ranked_contexts[question_id]['ranks']
context_scores = ranked_contexts[question_id]['scores']
contexts = [contexts[context_id]['text'] for context_id in context_ids_sorted]
# returns the question (str) and its contexts (sequence)
return question, contexts, context_scores
def get_tokens(text, tokenizer, max_tokens):
return tokenizer.encode_plus(text, return_tensors='pt', max_length=max_tokens, padding='max_length', truncation=True)['input_ids']
def prepare_inputs(tokenizer, max_tokens, context, question):
input_str = f'{question} \\n {context}'
inputs = get_tokens(input_str, tokenizer, max_tokens)
return inputs
def get_outputs(model, tokenizer, input_tokens, max_tokens):
output_dict = model.generate(input_tokens, output_scores=True, return_dict_in_generate=True, **{'max_length': max_tokens})
pred_tokens = output_dict['sequences'].squeeze().tolist()
# initialize metrics
logit_entropy = []
sentence_probs = []
# accumulate metrics over logit_sequence
logit_sequence = output_dict['scores'][:-1] # discard end token
for logit in logit_sequence:
log_probs = log_softmax(logit, dim=-1)
# update metrics
logit_entropy.append(Categorical(log_probs.exp()).entropy())
sentence_probs.append(log_probs.max())
# finish metrics calculation
logit_entropy = tensor(logit_entropy)
sentence_probs = tensor(sentence_probs)
entropy = logit_entropy.mean()
sentence_std = 0 if len(logit_sequence) == 1 else sentence_probs.std(unbiased=True).exp()
# use entropy * sentence_std as uncertainty
uncertainty = (entropy * sentence_std).item()
# convert answer tokens to str
pred_str = tokenizer.decode(pred_tokens, skip_special_tokens=True).lower()
return pred_str, uncertainty
# k_percent: percentage of contexts to use, cannot be less than min_k or greater than max_k
# min_k: minimum number of contexts to use, if possible. Setting this too small reduces recall
# max_k: maximum number of contexts to use. Setting this too big reduces precision
# recommended uncertainty thresholds are 2,3,4, and 5. The lower the threshold, the more aggressive the filtering
def run_model(model, tokenizer, device, question, contexts, context_scores, k_percent=0.1, min_k=10, max_k=25, uncertainty_thresh=3):
k = min(max(ceil(k_percent * len(contexts)), min_k), max_k)
contexts = contexts[:k]
context_scores = context_scores[:k]
# iterate through top-k contexts
answers = []
uncertainty = []
for context in contexts:
input_tokens = prepare_inputs(tokenizer, 512, context, question).to(device)
pred_str, uncertainty_1 = get_outputs(model, tokenizer, input_tokens, 512)
answers.append(pred_str)
uncertainty.append(uncertainty_1)
# contexts = np.array(contexts)
# answers = np.array(answers)
# uncertainty = np.array(uncertainty)
# sort by uncertainty, ascending order
# order = np.argsort(uncertainty)
# contexts = contexts[order]
# answers = answers[order]
# uncertainty = uncertainty[order]
# init lists for threshed answers
# weak_contexts = []
# weak_answers = []
# weak_uncertainty = []
# filter by uncertainty
# if len(answers) > min_k:
# weak = np.argwhere(uncertainty > uncertainty_thresh) # exceeds threshold
# weak_contexts = contexts[weak].tolist()
# weak_answers = answers[weak].tolist()
# weak_uncertainty = uncertainty[weak].tolist()
# strong = np.argwhere(uncertainty <= uncertainty_thresh) # within threshold
# contexts = contexts[strong]
# answers = answers[strong]
# uncertainty = uncertainty[strong]
# contexts = contexts.tolist()
# answers = answers.tolist()
# uncertainty = uncertainty.tolist()
# return {'contexts': contexts, 'answers': answers, 'uncertainty': uncertainty}, \
# {'contexts': weak_contexts, 'answers': weak_answers, 'uncertainty': weak_uncertainty}
return {'contexts': contexts, 'answers': answers, 'context_scores':context_scores, 'uncertainty': uncertainty}
def get_qa_results(contexts_json, ranked_contexts_json, topk):
# extract question and contexts from json
question, contexts, context_scores = get_inputs(contexts_json, ranked_contexts_json)
# infer answers
with torch.inference_mode(True):
# strong_answers, weak_answers = run_model(model, tokenizer, device, question, contexts, k_percent=k_percent)
qa_results = run_model(model, tokenizer, device, question, contexts, context_scores, k_percent=1.0, min_k=1, max_k=topk)
return qa_results
def get_qa_results_in_memory(contexts, ranked_contexts, topk):
question_id = list(ranked_contexts.keys())[0]
# assert len(questions) == 1, f'JSON should only have 1 question but found {len(questions)}: {questions}'
question = ranked_contexts[question_id]['text']
context_ids_sorted = ranked_contexts[question_id]['ranks']
context_scores = ranked_contexts[question_id]['scores']
contexts = [contexts[context_id]['text'] for context_id in context_ids_sorted]
# infer answers
with torch.inference_mode(True):
# strong_answers, weak_answers = run_model(model, tokenizer, device, question, contexts, k_percent=k_percent)
qa_results = run_model(model, tokenizer, device, question, contexts, context_scores, k_percent=1.0, min_k=1, max_k=topk)
return qa_results
def load_custom_model(finetuned_model_path):
global tokenizer
global model
# load UnifiedQA onto device
tokenizer = T5Tokenizer.from_pretrained(finetuned_model_path)
model = T5ForConditionalGeneration.from_pretrained(finetuned_model_path)
model.to(device)
def get_qa_results_in_memory_finetuned_unifiedqa(question, context_scores, contexts, topk):
# infer answers
with torch.inference_mode(True):
# strong_answers, weak_answers = run_model(model, tokenizer, device, question, contexts, k_percent=k_percent)
qa_results = run_model(model, tokenizer, device, question, contexts, context_scores, k_percent=1.0, min_k=1, max_k=topk)
return qa_results