wiki-chat / QuestionAnswer.py
Pennywise881's picture
uploaded code files
9f23e0b
raw
history blame contribute delete
No virus
4.6 kB
import torch
import numpy as np
# # from transformers import AutoTokenizer, AutoModelForQuestionAnswering
class QuestionAnswer:
def __init__(self, data, model, tokenizer, torch_device):
self.max_length = 384
self.doc_stride = 128
self.tokenizer = tokenizer
self.model = model
self.data = data
self.torch_device = torch_device
self.output = None
self.features = None
self.results = None
def get_output_from_model(self):
# data = {'question': question, 'context': context}
with torch.no_grad():
tokenized_data = self.tokenizer(
self.data['question'],
self.data['context'],
truncation='only_second',
max_length=self.max_length,
stride=self.doc_stride,
return_overflowing_tokens=True,
return_offsets_mapping=True,
padding='max_length',
return_tensors='pt'
).to(self.torch_device)
output = self.model(tokenized_data['input_ids'], tokenized_data['attention_mask'])
return output
# print(output.keys())
# print(output['start_logits'].shape)
# print(output['end_logits'].shape)
# print(tokenized_data.keys())
def prepare_features(self, example):
tokenized_example = self.tokenizer(
example['question'],
example['context'],
truncation='only_second',
max_length=self.max_length,
stride=self.doc_stride,
return_overflowing_tokens=True,
return_offsets_mapping=True,
padding='max_length',
)
# sample_mapping = tokenized_example.pop("overflow_to_sample_mapping")
for i in range(len(tokenized_example['input_ids'])):
sequence_ids = tokenized_example.sequence_ids(i)
# print(sequence_ids)
context_index = 1
# sample_index = sample_mapping[i]
tokenized_example["offset_mapping"][i] = [
(o if sequence_ids[k] == context_index else None)
for k, o in enumerate(tokenized_example["offset_mapping"][i])
]
return tokenized_example
def postprocess_qa_predictions(self, data, features, raw_predictions, top_n_answers=5, max_answer_length=30):
all_start_logits, all_end_logits = raw_predictions.start_logits, raw_predictions.end_logits
# print(all_start_logits)
results = []
context = data['context']
# print(len(features['input_ids']))
for i in range(len(features['input_ids'])):
start_logits = all_start_logits[i].cpu().numpy()
end_logits = all_end_logits[i].cpu().numpy()
# print(start_logits)
offset_mapping = features['offset_mapping'][i]
start_indices = np.argsort(start_logits)[-1: -top_n_answers - 1: -1].tolist()
end_indices = np.argsort(end_logits)[-1: -top_n_answers - 1: -1].tolist()
for start_index in start_indices:
for end_index in end_indices:
if (
start_index >= len(offset_mapping)
or end_index >= len(offset_mapping)
or offset_mapping[start_index] is None
or offset_mapping[end_index] is None
or end_index < start_index
or end_index - start_index + 1 > max_answer_length
):
continue
start_char = offset_mapping[start_index][0]
end_char = offset_mapping[end_index][1]
# print(start_logits[start_index])
# print(end_logits[end_index])
score = start_logits[start_index] + end_logits[end_index]
results.append(
{
'score': float('%.*g' % (3, score)),
'text': context[start_char: end_char]
}
)
results = sorted(results, key=lambda x: x["score"], reverse=True)[:top_n_answers]
return results
def get_results(self):
self.output = self.get_output_from_model()
self.features = self.prepare_features(self.data)
self.results = self.postprocess_qa_predictions(self.data, self.features, self.output)
return self.results