H-H-E commited on
Commit
4f9c202
1 Parent(s): a1e58ca

Upload 2 files

Browse files
Files changed (2) hide show
  1. handler.py +17 -14
  2. requirements.txt +1 -0
handler.py CHANGED
@@ -1,23 +1,26 @@
1
  from typing import Dict, List, Any
2
  from transformers import pipeline
 
3
 
4
 
5
  class EndpointHandler:
6
  def __init__(self, path=""):
7
- self.model = pipeline("text-to-speech", "suno/bark")
8
 
9
- def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
10
- """
11
- data args:
12
- inputs (:obj: `str`)
13
- date (:obj: `str`)
14
- Return:
15
- A :obj:`list` | `dict`: will be serialized and returned
16
- """
17
- # get inputs
18
- text_prompt = data.pop("inputs", data)
19
 
20
- # run normal prediction
21
- speech_array = self.model(text_prompt,forward_params={"do_sample": True})
22
- return speech_array
 
23
 
 
 
 
 
 
 
 
 
1
  from typing import Dict, List, Any
2
  from transformers import pipeline
3
+ import scipy.io.wavfile
4
 
5
 
6
  class EndpointHandler:
7
  def __init__(self, path=""):
8
+ self.synthesiser = pipeline("text-generation", "suno/bark") # Attempt to create pipeline
9
 
10
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
11
+ text_prompt = data.get("inputs")
12
+ if not text_prompt:
13
+ raise ValueError("Missing required 'inputs' field in request data.")
 
 
 
 
 
 
14
 
15
+ try:
16
+ speech = self.synthesiser(text_prompt, forward_params={"do_sample": True})
17
+ audio_data = speech["audio"] # Assuming audio is in a NumPy array
18
+ sampling_rate = speech["sampling_rate"]
19
 
20
+ # Return audio data as a byte string (adjust format as needed)
21
+ audio_bytes = audio_data.tobytes()
22
+ return {"audio": audio_bytes, "sampling_rate": sampling_rate}
23
+
24
+ except Exception as e:
25
+ # Handle potential errors with model loading or usage
26
+ return {"error": str(e)}
requirements.txt CHANGED
@@ -1 +1,2 @@
1
  transformers
 
 
1
  transformers
2
+ scipy