File size: 5,153 Bytes
c0a3632 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
from typing import Dict, List, Any
from transformers import BertForQuestionAnswering, BertTokenizer
import torch
# set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# def print_tokens_with_ids(tokenizer, input_ids):
# # BERT only needs the token IDs, but for the purpose of inspecting the
# # tokenizer's behavior, let's also get the token strings and display them.
# tokens = tokenizer.convert_ids_to_tokens(input_ids)
# # For each token and its id...
# for token, id in zip(tokens, input_ids):
# # If this is the [SEP] token, add some space around it to make it stand out.
# if id == tokenizer.sep_token_id:
# print('')
# # Print the token string and its ID in two columns.
# print('{:<12} {:>6,}'.format(token, id))
# if id == tokenizer.sep_token_id:
# print('')
def get_segment_ids_aka_token_type_ids(tokenizer, input_ids):
# Search the input_ids for the first instance of the `[SEP]` token.
sep_index = input_ids.index(tokenizer.sep_token_id)
# The number of segment A tokens includes the [SEP] token istelf.
num_seg_a = sep_index + 1
# The remainder are segment B.
num_seg_b = len(input_ids) - num_seg_a
# Construct the list of 0s and 1s.
segment_ids = [0]*num_seg_a + [1]*num_seg_b
# There should be a segment_id for every input token.
assert len(segment_ids) == len(input_ids), \
'There should be a segment_id for every input token.'
return segment_ids
def to_model(
model: BertForQuestionAnswering,
input_ids,
segment_ids
) -> tuple:
# Run input through the model.
output = model(
torch.tensor([input_ids]), # The tokens representing our input text.
token_type_ids=torch.tensor([segment_ids])
)
# print(output)
# print(output.start_logits)
# print(output.end_logits)
# print(type(output))
# The segment IDs to differentiate question from answer_text
return output.start_logits, output.end_logits
#output.hidden_states
#output.attentions
#output.loss
def get_answer(
start_scores,
end_scores,
input_ids,
tokenizer: BertTokenizer
) -> str:
'''Side Note:
- It’s a little naive to pick the highest scores for start and end–what if it predicts an end word that’s before the start word?!
- The correct implementation is to pick the highest total score for which end >= start.
'''
# Find the tokens with the highest `start` and `end` scores.
answer_start = torch.argmax(start_scores)
answer_end = torch.argmax(end_scores)
# Combine the tokens in the answer and print it out.
# answer = ' '.join(tokens[answer_start:answer_end + 1])
# Get the string versions of the input tokens.
tokens = tokenizer.convert_ids_to_tokens(input_ids)
# Start with the first token.
answer = tokens[answer_start]
# print('Answer: "' + answer + '"')
# Select the remaining answer tokens and join them with whitespace.
for i in range(answer_start + 1, answer_end + 1):
# If it's a subword token, then recombine it with the previous token.
if tokens[i][0:2] == '##':
answer += tokens[i][2:]
# Otherwise, add a space then the token.
else:
answer += ' ' + tokens[i]
return answer
# def resonstruct_words(tokens, answer_start, answer_end):
# '''reconstruct any words that got broken down into subwords.
# '''
# # Start with the first token.
# answer = tokens[answer_start]
# # Select the remaining answer tokens and join them with whitespace.
# for i in range(answer_start + 1, answer_end + 1):
# # If it's a subword token, then recombine it with the previous token.
# if tokens[i][0:2] == '##':
# answer += tokens[i][2:]
# # Otherwise, add a space then the token.
# else:
# answer += ' ' + tokens[i]
# print('Answer: "' + answer + '"')
class EndpointHandler:
def __init__(self, path=""):
self.model = BertForQuestionAnswering.from_pretrained(path).to(device)
self.tokenizer = BertTokenizer.from_pretrained(path)
def __call__(
self,
data: Dict[str, str | bytes]
):
"""
Args:
data (:obj:):
includes the deserialized image file as PIL.Image
"""
question = data.pop("question", data)
context = data.pop("context", data)
input_ids = self.tokenizer.encode(question, context)
# print('The input has a total of {:} tokens.'.format(len(input_ids)))
segment_ids = get_segment_ids_aka_token_type_ids(
self.tokenizer,
input_ids
)
# run prediction
with torch.inference_mode():
start_scores, end_scores = to_model(
self.model,
input_ids,
segment_ids
)
answer = get_answer(
start_scores,
end_scores,
input_ids,
self.tokenizer
)
return answer
|