Spaces:
Runtime error
Runtime error
File size: 950 Bytes
2b2ab2a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 |
from audiocraft.models import MusicGen
from mlserver import MLModel
from mlserver.types import InferenceRequest, InferenceResponse, ResponseOutput
MODEL = "small"
class MusicGenServer(MLModel):
async def load(self):
self.model = MusicGen.get_pretrained(MODEL, device="cuda")
async def predict(self, request: InferenceRequest) -> InferenceResponse:
prompts = request.inputs[0].data["prompts"]
seconds = request.inputs[0].data["duration"]
duration = 5 if seconds < 2 else seconds
self.model.set_generation_params(duration=duration)
wav = self.model.generate(prompts)
shape = list(wav[0].shape)
response_output = ResponseOutput(
name="new_music",
shape=shape,
datatype="FLOAT32",
data=wav[0, 0].cpu().tolist(),
)
return InferenceResponse(model_name="music_model", outputs=[response_output])
|