Alexander Slessor commited on
Commit
b12986e
1 Parent(s): c0a3632

fixed type annotation error

Browse files
Files changed (1) hide show
  1. handler.py +27 -23
handler.py CHANGED
@@ -1,6 +1,7 @@
1
- from typing import Dict, List, Any
2
  from transformers import BertForQuestionAnswering, BertTokenizer
3
  import torch
 
4
 
5
  # set device
6
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -108,34 +109,37 @@ class EndpointHandler:
108
 
109
  def __call__(
110
  self,
111
- data: Dict[str, str | bytes]
112
  ):
113
  """
114
  Args:
115
  data (:obj:):
116
  includes the deserialized image file as PIL.Image
117
  """
118
- question = data.pop("question", data)
119
- context = data.pop("context", data)
 
120
 
121
- input_ids = self.tokenizer.encode(question, context)
122
- # print('The input has a total of {:} tokens.'.format(len(input_ids)))
123
 
124
- segment_ids = get_segment_ids_aka_token_type_ids(
125
- self.tokenizer,
126
- input_ids
127
- )
128
- # run prediction
129
- with torch.inference_mode():
130
- start_scores, end_scores = to_model(
131
- self.model,
132
- input_ids,
133
- segment_ids
134
  )
135
- answer = get_answer(
136
- start_scores,
137
- end_scores,
138
- input_ids,
139
- self.tokenizer
140
- )
141
- return answer
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Any
2
  from transformers import BertForQuestionAnswering, BertTokenizer
3
  import torch
4
+ # from scipy.special import softmax
5
 
6
  # set device
7
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
109
 
110
  def __call__(
111
  self,
112
+ data: Dict[str, Any]
113
  ):
114
  """
115
  Args:
116
  data (:obj:):
117
  includes the deserialized image file as PIL.Image
118
  """
119
+ try:
120
+ question = data.pop("question", data)
121
+ context = data.pop("context", data)
122
 
123
+ input_ids = self.tokenizer.encode(question, context)
124
+ # print('The input has a total of {:} tokens.'.format(len(input_ids)))
125
 
126
+ segment_ids = get_segment_ids_aka_token_type_ids(
127
+ self.tokenizer,
128
+ input_ids
 
 
 
 
 
 
 
129
  )
130
+ # run prediction
131
+ with torch.inference_mode():
132
+ start_scores, end_scores = to_model(
133
+ self.model,
134
+ input_ids,
135
+ segment_ids
136
+ )
137
+ answer = get_answer(
138
+ start_scores,
139
+ end_scores,
140
+ input_ids,
141
+ self.tokenizer
142
+ )
143
+ return answer
144
+ except Exception as e:
145
+ raise