musicgen
Inference Endpoints
reneepc commited on
Commit
498944d
1 Parent(s): fd048dd

Add initial custom handler

Browse files
Files changed (1) hide show
  1. handler.py +28 -0
handler.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ from transformers import AutoProcessor, MusicgenForConditionalGeneration
3
+ import torch
4
+
5
+ class EndpointHandler:
6
+ def __init__(self, path=""):
7
+ self.processor = AutoProcessor.from_pretrained(path)
8
+ self.model = MusicgenForConditionalGeneration.from_pretrained(path)
9
+
10
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
11
+ text_input = data.pop("inputs", data)
12
+ parameters = data.pop("parameters", None)
13
+
14
+ inputs = self.processor(
15
+ text = [text_input],
16
+ return_tensors="pt",
17
+ padding=True).to("cuda")
18
+
19
+ if parameters is not None:
20
+ with torch.autocast("cuda"):
21
+ outputs = self.model.generate(**inputs, **parameters)
22
+ else:
23
+ with torch.autocast("cuda"):
24
+ outputs = self.model.generate(**inputs)
25
+
26
+ prediction = outputs[0].cpu().numpy().tolist()
27
+
28
+ return [{"generated_audio": prediction}]