Adapters
Inference Endpoints
JeremyArancio commited on
Commit
2487e56
1 Parent(s): 5a04ada

Update handler

Browse files
Files changed (1) hide show
  1. handler.py +4 -4
handler.py CHANGED
@@ -23,15 +23,15 @@ class EndpointHandler():
23
  """
24
  LOGGER.info(f"Received data: {data}")
25
  # Get inputs
26
- prompt = data.pop("prompt", data)
27
  parameters = data.pop("parameters", None)
28
  # Preprocess
29
- inputs = self.tokenizer(prompt, return_tensors="pt")
30
  # Forward
31
  if parameters is not None:
32
- outputs = self.model.generate(**inputs, **parameters)
33
  else:
34
- outputs = self.model.generate(**inputs)
35
  # Postprocess
36
  prediction = self.tokenizer.decode(outputs[0])
37
  LOGGER.info(f"Generated text: {prediction}")
 
23
  """
24
  LOGGER.info(f"Received data: {data}")
25
  # Get inputs
26
+ inputs = data.pop("inputs", data)
27
  parameters = data.pop("parameters", None)
28
  # Preprocess
29
+ inputs_ids = self.tokenizer(inputs, return_tensors="pt").inputs_ids
30
  # Forward
31
  if parameters is not None:
32
+ outputs = self.model.generate(inputs_ids, **parameters)
33
  else:
34
+ outputs = self.model.generate(inputs_ids)
35
  # Postprocess
36
  prediction = self.tokenizer.decode(outputs[0])
37
  LOGGER.info(f"Generated text: {prediction}")