Spaces:
Runtime error
Runtime error
Create serve-model.py file
Browse files- serve-model.py +31 -0
serve-model.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from audiocraft.models import MusicGen
|
2 |
+
from mlserver import MLModel
|
3 |
+
from mlserver.types import InferenceRequest, InferenceResponse, ResponseOutput
|
4 |
+
|
5 |
+
MODEL = "small"
|
6 |
+
|
7 |
+
class MusicGenServer(MLModel):
|
8 |
+
|
9 |
+
async def load(self):
|
10 |
+
self.model = MusicGen.get_pretrained(MODEL, device="cuda")
|
11 |
+
|
12 |
+
|
13 |
+
async def predict(self, request: InferenceRequest) -> InferenceResponse:
|
14 |
+
|
15 |
+
prompts = request.inputs[0].data["prompts"]
|
16 |
+
seconds = request.inputs[0].data["duration"]
|
17 |
+
duration = 5 if seconds < 2 else seconds
|
18 |
+
|
19 |
+
self.model.set_generation_params(duration=duration)
|
20 |
+
wav = self.model.generate(prompts)
|
21 |
+
|
22 |
+
shape = list(wav[0].shape)
|
23 |
+
|
24 |
+
response_output = ResponseOutput(
|
25 |
+
name="new_music",
|
26 |
+
shape=shape,
|
27 |
+
datatype="FLOAT32",
|
28 |
+
data=wav[0, 0].cpu().tolist(),
|
29 |
+
)
|
30 |
+
|
31 |
+
return InferenceResponse(model_name="music_model", outputs=[response_output])
|