File size: 1,148 Bytes
3e55f57
7a8a532
 
f067669
 
3e55f57
7a8a532
 
3e55f57
4f9c202
 
7a8a532
 
 
 
 
 
 
 
 
 
62f977c
7a8a532
4f9c202
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from typing import Dict, List, Any
from transformers import AutoProcessor, AutoModel
import scipy.io.wavfile  # Assuming WAV output format

class EndpointHandler:
    def __init__(self, path=""):
        self.processor = AutoProcessor.from_pretrained("suno/bark")
        self.model = AutoModel.from_pretrained("suno/bark")

    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        try:
            text_prompt = data.get("inputs")
            if not text_prompt:
                raise ValueError("Missing required 'inputs' field in request data.")

            inputs = self.processor(text=[text_prompt], return_tensors="pt")
            speech_values = self.model.generate(**inputs, do_sample=True)

            # Assuming model returns audio as NumPy array
            audio_data = speech_values[0].numpy()
            sampling_rate = 22050  # Adjust as needed based on model documentation

            # Return audio data as a byte string
            audio_bytes = audio_data.tobytes()
            return {"audio": audio_bytes, "sampling_rate": sampling_rate}

        except Exception as e:
            return {"error": str(e)}