File size: 5,795 Bytes
b12986e
c0a3632
 
b12986e
c0a3632
 
a77ffa1
c0a3632
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a77ffa1
 
c0a3632
 
edeeba1
 
 
c0a3632
 
 
edeeba1
c0a3632
b12986e
edeeba1
 
 
 
 
 
 
 
 
 
c0a3632
b12986e
 
c0a3632
b12986e
 
 
c0a3632
b12986e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
from typing import Dict, Any
from transformers import BertForQuestionAnswering, BertTokenizer
import torch
# from scipy.special import softmax

# set device
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# def print_tokens_with_ids(tokenizer, input_ids):
#     # BERT only needs the token IDs, but for the purpose of inspecting the 
#     # tokenizer's behavior, let's also get the token strings and display them.
#     tokens = tokenizer.convert_ids_to_tokens(input_ids)
#     # For each token and its id...
#     for token, id in zip(tokens, input_ids):
#         # If this is the [SEP] token, add some space around it to make it stand out.
#         if id == tokenizer.sep_token_id:
#             print('')
#         # Print the token string and its ID in two columns.
#         print('{:<12} {:>6,}'.format(token, id))
#         if id == tokenizer.sep_token_id:
#             print('')

def get_segment_ids_aka_token_type_ids(tokenizer, input_ids):
    # Search the input_ids for the first instance of the `[SEP]` token.
    sep_index = input_ids.index(tokenizer.sep_token_id)
    # The number of segment A tokens includes the [SEP] token istelf.
    num_seg_a = sep_index + 1
    # The remainder are segment B.
    num_seg_b = len(input_ids) - num_seg_a
    # Construct the list of 0s and 1s.
    segment_ids = [0]*num_seg_a + [1]*num_seg_b
    # There should be a segment_id for every input token.
    assert len(segment_ids) == len(input_ids), \
        'There should be a segment_id for every input token.'
    return segment_ids

def to_model(
    model: BertForQuestionAnswering, 
    input_ids, 
    segment_ids
) -> tuple:
    # Run input through the model.
    output = model(
        torch.tensor([input_ids]), # The tokens representing our input text.
        token_type_ids=torch.tensor([segment_ids])
    )
    # print(output)
    # print(output.start_logits)
    # print(output.end_logits)
    # print(type(output))
    # The segment IDs to differentiate question from answer_text
    return output.start_logits, output.end_logits
    #output.hidden_states
    #output.attentions
    #output.loss

def get_answer(
    start_scores, 
    end_scores, 
    input_ids,
    tokenizer: BertTokenizer
) -> str:
    '''Side Note: 
    - It’s a little naive to pick the highest scores for start and end–what if it predicts an end word that’s before the start word?! 
    - The correct implementation is to pick the highest total score for which end >= start.
    '''
    # Find the tokens with the highest `start` and `end` scores.
    answer_start = torch.argmax(start_scores)
    answer_end = torch.argmax(end_scores)

    # Combine the tokens in the answer and print it out.
    # answer = ' '.join(tokens[answer_start:answer_end + 1])
    # Get the string versions of the input tokens.
    tokens = tokenizer.convert_ids_to_tokens(input_ids)
    # Start with the first token.
    answer = tokens[answer_start]
    # print('Answer: "' + answer + '"')
    # Select the remaining answer tokens and join them with whitespace.
    for i in range(answer_start + 1, answer_end + 1):
        # If it's a subword token, then recombine it with the previous token.
        if tokens[i][0:2] == '##':
            answer += tokens[i][2:]
        # Otherwise, add a space then the token.
        else:
            answer += ' ' + tokens[i]
    return answer


# def resonstruct_words(tokens, answer_start, answer_end):
#     '''reconstruct any words that got broken down into subwords.
#     '''
#     # Start with the first token.
#     answer = tokens[answer_start]
#     # Select the remaining answer tokens and join them with whitespace.
#     for i in range(answer_start + 1, answer_end + 1):
#         # If it's a subword token, then recombine it with the previous token.
#         if tokens[i][0:2] == '##':
#             answer += tokens[i][2:]
#         # Otherwise, add a space then the token.
#         else:
#             answer += ' ' + tokens[i]
#     print('Answer: "' + answer + '"')


class EndpointHandler:
    def __init__(self, path=""):
        # self.model = BertForQuestionAnswering.from_pretrained(path).to(device)
        self.model = BertForQuestionAnswering.from_pretrained(path)
        self.tokenizer = BertTokenizer.from_pretrained(path)

    # def __call__(self, data: Dict[str, Any]):
    # def __call__(self, data: dict[str, Any]) -> dict[str, list[Any]]:
    def __call__(self, data: dict[str, Any]):
        """
        Args:
            data (:obj:):
                includes the context and question
        """
        try:
            if 'inputs' not in data:
                raise ValueError('no inputs key in data')

            i = data.pop("inputs", data)
            question = i.pop("question", False)
            context = i.pop("context", False)

            if question is False and context is False:
                raise ValueError(
                    f'No question and/or context: question: {question} - context: {context}')

            input_ids = self.tokenizer.encode(question, context)
            # print('The input has a total of {:} tokens.'.format(len(input_ids)))

            segment_ids = get_segment_ids_aka_token_type_ids(
                self.tokenizer, 
                input_ids
            )
            # run prediction
            with torch.inference_mode():
                start_scores, end_scores = to_model(
                    self.model, 
                    input_ids, 
                    segment_ids
                )
                answer = get_answer(
                    start_scores, 
                    end_scores, 
                    input_ids,
                    self.tokenizer
                )
            return answer
        except Exception as e:
            raise