from typing import Dict, List, Any from transformers import AutoProcessor, MusicgenForConditionalGeneration import torch import array import base64 import io import wave import numpy as np class EndpointHandler: def __init__(self, path=""): # load model and processor from path self.processor = AutoProcessor.from_pretrained(path) self.model = MusicgenForConditionalGeneration.from_pretrained(path, torch_dtype=torch.float16).to("cuda") def __call__(self, data: Dict[str, Any]) -> Dict[str, str]: """ Args: data (:dict:): The payload with the text prompt and generation parameters. """ # process input inputs = data.pop("inputs", data) parameters = data.pop("parameters", None) # preprocess inputs = self.processor( text=[inputs], padding=True, return_tensors="pt",).to("cuda") # pass inputs with all kwargs in data with torch.autocast("cuda"): outputs = self.model.generate(**inputs, do_sample=False, max_new_tokens=400) # postprocess the prediction audio_samples = outputs[0].cpu().numpy()[0].tolist() audio_samples = [int(min(max(sample * 32767, -32768), 32767)) for sample in audio_samples] # Create BytesIO object to capture the audio in-memory audio_io = io.BytesIO() # Create WAV file with wave.open(audio_io, 'wb') as wf: wf.setnchannels(1) wf.setsampwidth(2) # 2 bytes for 16-bit PCM wf.setframerate(sampling_rate) wf.writeframes(array.array('h', audio_samples).tobytes()) audio_base64 = base64.b64encode(audio_io.get value()).decode('utf-8') return [{'sampling_rate': sampling_rate, 'audio': audio_base64}]