|
import subprocess |
|
import sys |
|
|
|
|
|
subprocess.check_call([sys.executable, "-m", "pip", "install", "--upgrade", "gradio>=4.44.0"]) |
|
|
|
from transformers import pipeline, AutoTokenizer, AutoModelForTokenClassification |
|
import gradio as gr |
|
import numpy as np |
|
import scipy.io.wavfile |
|
import tempfile |
|
import os |
|
from transformers import VitsModel, AutoTokenizer |
|
import torch |
|
import re |
|
import traceback |
|
|
|
print("Starting application...") |
|
|
|
|
|
punct_pipe = None |
|
model = None |
|
tokenizer = None |
|
|
|
def load_models(): |
|
global punct_pipe, model, tokenizer |
|
|
|
print("Loading punctuation model...") |
|
try: |
|
punctuation_model_id = "oliverguhr/fullstop-punctuation-multilang-large" |
|
punct_tokenizer = AutoTokenizer.from_pretrained(punctuation_model_id) |
|
punct_model = AutoModelForTokenClassification.from_pretrained(punctuation_model_id) |
|
punct_pipe = pipeline("token-classification", model=punct_model, tokenizer=punct_tokenizer, aggregation_strategy="simple") |
|
print("✓ Punctuation model loaded successfully") |
|
except Exception as e: |
|
print(f"✗ Error loading punctuation model: {e}") |
|
punct_pipe = None |
|
|
|
print("Loading TTS model...") |
|
try: |
|
model = VitsModel.from_pretrained("facebook/mms-tts-kmr-script_latin") |
|
tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-kmr-script_latin") |
|
print("✓ TTS model loaded successfully") |
|
except Exception as e: |
|
print(f"✗ Error loading TTS model: {e}") |
|
model = None |
|
tokenizer = None |
|
|
|
|
|
load_models() |
|
|
|
|
|
num2word = { |
|
"0": "sifir", "1": "yek", "2": "du", "3": "sê", "4": "çar", "5": "pênc", |
|
"6": "şeş", "7": "heft", "8": "heşt", "9": "neh", "10": "deh" |
|
} |
|
|
|
def replace_numbers_with_words(text): |
|
def repl(match): |
|
num = match.group() |
|
return num2word.get(num, num) |
|
return re.sub(r'\b\d+\b', repl, text) |
|
|
|
def restore_punctuation(text): |
|
if punct_pipe is None: |
|
print("Punctuation model not available, skipping...") |
|
return text |
|
|
|
try: |
|
results = punct_pipe(text) |
|
punctuated = "" |
|
for token in results: |
|
word = token['word'] |
|
punct = token.get('entity_group', '') |
|
if punct == "PERIOD": |
|
punctuated += word + ". " |
|
elif punct == "COMMA": |
|
punctuated += word + ", " |
|
else: |
|
punctuated += word + " " |
|
return punctuated.strip() |
|
except Exception as e: |
|
print(f"Punctuation error: {e}") |
|
return text |
|
|
|
def text_to_speech(text): |
|
print(f"=== TTS Function Called ===") |
|
print(f"Input text: '{text}'") |
|
|
|
try: |
|
|
|
if not text or text.strip() == "": |
|
error_msg = "Please enter some text" |
|
print(f"Error: {error_msg}") |
|
return None |
|
|
|
|
|
if model is None or tokenizer is None: |
|
error_msg = "TTS model not loaded properly" |
|
print(f"Error: {error_msg}") |
|
return None |
|
|
|
print("Processing text...") |
|
|
|
|
|
processed_text = text.strip() |
|
processed_text = replace_numbers_with_words(processed_text) |
|
print(f"Processed text: '{processed_text}'") |
|
|
|
|
|
print("Tokenizing...") |
|
inputs = tokenizer(processed_text, return_tensors="pt") |
|
print(f"Tokenized successfully, input_ids shape: {inputs['input_ids'].shape}") |
|
|
|
|
|
print("Generating audio...") |
|
with torch.no_grad(): |
|
output = model(**inputs).waveform |
|
print(f"Audio generated, shape: {output.shape}") |
|
|
|
|
|
waveform = output.squeeze().numpy() |
|
print(f"Waveform shape: {waveform.shape}") |
|
|
|
|
|
print("Saving audio file...") |
|
tmp_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) |
|
tmp_path = tmp_file.name |
|
tmp_file.close() |
|
|
|
scipy.io.wavfile.write( |
|
tmp_path, |
|
rate=model.config.sampling_rate, |
|
data=waveform |
|
) |
|
|
|
print(f"✓ Audio saved to: {tmp_path}") |
|
print("=== TTS Function Completed Successfully ===") |
|
return tmp_path |
|
|
|
except Exception as e: |
|
error_msg = f"Error in TTS: {str(e)}" |
|
print(f"✗ {error_msg}") |
|
print("Full traceback:") |
|
traceback.print_exc() |
|
return None |
|
|
|
print("Creating Gradio interface...") |
|
|
|
|
|
interface = gr.Interface( |
|
fn=text_to_speech, |
|
inputs=gr.Textbox( |
|
label="Nivîseke bi kurmancî binivîse", |
|
placeholder="Mînak: Silav! Ez baş im." |
|
), |
|
outputs=gr.Audio(label="Deng"), |
|
title="Bernameya Nivîs-bo-Deng ya bi kurmancî - Kurmanji Text-to-Speech", |
|
description="Nivîseke bi kurmancî binivîse ku bo deng bê veguherandin. / Write Kurmanji Kurdish text and listen to it.", |
|
submit_btn="Bişîne", |
|
clear_btn="Paqij bike", |
|
examples=[ |
|
["Silav! Ez baş im."], |
|
["Tu çawa yî?"], |
|
["Ez ji Kurdistanê me."] |
|
] |
|
) |
|
|
|
print("Launching interface...") |
|
|
|
if __name__ == "__main__": |
|
interface.launch() |