from audiocraft.models import MusicGen from mlserver import MLModel from mlserver.types import InferenceRequest, InferenceResponse, ResponseOutput MODEL = "small" class MusicGenServer(MLModel): async def load(self): self.model = MusicGen.get_pretrained(MODEL, device="cuda") async def predict(self, request: InferenceRequest) -> InferenceResponse: prompts = request.inputs[0].data["prompts"] seconds = request.inputs[0].data["duration"] duration = 5 if seconds < 2 else seconds self.model.set_generation_params(duration=duration) wav = self.model.generate(prompts) shape = list(wav[0].shape) response_output = ResponseOutput( name="new_music", shape=shape, datatype="FLOAT32", data=wav[0, 0].cpu().tolist(), ) return InferenceResponse(model_name="music_model", outputs=[response_output])