File size: 950 Bytes
2b2ab2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from audiocraft.models import MusicGen
from mlserver import MLModel
from mlserver.types import InferenceRequest, InferenceResponse, ResponseOutput

MODEL = "small"

class MusicGenServer(MLModel):

    async def load(self):
        self.model = MusicGen.get_pretrained(MODEL, device="cuda")


    async def predict(self, request: InferenceRequest) -> InferenceResponse:
        
        prompts = request.inputs[0].data["prompts"]
        seconds = request.inputs[0].data["duration"]
        duration = 5 if seconds < 2 else seconds

        self.model.set_generation_params(duration=duration)
        wav = self.model.generate(prompts)

        shape = list(wav[0].shape)

        response_output = ResponseOutput(
            name="new_music",
            shape=shape,
            datatype="FLOAT32",
            data=wav[0, 0].cpu().tolist(),
        )
        
        return InferenceResponse(model_name="music_model", outputs=[response_output])