ramonpzg commited on
Commit
2b2ab2a
1 Parent(s): 929fd9d

Create serve-model.py file

Browse files
Files changed (1) hide show
  1. serve-model.py +31 -0
serve-model.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from audiocraft.models import MusicGen
2
+ from mlserver import MLModel
3
+ from mlserver.types import InferenceRequest, InferenceResponse, ResponseOutput
4
+
5
+ MODEL = "small"
6
+
7
+ class MusicGenServer(MLModel):
8
+
9
+ async def load(self):
10
+ self.model = MusicGen.get_pretrained(MODEL, device="cuda")
11
+
12
+
13
+ async def predict(self, request: InferenceRequest) -> InferenceResponse:
14
+
15
+ prompts = request.inputs[0].data["prompts"]
16
+ seconds = request.inputs[0].data["duration"]
17
+ duration = 5 if seconds < 2 else seconds
18
+
19
+ self.model.set_generation_params(duration=duration)
20
+ wav = self.model.generate(prompts)
21
+
22
+ shape = list(wav[0].shape)
23
+
24
+ response_output = ResponseOutput(
25
+ name="new_music",
26
+ shape=shape,
27
+ datatype="FLOAT32",
28
+ data=wav[0, 0].cpu().tolist(),
29
+ )
30
+
31
+ return InferenceResponse(model_name="music_model", outputs=[response_output])