pbotsaris commited on
Commit
da50f54
1 Parent(s): edf8016

fixed to device for inputs

Browse files
Files changed (1) hide show
  1. handler.py +16 -25
handler.py CHANGED
@@ -81,39 +81,30 @@ class EndpointHandler:
81
  data (:dict:):
82
  The payload with the text prompt and generation parameters.
83
 
84
- Returns: wav file in bytes
85
  """
86
 
87
- # inputs = data.pop("inputs", data)
88
- # params = data.pop("parameters", None)
89
 
90
- # inputs = self.processor(
91
- # text=[inputs],
92
- # padding=True,
93
- # return_tensors="pt"
94
- # ).to('cuda')
95
 
96
- # params = create_params(params, self.model.config.audio_encoder.frame_rate)
97
 
98
- # with torch.cuda.amp.autocast():
99
- # outputs = self.model.generate(**inputs, **params)
100
 
101
- # pred = outputs[0, 0].cpu().numpy()
102
- # sr = 32000
103
 
104
- # try:
105
- # sr = self.model.config.audio_encoder.sampling_rate
 
106
 
107
- # except:
108
- # sr = 32000
109
-
110
- # wav_buffer = io.BytesIO()
111
- # wavfile.write(wav_buffer, rate=sr, data=pred)
112
- # wav_data = wav_buffer.getvalue()
113
-
114
- # base64_encoded_wav = base64.b64encode(wav_data).decode('utf-8')
115
-
116
- base64_encoded_wav = sine_to_base64()
117
  return [{"audio": base64_encoded_wav}]
118
 
119
 
 
81
  data (:dict:):
82
  The payload with the text prompt and generation parameters.
83
 
 
84
  """
85
 
86
+ inputs = data.pop("inputs", data)
87
+ params = data.pop("parameters", None)
88
 
89
+ inputs = self.processor(
90
+ text=[inputs],
91
+ padding=True,
92
+ return_tensors="pt"
93
+ ).to('cuda')
94
 
95
+ params = create_params(params, self.model.config.audio_encoder.frame_rate)
96
 
97
+ with torch.cuda.amp.autocast():
98
+ outputs = self.model.generate(**inputs.to('cuda'), do_sample=True, guidance_scale=3, max_new_tokens=256)
99
 
100
+ pred = outputs[0, 0].cpu().numpy()
101
+ sr = self.model.config.audio_encoder.sampling_rate
102
 
103
+ wav_buffer = io.BytesIO()
104
+ wavfile.write(wav_buffer, rate=sr, data=pred)
105
+ wav_data = wav_buffer.getvalue()
106
 
107
+ base64_encoded_wav = base64.b64encode(wav_data).decode('utf-8')
 
 
 
 
 
 
 
 
 
108
  return [{"audio": base64_encoded_wav}]
109
 
110