H-H-E commited on
Commit
7a8a532
1 Parent(s): 30769ce

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +15 -15
handler.py CHANGED
@@ -1,28 +1,28 @@
1
  from typing import Dict, List, Any
2
- from transformers import pipeline
3
- import scipy.io.wavfile
4
-
5
 
6
  class EndpointHandler:
7
  def __init__(self, path=""):
8
- self.synthesiser = pipeline("text-generation", "suno/bark") # Attempt to create pipeline
 
9
 
10
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
11
- text_prompt = data.get("inputs")
12
- if not text_prompt:
13
- raise ValueError("Missing required 'inputs' field in request data.")
14
-
15
  try:
16
- print(self.synthesiser)
17
- speech = self.synthesiser(text_prompt, forward_params={"do_sample": True})
18
- print(speech)
19
- audio_data = speech["audio"] # Assuming audio is in a NumPy array
20
- sampling_rate = speech["sampling_rate"]
 
 
 
 
 
21
 
22
- # Return audio data as a byte string (adjust format as needed)
23
  audio_bytes = audio_data.tobytes()
24
  return {"audio": audio_bytes, "sampling_rate": sampling_rate}
25
 
26
  except Exception as e:
27
- # Handle potential errors with model loading or usage
28
  return {"error": str(e)}
 
1
  from typing import Dict, List, Any
2
+ from transformers import AutoProcessor, AutoModel
3
+ import scipy.io.wavfile # Assuming WAV output format
 
4
 
5
  class EndpointHandler:
6
  def __init__(self, path=""):
7
+ self.processor = AutoProcessor.from_pretrained("suno/bark")
8
+ self.model = AutoModel.from_pretrained("suno/bark")
9
 
10
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
 
 
 
 
11
  try:
12
+ text_prompt = data.get("inputs")
13
+ if not text_prompt:
14
+ raise ValueError("Missing required 'inputs' field in request data.")
15
+
16
+ inputs = self.processor(text=[text_prompt], return_tensors="pt")
17
+ speech_values = self.model.generate(**inputs, do_sample=True)
18
+
19
+ # Assuming model returns audio as NumPy array
20
+ audio_data = speech_values[0].numpy()
21
+ sampling_rate = 22050 # Adjust as needed based on model documentation
22
 
23
+ # Return audio data as a byte string
24
  audio_bytes = audio_data.tobytes()
25
  return {"audio": audio_bytes, "sampling_rate": sampling_rate}
26
 
27
  except Exception as e:
 
28
  return {"error": str(e)}