Adapters
Inference Endpoints
JeremyArancio commited on
Commit
adf79f2
1 Parent(s): 88e1248

Update handler

Browse files
Files changed (1) hide show
  1. handler.py +6 -4
handler.py CHANGED
@@ -24,16 +24,18 @@ class EndpointHandler():
24
  """
25
  LOGGER.info(f"Received data: {data}")
26
  # Get inputs
27
- prompt = data.pop("prompt", data)
28
  parameters = data.pop("parameters", None)
 
 
29
  # Preprocess
30
- input = self.tokenizer(prompt, return_tensors="pt")
31
  # Forward
32
  LOGGER.info(f"Start generation.")
33
  if parameters is not None:
34
- output = self.model.generate(**input, **parameters)
35
  else:
36
- output = self.model.generate(**input)
37
  # Postprocess
38
  prediction = self.tokenizer.decode(output[0])
39
  LOGGER.info(f"Generated text: {prediction}")
 
24
  """
25
  LOGGER.info(f"Received data: {data}")
26
  # Get inputs
27
+ prompt = data.pop("prompt", None)
28
  parameters = data.pop("parameters", None)
29
+ if prompt is None:
30
+ raise ValueError("Missing prompt.")
31
  # Preprocess
32
+ inputs = self.tokenizer(prompt, return_tensors="pt")
33
  # Forward
34
  LOGGER.info(f"Start generation.")
35
  if parameters is not None:
36
+ output = self.model.generate(**inputs, **parameters)
37
  else:
38
+ output = self.model.generate(**inputs)
39
  # Postprocess
40
  prediction = self.tokenizer.decode(output[0])
41
  LOGGER.info(f"Generated text: {prediction}")