music_mlserver_demo / serve-model.py
ramonpzg's picture
Create serve-model.py file
2b2ab2a
from audiocraft.models import MusicGen
from mlserver import MLModel
from mlserver.types import InferenceRequest, InferenceResponse, ResponseOutput
MODEL = "small"
class MusicGenServer(MLModel):
async def load(self):
self.model = MusicGen.get_pretrained(MODEL, device="cuda")
async def predict(self, request: InferenceRequest) -> InferenceResponse:
prompts = request.inputs[0].data["prompts"]
seconds = request.inputs[0].data["duration"]
duration = 5 if seconds < 2 else seconds
self.model.set_generation_params(duration=duration)
wav = self.model.generate(prompts)
shape = list(wav[0].shape)
response_output = ResponseOutput(
name="new_music",
shape=shape,
datatype="FLOAT32",
data=wav[0, 0].cpu().tolist(),
)
return InferenceResponse(model_name="music_model", outputs=[response_output])