File size: 1,990 Bytes
18694be acb1db0 18694be acb1db0 ee4f8f7 ceb1cc0 01d132b 18694be fb92df5 18694be 224fe62 3ce0400 18694be 3ce0400 18694be acb1db0 18694be da50f54 18694be da50f54 e3598b0 da50f54 3ce0400 18694be 3ce0400 ceb1cc0 e3598b0 18694be da50f54 3ce0400 ceb1cc0 da50f54 18694be da50f54 e3598b0 fb92df5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
from typing import Dict, List, Any
from scipy.io import wavfile
from transformers import AutoProcessor, MusicgenForConditionalGeneration
import torch
import io
import base64
def create_params(params, fr):
# default
out = { "do_sample": True,
"guidance_scale": 3,
"max_new_tokens": 256
}
has_tokens = False
if params is None:
return out
if 'duration' in params:
out['max_new_tokens'] = params['duration'] * fr
has_tokens = True
for k, p in params.items():
if k in out:
if has_tokens and k == 'max_new_tokens':
continue
out[k] = p
return out
class EndpointHandler:
def __init__(self, path="pbotsaris/musicgen-small"):
self.processor = AutoProcessor.from_pretrained(path)
self.model = MusicgenForConditionalGeneration.from_pretrained(path)
self.model.to('cuda:0') #type: ignore
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Args:
data (:dict:):
The payload with the text prompt and generation parameters.
"""
inputs = data.pop("inputs", data)
params = data.pop("parameters", None)
inputs = self.processor(
text=[inputs],
padding=True,
return_tensors="pt"
)
params = create_params(params, self.model.config.audio_encoder.frame_rate) #type: ignore
outputs = self.model.generate(**inputs.to('cuda:0'), **params) #type: ignore
pred = outputs[0, 0].cpu().numpy()
sr = self.model.config.audio_encoder.sampling_rate #type: ignore
wav_buffer = io.BytesIO()
wavfile.write(wav_buffer, rate=sr, data=pred)
wav_data = wav_buffer.getvalue()
base64_encoded_wav = base64.b64encode(wav_data).decode('utf-8')
return [{"audio": base64_encoded_wav, "sr": sr}]
if __name__ == "__main__":
handler = EndpointHandler()
|