skroed commited on
Commit
c2293d7
1 Parent(s): 6eb9b90

Fix: endpoint passing speaker.

Browse files
Files changed (2) hide show
  1. handler.py +14 -10
  2. requirements.txt +0 -2
handler.py CHANGED
@@ -7,9 +7,9 @@ from transformers import AutoModel, AutoProcessor
7
  class EndpointHandler:
8
  def __init__(self, path=""):
9
  # load model and processor from path
10
- self.processor = AutoProcessor.from_pretrained(path)
11
  self.model = AutoModel.from_pretrained(
12
- path,
13
  ).to("cuda")
14
 
15
  def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
@@ -20,14 +20,18 @@ class EndpointHandler:
20
  """
21
  # process input
22
  text = data.pop("inputs", data)
23
- parameters = data.get("parameters", None)
24
-
25
- # preprocess
26
- inputs = self.processor(
27
- text=[text],
28
- return_tensors="pt",
29
- voice_preset=parameters.get("voice_preset", None),
30
- ).to("cuda")
 
 
 
 
31
 
32
  with torch.autocast("cuda"):
33
  outputs = self.model.generate(**inputs)
 
7
  class EndpointHandler:
8
  def __init__(self, path=""):
9
  # load model and processor from path
10
+ self.processor = AutoProcessor.from_pretrained("suno/bark")
11
  self.model = AutoModel.from_pretrained(
12
+ "suno/bark",
13
  ).to("cuda")
14
 
15
  def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
 
20
  """
21
  # process input
22
  text = data.pop("inputs", data)
23
+ voice_preset = data.get("voice_preset", None)
24
+ if voice_preset:
25
+ inputs = self.processor(
26
+ text=[text],
27
+ return_tensors="pt",
28
+ voice_preset=voice_preset,
29
+ ).to("cuda")
30
+ else:
31
+ inputs = self.processor(
32
+ text=[text],
33
+ return_tensors="pt",
34
+ ).to("cuda")
35
 
36
  with torch.autocast("cuda"):
37
  outputs = self.model.generate(**inputs)
requirements.txt DELETED
@@ -1,2 +0,0 @@
1
- transformers==4.34.1
2
- accelerate>=0.23.0