from typing import Dict, Any from transformers import BertForQuestionAnswering, BertTokenizer import torch # from scipy.special import softmax # set device # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # def print_tokens_with_ids(tokenizer, input_ids): # # BERT only needs the token IDs, but for the purpose of inspecting the # # tokenizer's behavior, let's also get the token strings and display them. # tokens = tokenizer.convert_ids_to_tokens(input_ids) # # For each token and its id... # for token, id in zip(tokens, input_ids): # # If this is the [SEP] token, add some space around it to make it stand out. # if id == tokenizer.sep_token_id: # print('') # # Print the token string and its ID in two columns. # print('{:<12} {:>6,}'.format(token, id)) # if id == tokenizer.sep_token_id: # print('') def get_segment_ids_aka_token_type_ids(tokenizer, input_ids): # Search the input_ids for the first instance of the `[SEP]` token. sep_index = input_ids.index(tokenizer.sep_token_id) # The number of segment A tokens includes the [SEP] token istelf. num_seg_a = sep_index + 1 # The remainder are segment B. num_seg_b = len(input_ids) - num_seg_a # Construct the list of 0s and 1s. segment_ids = [0]*num_seg_a + [1]*num_seg_b # There should be a segment_id for every input token. assert len(segment_ids) == len(input_ids), \ 'There should be a segment_id for every input token.' return segment_ids def to_model( model: BertForQuestionAnswering, input_ids, segment_ids ) -> tuple: # Run input through the model. output = model( torch.tensor([input_ids]), # The tokens representing our input text. token_type_ids=torch.tensor([segment_ids]) ) # print(output) # print(output.start_logits) # print(output.end_logits) # print(type(output)) # The segment IDs to differentiate question from answer_text return output.start_logits, output.end_logits #output.hidden_states #output.attentions #output.loss def get_answer( start_scores, end_scores, input_ids, tokenizer: BertTokenizer ) -> str: '''Side Note: - It’s a little naive to pick the highest scores for start and end–what if it predicts an end word that’s before the start word?! - The correct implementation is to pick the highest total score for which end >= start. ''' # Find the tokens with the highest `start` and `end` scores. answer_start = torch.argmax(start_scores) answer_end = torch.argmax(end_scores) # Combine the tokens in the answer and print it out. # answer = ' '.join(tokens[answer_start:answer_end + 1]) # Get the string versions of the input tokens. tokens = tokenizer.convert_ids_to_tokens(input_ids) # Start with the first token. answer = tokens[answer_start] # print('Answer: "' + answer + '"') # Select the remaining answer tokens and join them with whitespace. for i in range(answer_start + 1, answer_end + 1): # If it's a subword token, then recombine it with the previous token. if tokens[i][0:2] == '##': answer += tokens[i][2:] # Otherwise, add a space then the token. else: answer += ' ' + tokens[i] return answer # def resonstruct_words(tokens, answer_start, answer_end): # '''reconstruct any words that got broken down into subwords. # ''' # # Start with the first token. # answer = tokens[answer_start] # # Select the remaining answer tokens and join them with whitespace. # for i in range(answer_start + 1, answer_end + 1): # # If it's a subword token, then recombine it with the previous token. # if tokens[i][0:2] == '##': # answer += tokens[i][2:] # # Otherwise, add a space then the token. # else: # answer += ' ' + tokens[i] # print('Answer: "' + answer + '"') class EndpointHandler: def __init__(self, path=""): # self.model = BertForQuestionAnswering.from_pretrained(path).to(device) self.model = BertForQuestionAnswering.from_pretrained(path) self.tokenizer = BertTokenizer.from_pretrained(path) # def __call__(self, data: Dict[str, Any]): # def __call__(self, data: dict[str, Any]) -> dict[str, list[Any]]: def __call__(self, data: dict[str, Any]): """ Args: data (:obj:): includes the context and question """ try: if 'inputs' not in data: raise ValueError('no inputs key in data') i = data.pop("inputs", data) question = i.pop("question", False) context = i.pop("context", False) if question is False and context is False: raise ValueError( f'No question and/or context: question: {question} - context: {context}') input_ids = self.tokenizer.encode(question, context) # print('The input has a total of {:} tokens.'.format(len(input_ids))) segment_ids = get_segment_ids_aka_token_type_ids( self.tokenizer, input_ids ) # run prediction with torch.inference_mode(): start_scores, end_scores = to_model( self.model, input_ids, segment_ids ) answer = get_answer( start_scores, end_scores, input_ids, self.tokenizer ) return answer except Exception as e: raise