Vaibhav Srivastav commited on
Commit
d319451
1 Parent(s): 6388023
Files changed (1) hide show
  1. handler.py +4 -4
handler.py CHANGED
@@ -22,15 +22,15 @@ class EndpointHandler:
22
  inputs = self.processor(
23
  text=[inputs],
24
  padding=True,
25
- return_tensors="pt",)
26
 
27
  # pass inputs with all kwargs in data
28
  if parameters is not None:
29
- outputs = self.model.generate(inputs, max_new_tokens=256, **parameters)
30
  else:
31
- outputs = self.model.generate(inputs, max_new_tokens=256)
32
 
33
  # postprocess the prediction
34
- prediction = outputs[0].numpy()
35
 
36
  return [{"generated_audio": prediction}]
 
22
  inputs = self.processor(
23
  text=[inputs],
24
  padding=True,
25
+ return_tensors="pt",).to("cuda")
26
 
27
  # pass inputs with all kwargs in data
28
  if parameters is not None:
29
+ outputs = self.model.generate(**inputs, max_new_tokens=256, **parameters)
30
  else:
31
+ outputs = self.model.generate(**inputs, max_new_tokens=256)
32
 
33
  # postprocess the prediction
34
+ prediction = outputs[0].cpu().numpy()
35
 
36
  return [{"generated_audio": prediction}]