| | """ |
| | Custom Handler for Hugging Face Inference Endpoints |
| | Model: IbrahimSalah/Arabic-TTS-Spark |
| | Repository: https://huggingface.co/IbrahimSalah/Arabic-TTS-Spark |
| | |
| | This handler provides Text-to-Speech inference for Arabic with: |
| | - Voice cloning (with reference audio) |
| | - Controllable TTS (with gender, pitch, speed parameters) |
| | """ |
| |
|
| | import base64 |
| | import io |
| | import logging |
| | import os |
| | import tempfile |
| | from pathlib import Path |
| | from typing import Any, Dict, Optional |
| |
|
| | import numpy as np |
| | import soundfile as sf |
| | import torch |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | class EndpointHandler: |
| | """ |
| | Hugging Face Inference Endpoints handler for Arabic-TTS-Spark. |
| | |
| | Supports two modes: |
| | 1. Voice Cloning: Provide reference audio to clone the voice |
| | 2. Controllable TTS: Specify gender, pitch, and speed parameters |
| | """ |
| |
|
| | def __init__(self, path: str = ""): |
| | """ |
| | Initialize the handler by loading the model and processor. |
| | |
| | Args: |
| | path: Path to the model directory (provided by HF Inference Endpoints) |
| | """ |
| | from transformers import AutoModel, AutoProcessor |
| |
|
| | self.device = self._get_device() |
| | logger.info(f"Initializing Arabic-TTS-Spark on device: {self.device}") |
| |
|
| | |
| | model_path = path if path else "IbrahimSalah/Arabic-TTS-Spark" |
| | logger.info(f"Loading model from: {model_path}") |
| |
|
| | |
| | self.processor = AutoProcessor.from_pretrained( |
| | model_path, |
| | trust_remote_code=True |
| | ) |
| |
|
| | self.model = AutoModel.from_pretrained( |
| | model_path, |
| | trust_remote_code=True, |
| | torch_dtype=torch.bfloat16 if self.device.type == "cuda" else torch.float32 |
| | ) |
| |
|
| | |
| | self.model = self.model.to(self.device).eval() |
| |
|
| | |
| | self.processor.link_model(self.model) |
| |
|
| | |
| | self.default_reference_path = Path(model_path) / "reference.wav" |
| | if not self.default_reference_path.exists(): |
| | |
| | self.default_reference_path = Path(path) / "reference.wav" if path else None |
| |
|
| | logger.info("Model loaded successfully") |
| |
|
| | def _get_device(self) -> torch.device: |
| | """Determine the best available device.""" |
| | if torch.cuda.is_available(): |
| | return torch.device("cuda") |
| | elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): |
| | return torch.device("mps") |
| | return torch.device("cpu") |
| |
|
| | def _decode_audio_base64(self, audio_base64: str) -> tuple: |
| | """ |
| | Decode base64 audio to numpy array. |
| | |
| | Args: |
| | audio_base64: Base64 encoded audio data |
| | |
| | Returns: |
| | Tuple of (audio_data, sample_rate) |
| | """ |
| | audio_bytes = base64.b64decode(audio_base64) |
| | audio_buffer = io.BytesIO(audio_bytes) |
| | audio_data, sample_rate = sf.read(audio_buffer) |
| | return audio_data, sample_rate |
| |
|
| | def _encode_audio_base64(self, audio_data: np.ndarray, sample_rate: int) -> str: |
| | """ |
| | Encode audio numpy array to base64. |
| | |
| | Args: |
| | audio_data: Audio waveform as numpy array |
| | sample_rate: Sample rate of the audio |
| | |
| | Returns: |
| | Base64 encoded audio string |
| | """ |
| | audio_buffer = io.BytesIO() |
| | sf.write(audio_buffer, audio_data, sample_rate, format='WAV') |
| | audio_buffer.seek(0) |
| | return base64.b64encode(audio_buffer.read()).decode('utf-8') |
| |
|
| | def _validate_inputs(self, data: Dict[str, Any]) -> tuple: |
| | """ |
| | Validate and extract inputs from request data. |
| | |
| | Args: |
| | data: Request data dictionary |
| | |
| | Returns: |
| | Tuple of (text, parameters, mode) |
| | """ |
| | |
| | text = data.get("inputs", "") |
| | if not text: |
| | raise ValueError("No input text provided. Use 'inputs' field.") |
| |
|
| | |
| | parameters = data.get("parameters", {}) |
| |
|
| | |
| | has_audio = "prompt_audio_base64" in parameters or "prompt_audio" in parameters |
| | has_control = all(k in parameters for k in ["gender", "pitch", "speed"]) |
| |
|
| | if has_audio: |
| | mode = "voice_cloning" |
| | elif has_control: |
| | mode = "controllable" |
| | else: |
| | |
| | mode = "controllable" |
| | parameters.setdefault("gender", "male") |
| | parameters.setdefault("pitch", "moderate") |
| | parameters.setdefault("speed", "moderate") |
| |
|
| | return text, parameters, mode |
| |
|
| | def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
| | """ |
| | Process inference request. |
| | |
| | Args: |
| | data: Request data with the following structure: |
| | { |
| | "inputs": "Arabic text with diacritics", |
| | "parameters": { |
| | # For voice cloning: |
| | "prompt_audio_base64": "<base64-wav>", # or "prompt_audio" |
| | "prompt_text": "reference transcript", |
| | |
| | # For controllable TTS: |
| | "gender": "male" or "female", |
| | "pitch": "very_low", "low", "moderate", "high", "very_high", |
| | "speed": "very_low", "low", "moderate", "high", "very_high", |
| | |
| | # Generation parameters (optional): |
| | "temperature": 0.8, |
| | "max_new_tokens": 3000, |
| | "top_p": 0.95, |
| | "top_k": 50 |
| | } |
| | } |
| | |
| | Returns: |
| | Dictionary with: |
| | { |
| | "audio": "<base64-encoded-wav>", |
| | "sampling_rate": 16000 |
| | } |
| | """ |
| | try: |
| | |
| | text, parameters, mode = self._validate_inputs(data) |
| | logger.info(f"Processing request - Mode: {mode}, Text length: {len(text)}") |
| |
|
| | |
| | temperature = parameters.get("temperature", 0.8) |
| | max_new_tokens = parameters.get("max_new_tokens", 3000) |
| | top_p = parameters.get("top_p", 0.95) |
| | top_k = parameters.get("top_k", 50) |
| |
|
| | |
| | if mode == "voice_cloning": |
| | |
| | audio_base64 = parameters.get("prompt_audio_base64") or parameters.get("prompt_audio") |
| | prompt_text = parameters.get("prompt_text", "") |
| |
|
| | |
| | with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file: |
| | audio_data, _ = self._decode_audio_base64(audio_base64) |
| | sf.write(tmp_file.name, audio_data, 16000) |
| | tmp_audio_path = tmp_file.name |
| |
|
| | try: |
| | |
| | inputs = self.processor( |
| | text=text, |
| | prompt_speech_path=tmp_audio_path, |
| | prompt_text=prompt_text if prompt_text else None, |
| | return_tensors="pt" |
| | ) |
| | finally: |
| | |
| | os.unlink(tmp_audio_path) |
| | else: |
| | |
| | gender = parameters.get("gender", "male") |
| | pitch = parameters.get("pitch", "moderate") |
| | speed = parameters.get("speed", "moderate") |
| |
|
| | |
| | valid_genders = ["male", "female"] |
| | valid_levels = ["very_low", "low", "moderate", "high", "very_high"] |
| |
|
| | if gender not in valid_genders: |
| | raise ValueError(f"Invalid gender: {gender}. Must be one of {valid_genders}") |
| | if pitch not in valid_levels: |
| | raise ValueError(f"Invalid pitch: {pitch}. Must be one of {valid_levels}") |
| | if speed not in valid_levels: |
| | raise ValueError(f"Invalid speed: {speed}. Must be one of {valid_levels}") |
| |
|
| | |
| | inputs = self.processor( |
| | text=text, |
| | gender=gender, |
| | pitch=pitch, |
| | speed=speed, |
| | return_tensors="pt" |
| | ) |
| |
|
| | |
| | inputs = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v |
| | for k, v in inputs.items()} |
| |
|
| | |
| | input_ids_len = inputs["input_ids"].shape[1] |
| |
|
| | |
| | with torch.no_grad(): |
| | output_ids = self.model.generate( |
| | input_ids=inputs["input_ids"], |
| | attention_mask=inputs.get("attention_mask"), |
| | max_new_tokens=max_new_tokens, |
| | temperature=temperature, |
| | top_p=top_p, |
| | top_k=top_k, |
| | do_sample=True, |
| | pad_token_id=self.processor.tokenizer.pad_token_id or self.processor.tokenizer.eos_token_id, |
| | eos_token_id=self.processor.tokenizer.eos_token_id, |
| | ) |
| |
|
| | |
| | global_tokens = inputs.get("global_token_ids_prompt") |
| | output = self.processor.decode( |
| | generated_ids=output_ids, |
| | global_token_ids_prompt=global_tokens, |
| | input_ids_len=input_ids_len |
| | ) |
| |
|
| | |
| | audio_data = output["audio"] |
| | sampling_rate = output["sampling_rate"] |
| |
|
| | |
| | if audio_data is None or len(audio_data) == 0: |
| | raise RuntimeError("Model generated empty audio output") |
| |
|
| | |
| | audio_base64 = self._encode_audio_base64(audio_data, sampling_rate) |
| |
|
| | logger.info(f"Generated audio: {len(audio_data)} samples at {sampling_rate}Hz") |
| |
|
| | return { |
| | "audio": audio_base64, |
| | "sampling_rate": sampling_rate |
| | } |
| |
|
| | except Exception as e: |
| | logger.error(f"Inference error: {str(e)}") |
| | return { |
| | "error": str(e), |
| | "error_type": type(e).__name__ |
| | } |
| |
|