File size: 3,950 Bytes
3ac41ff
 
 
 
 
 
7c1c946
3ac41ff
 
 
 
b1088ea
3ac41ff
 
 
 
 
1cfeb40
b1088ea
3ac41ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7c1c946
3ac41ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4e078f3
3ac41ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import torch
from transformers import VitsModel, AutoTokenizer, VitsConfig
import soundfile as sf
import numpy as np
import gradio as gr
import os
from thaicleantext import clean_thai_text

def load_tts_model(pth_path, speed=1.0):
    """Load the TTS model from a .pth file"""
    try:
        loaded_dict = torch.load(pth_path, map_location=torch.device('cpu'))
        config = VitsConfig(**loaded_dict['config'])
        model = VitsModel(config)
        model.load_state_dict(loaded_dict['model_state'])
        model.eval()
        model.speaking_rate = speed
        tokenizer = AutoTokenizer.from_pretrained("VIZINTZOR/tts-tha-vits")
        return model, tokenizer, None
    except Exception as e:
        return None, None, f"Error loading model: {str(e)}"

def generate_speech(model, tokenizer, text, speed, volume, output_file="output.wav"):
    """Generate speech from text and save to file"""
    try:
        model.speaking_rate = speed
        inputs = tokenizer(text, return_tensors="pt")
        with torch.no_grad():
            waveform = model(**inputs).waveform
        waveform = waveform.squeeze().cpu().numpy()
        waveform = waveform / np.max(np.abs(waveform))  # Normalize to [-1, 1]
        waveform = waveform * volume  # Apply volume adjustment
        sample_rate = model.config.sampling_rate
        sf.write(output_file, waveform, sample_rate)
        return output_file, None
    except Exception as e:
        return None, f"Error generating speech: {str(e)}"

def get_available_models(model_dir="./models"):
    """Get list of .pth files in the models directory"""
    if not os.path.exists(model_dir):
        return []
    return [os.path.join(model_dir, f) for f in os.listdir(model_dir) if f.endswith('.pth')]

def tts_interface(text, model_path, speed, volume):
    """Gradio interface function"""
    model, tokenizer, error = load_tts_model(model_path, speed)
    if model is None or tokenizer is None:
        return None, error
    
    output_file = "output.wav"
    text = clean_thai_text(text)
    audio_file, error = generate_speech(model, tokenizer, text, speed, volume, output_file)
    if audio_file:
        return audio_file, "Audio generated successfully!"
    return None, error

# Create Gradio interface
with gr.Blocks(title="Text-to-Speech Generator", theme=gr.themes.Soft()) as demo:
    gr.Markdown("# Text-to-Speech Generator")
    gr.Markdown("Enter text, select a model, adjust speed and volume, and generate audio!")
    
    with gr.Row():
        with gr.Column(scale=2):
            text_input = gr.Textbox(
                label="Input Text",
                placeholder="Enter your text here...",
                lines=5
            )
            model_dropdown = gr.Dropdown(
                label="Select Model",
                choices=get_available_models(),
                value=get_available_models()[0] if get_available_models() else None
            )
            
        with gr.Column(scale=1):
            speed_slider = gr.Slider(
                minimum=0.5,
                maximum=2.0,
                value=1.0,
                step=0.05,
                label="Speaking Speed",
                info="1.0 is normal speed"
            )
            volume_slider = gr.Slider(
                minimum=0.1,
                maximum=1.0,
                value=1,
                step=0.05,
                label="Volume",
                info="Adjust output volume"
            )
            generate_btn = gr.Button("Generate Audio", variant="primary")
    
    with gr.Row():
        audio_output = gr.Audio(label="Generated Audio")
        status_output = gr.Textbox(label="Status", interactive=False)
    
    # Connect the button to the function
    generate_btn.click(
        fn=tts_interface,
        inputs=[text_input, model_dropdown, speed_slider, volume_slider],
        outputs=[audio_output, status_output]
    )

# Launch the interface
demo.launch()