|
import os |
|
import io |
|
import base64 |
|
import torch |
|
import numpy as np |
|
from transformers import BarkModel, BarkProcessor |
|
from typing import Dict, List, Any |
|
|
|
class EndpointHandler: |
|
def __init__(self, path=""): |
|
""" |
|
Initialize the handler for Bark text-to-speech model. |
|
Args: |
|
path (str, optional): Path to the model directory. Defaults to "". |
|
""" |
|
self.path = path |
|
self.model = None |
|
self.processor = None |
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
self.initialized = False |
|
|
|
def setup(self, **kwargs): |
|
""" |
|
Load the model and processor. |
|
Args: |
|
**kwargs: Additional arguments. |
|
""" |
|
|
|
self.model = BarkModel.from_pretrained(self.path) |
|
self.model.to(self.device) |
|
|
|
|
|
self.processor = BarkProcessor.from_pretrained(self.path) |
|
|
|
self.initialized = True |
|
print(f"Bark model loaded on {self.device}") |
|
|
|
def preprocess(self, request: Dict) -> Dict: |
|
""" |
|
Process the input request before inference. |
|
Args: |
|
request (Dict): The request data containing text to convert to speech. |
|
Returns: |
|
Dict: Processed inputs for the model. |
|
""" |
|
if not self.initialized: |
|
self.setup() |
|
|
|
inputs = {} |
|
|
|
|
|
if "inputs" in request: |
|
if isinstance(request["inputs"], str): |
|
|
|
inputs["text"] = request["inputs"] |
|
elif isinstance(request["inputs"], list): |
|
|
|
inputs["text"] = request["inputs"][0] |
|
|
|
|
|
params = request.get("parameters", {}) |
|
|
|
|
|
if "speaker_id" in params: |
|
inputs["speaker_id"] = params["speaker_id"] |
|
elif "voice_preset" in params: |
|
inputs["voice_preset"] = params["voice_preset"] |
|
|
|
|
|
if "temperature" in params: |
|
inputs["temperature"] = params.get("temperature", 0.7) |
|
|
|
return inputs |
|
|
|
def inference(self, inputs: Dict) -> Dict: |
|
""" |
|
Run model inference on the processed inputs. |
|
Args: |
|
inputs (Dict): Processed inputs for the model. |
|
Returns: |
|
Dict: Model outputs. |
|
""" |
|
text = inputs.get("text", "") |
|
if not text: |
|
return {"error": "No text provided for speech generation"} |
|
|
|
|
|
speaker_id = inputs.get("speaker_id", None) |
|
voice_preset = inputs.get("voice_preset", None) |
|
temperature = inputs.get("temperature", 0.7) |
|
|
|
|
|
input_ids = self.processor(text).to(self.device) |
|
|
|
|
|
with torch.no_grad(): |
|
if speaker_id: |
|
|
|
speech_output = self.model.generate( |
|
input_ids=input_ids, |
|
speaker_id=speaker_id, |
|
temperature=temperature |
|
) |
|
elif voice_preset: |
|
|
|
speech_output = self.model.generate( |
|
input_ids=input_ids, |
|
voice_preset=voice_preset, |
|
temperature=temperature |
|
) |
|
else: |
|
|
|
speech_output = self.model.generate( |
|
input_ids=input_ids, |
|
temperature=temperature |
|
) |
|
|
|
|
|
audio_array = speech_output.cpu().numpy().squeeze() |
|
|
|
return {"audio_array": audio_array, "sample_rate": self.model.generation_config.sample_rate} |
|
|
|
def postprocess(self, inference_output: Dict) -> Dict: |
|
""" |
|
Process the model outputs after inference. |
|
Args: |
|
inference_output (Dict): Model outputs. |
|
Returns: |
|
Dict: Processed outputs ready for the response. |
|
""" |
|
if "error" in inference_output: |
|
return {"error": inference_output["error"]} |
|
|
|
audio_array = inference_output.get("audio_array") |
|
sample_rate = inference_output.get("sample_rate", 24000) |
|
|
|
|
|
try: |
|
import scipy.io.wavfile as wav |
|
audio_buffer = io.BytesIO() |
|
wav.write(audio_buffer, sample_rate, audio_array) |
|
audio_buffer.seek(0) |
|
audio_data = audio_buffer.read() |
|
|
|
|
|
audio_base64 = base64.b64encode(audio_data).decode("utf-8") |
|
|
|
return { |
|
"audio": audio_base64, |
|
"sample_rate": sample_rate, |
|
"format": "wav" |
|
} |
|
except Exception as e: |
|
return {"error": f"Error converting audio: {str(e)}"} |
|
|
|
def __call__(self, data: Dict) -> Dict: |
|
""" |
|
Main entry point for the handler. |
|
Args: |
|
data (Dict): Request data. |
|
Returns: |
|
Dict: Response data. |
|
""" |
|
|
|
if not self.initialized: |
|
self.setup() |
|
|
|
|
|
try: |
|
inputs = self.preprocess(data) |
|
outputs = self.inference(inputs) |
|
response = self.postprocess(outputs) |
|
return response |
|
except Exception as e: |
|
return {"error": f"Error processing request: {str(e)}"} |
|
|