File size: 7,058 Bytes
c14d9ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
# Copyright (c) 2022, Lawrence Livermore National Security, LLC. 
# All rights reserved.
# See the top-level LICENSE and NOTICE files for details.
# LLNL-CODE-838964

# SPDX-License-Identifier: Apache-2.0-with-LLVM-exception

import sys
import json
from math import ceil

import torch
import numpy as np
from torch import tensor
from torch.nn.functional import log_softmax
from torch.distributions.categorical import Categorical
from transformers import T5Tokenizer, T5ForConditionalGeneration

# load UnifiedQA onto device
model_name = "allenai/unifiedqa-v2-t5-large-1363200"
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)

def get_inputs(contexts_json, ranked_contexts_json):
	with open(contexts_json, 'rt') as fp:
		contexts = json.load(fp)

	with open(ranked_contexts_json, 'rt') as fp:
		ranked_contexts = json.load(fp)

	question_id = list(ranked_contexts.keys())[0]
	# assert len(questions) == 1, f'JSON should only have 1 question but found {len(questions)}: {questions}'
	question = ranked_contexts[question_id]['text']
	context_ids_sorted = ranked_contexts[question_id]['ranks']
	context_scores = ranked_contexts[question_id]['scores']
	contexts = [contexts[context_id]['text'] for context_id in context_ids_sorted]

	# returns the question (str) and its contexts (sequence)
	return question, contexts, context_scores

def get_tokens(text, tokenizer, max_tokens):
	return tokenizer.encode_plus(text, return_tensors='pt', max_length=max_tokens, padding='max_length', truncation=True)['input_ids']

def prepare_inputs(tokenizer, max_tokens, context, question):
	input_str = f'{question} \\n {context}'
	inputs = get_tokens(input_str, tokenizer, max_tokens)
	return inputs

def get_outputs(model, tokenizer, input_tokens, max_tokens):
	output_dict = model.generate(input_tokens, output_scores=True, return_dict_in_generate=True, **{'max_length': max_tokens})
	pred_tokens = output_dict['sequences'].squeeze().tolist()

	# initialize metrics
	logit_entropy = []
	sentence_probs = []

	# accumulate metrics over logit_sequence
	logit_sequence = output_dict['scores'][:-1] # discard end token
	for logit in logit_sequence:
		log_probs = log_softmax(logit, dim=-1)

		# update metrics
		logit_entropy.append(Categorical(log_probs.exp()).entropy())
		sentence_probs.append(log_probs.max())

	# finish metrics calculation
	logit_entropy = tensor(logit_entropy)
	sentence_probs = tensor(sentence_probs)
	entropy = logit_entropy.mean()
	sentence_std = 0 if len(logit_sequence) == 1 else sentence_probs.std(unbiased=True).exp()

	# use entropy * sentence_std as uncertainty
	uncertainty = (entropy * sentence_std).item()

	# convert answer tokens to str
	pred_str = tokenizer.decode(pred_tokens, skip_special_tokens=True).lower()

	return pred_str, uncertainty

# k_percent: percentage of contexts to use, cannot be less than min_k or greater than max_k
# min_k: minimum number of contexts to use, if possible. Setting this too small reduces recall
# max_k: maximum number of contexts to use. Setting this too big reduces precision
# recommended uncertainty thresholds are 2,3,4, and 5. The lower the threshold, the more aggressive the filtering
def run_model(model, tokenizer, device, question, contexts, context_scores, k_percent=0.1, min_k=10, max_k=25, uncertainty_thresh=3):
	k = min(max(ceil(k_percent * len(contexts)), min_k), max_k)
	contexts = contexts[:k]
	context_scores = context_scores[:k]

	# iterate through top-k contexts
	answers = []
	uncertainty = []
	for context in contexts:
		input_tokens = prepare_inputs(tokenizer, 512, context, question).to(device)
		pred_str, uncertainty_1 = get_outputs(model, tokenizer, input_tokens, 512)
		answers.append(pred_str)
		uncertainty.append(uncertainty_1)

	# contexts = np.array(contexts)
	# answers = np.array(answers)
	# uncertainty = np.array(uncertainty)

	# sort by uncertainty, ascending order
	# order = np.argsort(uncertainty)
	# contexts = contexts[order]
	# answers = answers[order]
	# uncertainty = uncertainty[order]

	# init lists for threshed answers
	# weak_contexts = []
	# weak_answers = []
	# weak_uncertainty = []

	# filter by uncertainty
	# if len(answers) > min_k:
		# weak = np.argwhere(uncertainty > uncertainty_thresh) # exceeds threshold
		# weak_contexts = contexts[weak].tolist()
		# weak_answers = answers[weak].tolist()
		# weak_uncertainty = uncertainty[weak].tolist()

		# strong = np.argwhere(uncertainty <= uncertainty_thresh) # within threshold
		# contexts = contexts[strong]
		# answers = answers[strong]
		# uncertainty = uncertainty[strong]

	# contexts = contexts.tolist()
	# answers = answers.tolist()
	# uncertainty = uncertainty.tolist()

	# return {'contexts': contexts, 'answers': answers, 'uncertainty': uncertainty}, \
	# 	{'contexts': weak_contexts, 'answers': weak_answers, 'uncertainty': weak_uncertainty}

	return {'contexts': contexts, 'answers': answers, 'context_scores':context_scores, 'uncertainty': uncertainty}

def get_qa_results(contexts_json, ranked_contexts_json, topk):

	# extract question and contexts from json
	question, contexts, context_scores = get_inputs(contexts_json, ranked_contexts_json)

	# infer answers
	with torch.inference_mode(True):
		# strong_answers, weak_answers = run_model(model, tokenizer, device, question, contexts, k_percent=k_percent)
		qa_results = run_model(model, tokenizer, device, question, contexts, context_scores, k_percent=1.0, min_k=1, max_k=topk)

	return qa_results

def get_qa_results_in_memory(contexts, ranked_contexts, topk):

	question_id = list(ranked_contexts.keys())[0]
	# assert len(questions) == 1, f'JSON should only have 1 question but found {len(questions)}: {questions}'
	question = ranked_contexts[question_id]['text']
	context_ids_sorted = ranked_contexts[question_id]['ranks']
	context_scores = ranked_contexts[question_id]['scores']
	contexts = [contexts[context_id]['text'] for context_id in context_ids_sorted]

	# infer answers
	with torch.inference_mode(True):
		# strong_answers, weak_answers = run_model(model, tokenizer, device, question, contexts, k_percent=k_percent)
		qa_results = run_model(model, tokenizer, device, question, contexts, context_scores, k_percent=1.0, min_k=1, max_k=topk)

	return qa_results

def load_custom_model(finetuned_model_path):
    global tokenizer
    global model
    
    # load UnifiedQA onto device
    tokenizer = T5Tokenizer.from_pretrained(finetuned_model_path)
    model = T5ForConditionalGeneration.from_pretrained(finetuned_model_path)
    model.to(device)

def get_qa_results_in_memory_finetuned_unifiedqa(question, context_scores, contexts, topk):
    
	# infer answers
	with torch.inference_mode(True):
		# strong_answers, weak_answers = run_model(model, tokenizer, device, question, contexts, k_percent=k_percent)
		qa_results = run_model(model, tokenizer, device, question, contexts, context_scores, k_percent=1.0, min_k=1, max_k=topk)

	return qa_results