Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| import os | |
| from TTS.api import TTS | |
| from huggingface_hub import hf_hub_download | |
| # --- ROMANIZER IMPORT --- | |
| try: | |
| from romanizer import sinhala_to_roman | |
| except ImportError: | |
| def sinhala_to_roman(text): return text | |
| # --- CONSOLIDATED MODEL LOADING --- | |
| def load_standard_model(repo_id): | |
| model_path = hf_hub_download(repo_id=repo_id, filename="best_model.pth") | |
| config_path = hf_hub_download(repo_id=repo_id, filename="config.json") | |
| return TTS(model_path=model_path, config_path=config_path, gpu=False) | |
| def load_eng_model_with_surgery(): | |
| repo_id = "E-motionAssistant/text-to-speech-VITS-english" | |
| model_path = hf_hub_download(repo_id=repo_id, filename="best_model.pth") | |
| config_path = hf_hub_download(repo_id=repo_id, filename="config.json") | |
| checkpoint = torch.load(model_path, map_location="cpu") | |
| raw_weights = checkpoint['model']['text_encoder.emb.weight'] | |
| if raw_weights.shape[0] == 137: | |
| checkpoint['model']['text_encoder.emb.weight'] = raw_weights[:131, :] | |
| fixed_model_path = "fixed_eng_model.pth" | |
| torch.save(checkpoint, fixed_model_path) | |
| return TTS(model_path=fixed_model_path, config_path=config_path, gpu=False) | |
| return TTS(model_path=model_path, config_path=config_path, gpu=False) | |
| # --- INITIALIZATION --- | |
| print("Loading all models... this may take a moment.") | |
| models = { | |
| "sinhala": load_standard_model("E-motionAssistant/text-to-speech-VITS-sinhala"), | |
| "tamil": load_standard_model("E-motionAssistant/text-to-speech-VITS-tamil"), | |
| "english": load_eng_model_with_surgery() | |
| } | |
| # --- SPECIFIC ENDPOINT FUNCTIONS --- | |
| def tts_english(text): | |
| output = "english_out.wav" | |
| models["english"].tts_to_file(text=text, file_path=output) | |
| return output | |
| def tts_sinhala(text): | |
| processed = sinhala_to_roman(text) | |
| output = "sinhala_out.wav" | |
| models["sinhala"].tts_to_file(text=processed, file_path=output) | |
| return output | |
| def tts_tamil(text): | |
| output = "tamil_out.wav" | |
| models["tamil"].tts_to_file(text=text, file_path=output) | |
| return output | |
| # --- GRADIO UI WITH TABS --- | |
| with gr.Blocks(title="Multilingual TTS API") as demo: | |
| gr.Markdown("# Trilingual TTS System") | |
| gr.Markdown("Choose a tab below to use a specific language endpoint.") | |
| with gr.Tab("English"): | |
| input_eng = gr.Textbox(label="English Text") | |
| output_eng = gr.Audio(label="English Audio", type="filepath") | |
| btn_eng = gr.Button("Synthesize English") | |
| # api_name creates a specific endpoint: /api/predict/english_tts | |
| btn_eng.click(tts_english, inputs=input_eng, outputs=output_eng, api_name="english_tts") | |
| with gr.Tab("Sinhala"): | |
| input_sin = gr.Textbox(label="Sinhala Text (Input Unicode)") | |
| output_sin = gr.Audio(label="Sinhala Audio", type="filepath") | |
| btn_sin = gr.Button("Synthesize Sinhala") | |
| btn_sin.click(tts_sinhala, inputs=input_sin, outputs=output_sin, api_name="sinhala_tts") | |
| with gr.Tab("Tamil"): | |
| input_tam = gr.Textbox(label="Tamil Text") | |
| output_tam = gr.Audio(label="Tamil Audio", type="filepath") | |
| btn_tam = gr.Button("Synthesize Tamil") | |
| btn_tam.click(tts_tamil, inputs=input_tam, outputs=output_tam, api_name="tamil_tts") | |
| if __name__ == "__main__": | |
| demo.launch() |