Spaces:
Runtime error
Runtime error
from audiocraft.models import MusicGen | |
from mlserver import MLModel | |
from mlserver.types import InferenceRequest, InferenceResponse, ResponseOutput | |
MODEL = "small" | |
class MusicGenServer(MLModel): | |
async def load(self): | |
self.model = MusicGen.get_pretrained(MODEL, device="cuda") | |
async def predict(self, request: InferenceRequest) -> InferenceResponse: | |
prompts = request.inputs[0].data["prompts"] | |
seconds = request.inputs[0].data["duration"] | |
duration = 5 if seconds < 2 else seconds | |
self.model.set_generation_params(duration=duration) | |
wav = self.model.generate(prompts) | |
shape = list(wav[0].shape) | |
response_output = ResponseOutput( | |
name="new_music", | |
shape=shape, | |
datatype="FLOAT32", | |
data=wav[0, 0].cpu().tolist(), | |
) | |
return InferenceResponse(model_name="music_model", outputs=[response_output]) | |