from typing import Dict, Any from audiocraft.models import AudioGen # from audiocraft.data.audio import audio_write class EndpointHandler: def __init__(self, path = ""): # Load the AudioGen model # self.model = AudioGen.get_pretrained('facebook/audiogen-medium') self.model = AudioGen.get_pretrained(path) self.model.set_generation_params(duration=5) # Set default duration to 5 seconds def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: """ Args: data (:dict:): The payload with the text prompt and generation parameters. """ # process input inputs = data.pop("inputs", data) parameters = data.pop("parameters", {}) # Update generation parameters if provided if 'duration' in parameters: self.model.set_generation_params(duration=parameters['duration']) # Generate audio from descriptions descriptions = [inputs] wav = self.model.generate(descriptions) # Convert the generated audio to a list format for JSON serialization predictions = [] for idx, one_wav in enumerate(wav): # Save the audio to a file (optional) # audio_write(f'{idx}', one_wav.cpu(), self.model.sample_rate, strategy="loudness", loudness_compressor=True) # Convert the tensor to a list prediction = one_wav.cpu().numpy().tolist() predictions.append(prediction) return {"generated_audio": predictions}