musicgen-small / handler.py
pbotsaris's picture
added tests and changed handler to respond with an wav file
acb1db0
raw
history blame
No virus
2.1 kB
from typing import Dict, List, Any
from scipy.io import wavfile
from transformers import AutoProcessor, MusicgenForConditionalGeneration
import torch
import io
def create_params(params, fr):
# default
out = { "do_sample": True,
"guidance_scale": 3,
"max_new_tokens": 256
}
has_tokens = False
if params is None:
return out
if 'duration' in params:
out['max_new_tokens'] = params['duration'] * fr
has_tokens = True
for k, p in params.items():
if k in out:
if has_tokens and k == 'max_new_tokens':
continue
out[k] = p
return out
class EndpointHandler:
def __init__(self, path="pbotsaris/musicgen-small"):
# load model and processor
self.processor = AutoProcessor.from_pretrained(path)
self.model = MusicgenForConditionalGeneration.from_pretrained(path, torch_dtype=torch.float16)
self.model.to('cuda')
def __call__(self, data: Dict[str, Any]) -> bytes:
"""
Args:
data (:dict:):
The payload with the text prompt and generation parameters.
Returns: wav file in bytes
"""
inputs = data.pop("inputs", data)
params = data.pop("parameters", None)
inputs = self.processor(
text=[inputs],
padding=True,
return_tensors="pt"
).to('cuda')
params = create_params(params, self.model.config.audio_encoder.frame_rate)
with torch.cuda.amp.autocast():
outputs = self.model.generate(**inputs, **params)
pred = outputs[0].cpu().numpy().tolist()
sr = 32000
try:
sr = self.model.config.audio_encoder.sampling_rate
except:
sr = 32000
# Convert the audio data to WAV format
wav_buffer = io.BytesIO()
wavfile.write(wav_buffer, sr, pred)
# Convert BytesIO to bytes
wav_data = wav_buffer.getvalue()
return wav_data
if __name__ == "__main__":
handler = EndpointHandler()