musicgen-small / handler.py
pbotsaris's picture
removed float16 types from tensor
224fe62
raw
history blame contribute delete
No virus
1.99 kB
from typing import Dict, List, Any
from scipy.io import wavfile
from transformers import AutoProcessor, MusicgenForConditionalGeneration
import torch
import io
import base64
def create_params(params, fr):
# default
out = { "do_sample": True,
"guidance_scale": 3,
"max_new_tokens": 256
}
has_tokens = False
if params is None:
return out
if 'duration' in params:
out['max_new_tokens'] = params['duration'] * fr
has_tokens = True
for k, p in params.items():
if k in out:
if has_tokens and k == 'max_new_tokens':
continue
out[k] = p
return out
class EndpointHandler:
def __init__(self, path="pbotsaris/musicgen-small"):
self.processor = AutoProcessor.from_pretrained(path)
self.model = MusicgenForConditionalGeneration.from_pretrained(path)
self.model.to('cuda:0') #type: ignore
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Args:
data (:dict:):
The payload with the text prompt and generation parameters.
"""
inputs = data.pop("inputs", data)
params = data.pop("parameters", None)
inputs = self.processor(
text=[inputs],
padding=True,
return_tensors="pt"
)
params = create_params(params, self.model.config.audio_encoder.frame_rate) #type: ignore
outputs = self.model.generate(**inputs.to('cuda:0'), **params) #type: ignore
pred = outputs[0, 0].cpu().numpy()
sr = self.model.config.audio_encoder.sampling_rate #type: ignore
wav_buffer = io.BytesIO()
wavfile.write(wav_buffer, rate=sr, data=pred)
wav_data = wav_buffer.getvalue()
base64_encoded_wav = base64.b64encode(wav_data).decode('utf-8')
return [{"audio": base64_encoded_wav, "sr": sr}]
if __name__ == "__main__":
handler = EndpointHandler()