from audiocraft.data.audio_utils import convert_audio from audiocraft.data.audio import audio_write from audiocraft.models import MusicGen from typing import Dict, List, Any import logging class EndpointHandler: def __init__(self, path=""): self.model = MusicGen.get_pretrained("musicgen-medium") def __call__(self, data: Dict[str, Any]) -> Dict[str, str]: """ Args: data (:dict:): The payload with the text prompt and generation parameters. """ # process input logging.info(f"data: {data}") inputs = data.pop("inputs", data) parameters = data.pop("parameters", None) self.model.set_generation_params(**parameters) outputs = self.model.generate(**inputs) # postprocess the prediction prediction = outputs[0].cpu().numpy().tolist() return [{"generated_audio": prediction}]