jed-tiotuico commited on
Commit
a49ac49
1 Parent(s): 2c7c440

fixed arg error

Browse files
Files changed (1) hide show
  1. handler.py +5 -3
handler.py CHANGED
@@ -165,7 +165,7 @@ class EndpointHandler:
165
  # create inference pipeline
166
 
167
 
168
- def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
169
  """
170
  Args:
171
  data (:obj:):
@@ -175,7 +175,9 @@ class EndpointHandler:
175
  - "label": A string representing what the label/class is. There can be multiple labels.
176
  - "score": A score between 0 and 1 describing how confident the model is for this label/class.
177
  """
178
-
 
 
179
  config = {
180
  "vocab_size": vocab_size,
181
  "embedding_dim": embedding_dim,
@@ -201,7 +203,7 @@ class EndpointHandler:
201
  for i in range(10):
202
  error_rate, generated_text = generate_text_bpe(
203
  model,
204
- start_string=data,
205
  generation_length=seq_length,
206
  temperature=1.2,
207
  top_k=20,
 
165
  # create inference pipeline
166
 
167
 
168
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
169
  """
170
  Args:
171
  data (:obj:):
 
175
  - "label": A string representing what the label/class is. There can be multiple labels.
176
  - "score": A score between 0 and 1 describing how confident the model is for this label/class.
177
  """
178
+ print("data", data)
179
+ inputs = data.pop("inputs", data)
180
+ start_string = inputs[0]
181
  config = {
182
  "vocab_size": vocab_size,
183
  "embedding_dim": embedding_dim,
 
203
  for i in range(10):
204
  error_rate, generated_text = generate_text_bpe(
205
  model,
206
+ start_string=start_string,
207
  generation_length=seq_length,
208
  temperature=1.2,
209
  top_k=20,