musicgen-small / handler.py
jauntybrain's picture
Update handler.py
33b11ab
from typing import Dict, List, Any
from transformers import AutoProcessor, MusicgenForConditionalGeneration
import torch
import array
import base64
import io
import wave
import numpy as np
class EndpointHandler:
def __init__(self, path=""):
# load model and processor from path
self.processor = AutoProcessor.from_pretrained(path)
self.model = MusicgenForConditionalGeneration.from_pretrained(path, torch_dtype=torch.float16).to("cuda")
def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
"""
Args:
data (:dict:):
The payload with the text prompt and generation parameters.
"""
# process input
inputs = data.pop("inputs", data)
parameters = data.pop("parameters", None)
# preprocess
inputs = self.processor(
text=[inputs],
padding=True,
return_tensors="pt",).to("cuda")
# pass inputs with all kwargs in data
with torch.autocast("cuda"):
audio_values = self.model.generate(**inputs, do_sample=True, guidance_scale=3, max_new_tokens=400)
# postprocess the prediction
sampling_rate = self.model.config.audio_encoder.sampling_rate
audio_samples = audio_values[0].cpu().numpy()[0].tolist()
audio_samples = [int(min(max(sample * 32000, -32000), 32000)) for sample in audio_samples]
# Create BytesIO object to capture the audio in-memory
audio_io = io.BytesIO()
# Create WAV file
with wave.open(audio_io, 'wb') as wf:
wf.setnchannels(1)
wf.setsampwidth(2) # 2 bytes for 16-bit PCM
wf.setframerate(sampling_rate)
wf.writeframes(array.array('h', audio_samples).tobytes())
audio_base64 = base64.b64encode(audio_io.getvalue()).decode('utf-8')
return [{'sampling_rate': sampling_rate, 'audio': audio_base64}]