Spaces:
Sleeping
Sleeping
| import io | |
| import os | |
| from typing import Optional, Literal, Dict, Any, List | |
| import numpy as np | |
| from fastapi import FastAPI, HTTPException, Query | |
| from fastapi.responses import StreamingResponse, JSONResponse | |
| from pydantic import BaseModel | |
| import torch | |
| import nltk | |
| from transformers import AutoTokenizer, AutoFeatureExtractor | |
| from parler_tts import ParlerTTSForConditionalGeneration | |
| # --- one-time setup --- | |
| nltk.download("punkt_tab") | |
| DEVICE = ( | |
| "cuda:0" if torch.cuda.is_available() | |
| else "mps" if torch.backends.mps.is_available() | |
| else "cpu" | |
| ) | |
| TORCH_DTYPE = torch.bfloat16 if DEVICE != "cpu" else torch.float32 | |
| # finetuned model only | |
| FINETUNED_REPO_ID = "ai4bharat/indic-parler-tts" | |
| model = ParlerTTSForConditionalGeneration.from_pretrained( | |
| FINETUNED_REPO_ID, attn_implementation="eager", torch_dtype=TORCH_DTYPE | |
| ).to(DEVICE) | |
| # tokenizers / feature extractor | |
| # NOTE: the base repo id provides tokenizer & feature extractor | |
| BASE_REPO_FOR_TOK = "ai4bharat/indic-parler-tts-pretrained" | |
| tokenizer = AutoTokenizer.from_pretrained(BASE_REPO_FOR_TOK) | |
| description_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-large") | |
| feature_extractor = AutoFeatureExtractor.from_pretrained(BASE_REPO_FOR_TOK) | |
| SAMPLE_RATE = feature_extractor.sampling_rate | |
| # --- FastAPI app --- | |
| app = FastAPI(title="Indic Parler-TTS (finetuned) API", version="1.0.0") | |
| # Optional default voice descriptions per language | |
| DEFAULT_DESCRIPTIONS: Dict[str, str] = { | |
| "english": ( | |
| "A calm, neutral male voice speaks natural English at a moderate pace. " | |
| "Very clear audio with no background noise." | |
| ), | |
| "urdu": ( | |
| "A warm, neutral female voice speaks natural Urdu at a moderate pace. " | |
| "Very clear audio with no background noise." | |
| ), | |
| "punjabi": ( | |
| "A friendly, neutral male voice speaks natural Punjabi at a moderate pace. " | |
| "Very clear audio with no background noise." | |
| ), | |
| } | |
| def numpy_to_mp3(audio_array: np.ndarray, sampling_rate: int) -> bytes: | |
| """ | |
| Converts mono int16/float array to MP3 (320 kbps). | |
| Uses pydub/ffmpeg; falls back to WAV if pydub not available. | |
| """ | |
| try: | |
| from pydub import AudioSegment | |
| # normalize float → int16 | |
| if np.issubdtype(audio_array.dtype, np.floating): | |
| max_val = np.max(np.abs(audio_array)) or 1.0 | |
| audio_array = (audio_array / max_val) * 32767 | |
| audio_array = audio_array.astype(np.int16) | |
| seg = AudioSegment( | |
| audio_array.tobytes(), | |
| frame_rate=sampling_rate, | |
| sample_width=audio_array.dtype.itemsize, | |
| channels=1, | |
| ) | |
| buf = io.BytesIO() | |
| seg.export(buf, format="mp3", bitrate="320k") | |
| out = buf.getvalue() | |
| buf.close() | |
| return out | |
| except Exception: | |
| # fallback: WAV to keep things working even without ffmpeg | |
| import soundfile as sf | |
| buf = io.BytesIO() | |
| sf.write(buf, audio_array, sampling_rate, format="WAV", subtype="PCM_16") | |
| return buf.getvalue() | |
| def split_text_into_chunks(text: str, max_words: int = 25) -> List[str]: | |
| sentences = nltk.sent_tokenize(text) | |
| curr = "" | |
| chunks: List[str] = [] | |
| for s in sentences: | |
| candidate = (curr + " " + s).strip() if curr else s | |
| if len(candidate.split()) >= max_words and curr: | |
| chunks.append(curr.strip()) | |
| curr = s | |
| else: | |
| curr = candidate | |
| if curr.strip(): | |
| chunks.append(curr.strip()) | |
| return chunks | |
| def synthesize(text: str, description: str) -> np.ndarray: | |
| inputs = description_tokenizer(description, return_tensors="pt").to(DEVICE) | |
| chunks = split_text_into_chunks(text, max_words=25) | |
| all_audio = [] | |
| for chunk in chunks: | |
| prompt = tokenizer(chunk, return_tensors="pt").to(DEVICE) | |
| generation = model.generate( | |
| input_ids=inputs.input_ids, | |
| attention_mask=inputs.attention_mask, | |
| prompt_input_ids=prompt.input_ids, | |
| prompt_attention_mask=prompt.attention_mask, | |
| do_sample=True, | |
| return_dict_in_generate=True, | |
| ) | |
| if hasattr(generation, "sequences") and hasattr(generation, "audios_length"): | |
| audio = generation.sequences[0, : generation.audios_length[0]] | |
| audio_np = audio.to(torch.float32).cpu().numpy().squeeze() | |
| if audio_np.ndim > 1: | |
| audio_np = audio_np.flatten() | |
| all_audio.append(audio_np) | |
| if not all_audio: | |
| raise RuntimeError("TTS generation produced no audio.") | |
| return np.concatenate(all_audio) | |
| # ---- API schemas ---- | |
| class TTSRequest(BaseModel): | |
| text: str | |
| language: Optional[Literal["english", "urdu", "punjabi"]] = None | |
| voice_description: Optional[str] = None | |
| # "mp3" (default) or "wav" (force WAV fallback) | |
| format: Optional[Literal["mp3", "wav"]] = "mp3" | |
| def health() -> Dict[str, Any]: | |
| return {"status": "ok", "device": DEVICE, "sample_rate": SAMPLE_RATE} | |
| def tts(body: TTSRequest): | |
| if not body.text or not body.text.strip(): | |
| raise HTTPException(status_code=400, detail="`text` is required.") | |
| # choose description | |
| description = ( | |
| body.voice_description | |
| or DEFAULT_DESCRIPTIONS.get((body.language or "").lower(), None) | |
| or "The speaker speaks naturally with a neutral tone. The recording is very high quality with no background noise." | |
| ) | |
| try: | |
| audio = synthesize(body.text, description) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"generation_error: {e}") | |
| # return bytes stream | |
| if body.format == "wav": | |
| import soundfile as sf | |
| buf = io.BytesIO() | |
| sf.write(buf, audio, SAMPLE_RATE, format="WAV", subtype="PCM_16") | |
| buf.seek(0) | |
| return StreamingResponse(buf, media_type="audio/wav") | |
| # default: mp3 (falls back to WAV inside helper if mp3 fails) | |
| mp3_bytes = numpy_to_mp3(audio, SAMPLE_RATE) | |
| # crude detection if fallback produced WAV | |
| if mp3_bytes[:4] == b"RIFF": | |
| return StreamingResponse(io.BytesIO(mp3_bytes), media_type="audio/wav") | |
| return StreamingResponse(io.BytesIO(mp3_bytes), media_type="audio/mpeg") | |
| # uvicorn entrypoint (Spaces sets PORT) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", "7860"))) |