jblast94 commited on
Commit
ed9e433
·
verified ·
1 Parent(s): 24c936f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +146 -97
app.py CHANGED
@@ -1,130 +1,179 @@
 
 
 
 
 
1
  import gradio as gr
 
 
2
  import torch
3
  import os
4
 
5
- # You must use the exact same model name as your repo
6
- MODEL_ID = "nineninesix/Kani-TTS-370m"
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
- # --- Global variable to store loaded models ---
9
- MODELS = {}
10
 
11
  @spaces.GPU
12
- def load_models():
13
- """Load models into GPU memory and store in a global variable."""
14
- global MODELS
15
- if not MODELS:
16
- print("Loading models into GPU memory...")
17
- from transformers import AutoModel, AutoConfig
18
-
19
- model_path = MODEL_ID
20
-
21
- # Load both the main model and its configuration
22
- model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
23
- config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
24
-
25
- # Store the loaded model and its configuration in the global variable
26
- MODELS = {
27
- "Kani TTS 370M": (model, config)
28
- }
29
-
30
- print(f"Models loaded. Available speakers: {list(config.speaker_id.keys()) if config.speaker_id else []}")
31
- return MODELS
32
-
33
- # --- Define a separate function for updating the stats display ---
34
- def update_stats_display():
35
- """This function gets the agent's stats and returns a formatted string for Gradio."""
36
- # This assumes 'agent' is a global instance of your ConversationalAgent class
37
- stats_text = agent.get_memory_stats()
38
- return gr.Markdown(f"### 📊 Memory Stats\n{stats_text}")
39
-
40
- def generate_speech(text: str, model_choice: str, speaker_display: str):
41
- """Generate speech using the selected model."""
42
  if not text.strip():
43
- return "Please enter text for speech generation.", None
 
 
 
44
 
45
  try:
46
  device = "cuda" if torch.cuda.is_available() else "cpu"
47
  print(f"Using device: {device}")
48
 
49
- # Ensure models are loaded
50
- if not MODELS:
51
- load_models()
52
-
53
- # Get the selected model from the global variable
54
- if model_choice not in MODELS:
55
- return f"Model '{model_choice}' not found.", None
56
-
57
- selected_model = MODELS[model_choice]
58
-
59
- # --- This is the key part to load a specific model ---
60
- model_to_generate = selected_model[0]
61
- cfg = selected_model[1] # Model config
62
  speaker_map = cfg.get('speaker_id', {}) if cfg is not None else {}
63
  if speaker_display and speaker_map:
64
  speaker_id = speaker_map.get(speaker_display)
65
  else:
66
  speaker_id = None
67
-
68
- print(f"Generating speech with {model_choice}...")
69
 
70
- # --- Use the specific part of the model for generation ---
71
- audio, _, time_report = model_to_generate.run_model(
72
- text=text,
73
- speaker_id=speaker_id,
74
- temperature=0.7,
75
- repetition_penalty=1.2,
76
- max_tokens=1024
77
- )
78
 
79
- sample_rate = 22050
80
  print("Speech generation completed!")
81
 
82
- return (sample_rate, audio), time_report
83
-
84
- # --- Create and configure the Gradio interface ---
85
- MODELS = load_models()
 
86
 
87
- with gr.Blocks(title="😻 KaniTTS - Text to Speech") as demo:
 
88
  gr.Markdown("# 😻 KaniTTS: Fast and Expressive Speech Generation Model")
 
89
 
90
- model_dropdown = gr.Dropdown(
91
- choices=list(MODELS.keys()),
92
- value=list(MODELS.keys())[0],
93
- label="Selected Model"
94
- )
95
-
96
- # --- Speaker selector (populated on model load) ---
97
- all_speakers = []
98
- if MODELS and list(MODELS.keys())[0] and MODELS[list(MODELS.keys())[0]][1]:
99
- all_speakers.extend(list(MODELS[list(MODELS.keys())[0]][1].speaker_id.keys()))
100
- all_speakers = sorted(list(set(all_speakers)))
101
- speaker_dropdown = gr.Dropdown(
102
- choices=all_speakers,
103
- value=None,
104
- label="Speaker",
105
- visible=True,
106
- allow_custom_value=True
107
- )
108
-
109
- text_input = gr.Textbox(label="Text", lines=5)
110
-
111
- generate_btn = gr.Button("Generate Speech", variant="primary")
112
-
113
- audio_output = gr.Audio(label="Generated Audio", type="numpy")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
- # --- Define the event to update the speakers when the model changes ---
 
 
 
 
 
 
 
 
116
  model_dropdown.change(
117
- fn=lambda choice: gr.update(choices=list(MODELS[choice][1].speaker_id.keys()), value=None, visible=True) if MODELS and MODELS[choice][1].speaker_id else gr.update(visible=False),
118
  inputs=[model_dropdown],
119
  outputs=[speaker_dropdown]
120
  )
121
-
122
- # --- Wire up the main generation button ---
 
 
 
 
 
 
 
123
  generate_btn.click(
124
- fn=generate_speech,
125
- inputs=[text_input, model_dropdown, speaker_dropdown],
126
- outputs=[audio_output]
127
  )
128
 
129
- # --- This is the API-enabling line ---
130
- demo.queue().launch(show_api=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ rom create_env import setup_dependencies
2
+
3
+ setup_dependencies()
4
+
5
+ import spaces
6
  import gradio as gr
7
+ from util import NemoAudioPlayer, InitModels, load_config, Examples
8
+ import numpy as np
9
  import torch
10
  import os
11
 
12
+ # Get HuggingFace token
13
+ token_ = os.getenv('HF_TOKEN')
14
+
15
+ config = load_config("./model_config.yaml")
16
+ models_configs = config.models
17
+ nemo_player_cfg = config.nemo_player
18
+
19
+ examples_cfg = load_config("./examples.yaml")
20
+ examples_maker = Examples(examples_cfg)
21
+ examples = examples_maker()
22
+
23
+ player = NemoAudioPlayer(nemo_player_cfg)
24
+ init_models = InitModels(models_configs, player, token_)
25
+ models = init_models()
26
 
 
 
27
 
28
  @spaces.GPU
29
+ def generate_speech_gpu(text, model_choice, speaker_display: str, t, top_p, rp, max_tok):
30
+ """
31
+ Generate speech from text using the selected model on GPU
32
+ """
33
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  if not text.strip():
35
+ return None, "Please enter text for speech generation."
36
+
37
+ if not model_choice:
38
+ return None, "Please select a model."
39
 
40
  try:
41
  device = "cuda" if torch.cuda.is_available() else "cpu"
42
  print(f"Using device: {device}")
43
 
44
+ selected_model = models[model_choice]
45
+ cfg = models_configs.get(model_choice)
 
 
 
 
 
 
 
 
 
 
 
46
  speaker_map = cfg.get('speaker_id', {}) if cfg is not None else {}
47
  if speaker_display and speaker_map:
48
  speaker_id = speaker_map.get(speaker_display)
49
  else:
50
  speaker_id = None
 
 
51
 
52
+ print(f"Generating speech with {model_choice}...")
53
+ audio, _, time_report = selected_model.run_model(text, speaker_id, t, top_p, rp, max_tok)
 
 
 
 
 
 
54
 
55
+ sample_rate = 22050
56
  print("Speech generation completed!")
57
 
58
+ return (sample_rate, audio), time_report #, f"✅ Audio generated successfully using {model_choice} on {device}"
59
+
60
+ except Exception as e:
61
+ print(f"Error during generation: {str(e)}")
62
+ return None, f"❌ Error during generation: {str(e)}"
63
 
64
+ # Create Gradio interface
65
+ with gr.Blocks(title="😻 KaniTTS - Text to Speech", theme=gr.themes.Ocean()) as demo:
66
  gr.Markdown("# 😻 KaniTTS: Fast and Expressive Speech Generation Model")
67
+ gr.Markdown("Select a model and enter text to generate emotional speech")
68
 
69
+ with gr.Row():
70
+ with gr.Column(scale=1):
71
+ model_dropdown = gr.Dropdown(
72
+ choices=list(models_configs.keys()),
73
+ value=list(models_configs.keys())[0],
74
+ label="Selected Model",
75
+ info="Base generates random voices"
76
+ )
77
+ # Speaker selector (shown only if model has speakers)
78
+ # Pre-populate all available speakers for example table rendering
79
+ all_speakers = []
80
+ for _cfg in models_configs.values():
81
+ if _cfg and _cfg.get('speaker_id'):
82
+ all_speakers.extend(list(_cfg.speaker_id.keys()))
83
+ all_speakers = sorted(list(set(all_speakers)))
84
+ speaker_dropdown = gr.Dropdown(
85
+ choices=all_speakers,
86
+ value=None,
87
+ label="Speaker",
88
+ visible=False,
89
+ allow_custom_value=True
90
+ )
91
+
92
+ text_input = gr.Textbox(
93
+ label="Text",
94
+ placeholder="Enter your text ...",
95
+ lines=3,
96
+ max_lines=10
97
+ )
98
+
99
+ with gr.Accordion("Settings", open=False):
100
+ temp = gr.Slider(
101
+ minimum=0.1, maximum=1.5, value=0.6, step=0.05,
102
+ label="Temp",
103
+ )
104
+ top_p = gr.Slider(
105
+ minimum=0.1, maximum=1.0, value=0.95, step=0.05,
106
+ label="Top P",
107
+ )
108
+ rp = gr.Slider(
109
+ minimum=1.0, maximum=2.0, value=1.1, step=0.05,
110
+ label="Repetition Penalty",
111
+ )
112
+ max_tok = gr.Slider(
113
+ minimum=100, maximum=2000, value=1000, step=100,
114
+ label="Max Tokens",
115
+ )
116
+
117
+ generate_btn = gr.Button("Run", variant="primary", size="lg")
118
+
119
+
120
+ with gr.Column(scale=1):
121
+ audio_output = gr.Audio(
122
+ label="Generated Audio",
123
+ type="numpy"
124
+ )
125
+
126
+ time_report_output = gr.Textbox(
127
+ label="Time Report",
128
+ interactive=False,
129
+ value="Ready to generate speech",
130
+ lines=3
131
+ )
132
 
133
+ # Update speakers when model changes
134
+ def update_speakers(model_choice):
135
+ cfg = models_configs.get(model_choice)
136
+ speakers = list(cfg.speaker_id.keys()) if (cfg and cfg.get('speaker_id')) else []
137
+ if speakers:
138
+ return gr.update(choices=speakers, value=speakers[0], visible=True)
139
+ else:
140
+ return gr.update(choices=[], value=None, visible=False)
141
+
142
  model_dropdown.change(
143
+ fn=update_speakers,
144
  inputs=[model_dropdown],
145
  outputs=[speaker_dropdown]
146
  )
147
+
148
+ # Populate speakers on initial page load based on default model
149
+ demo.load(
150
+ fn=update_speakers,
151
+ inputs=[model_dropdown],
152
+ outputs=[speaker_dropdown]
153
+ )
154
+
155
+ # GPU generation event
156
  generate_btn.click(
157
+ fn=generate_speech_gpu,
158
+ inputs=[text_input, model_dropdown, speaker_dropdown, temp, top_p, rp, max_tok],
159
+ outputs=[audio_output, time_report_output]
160
  )
161
 
162
+ with gr.Row():
163
+
164
+ examples = examples
165
+
166
+ gr.Examples(
167
+ examples=examples,
168
+ inputs=[text_input, model_dropdown, speaker_dropdown, temp, top_p, rp, max_tok],
169
+ fn=generate_speech_gpu,
170
+ outputs=[audio_output, time_report_output],
171
+ cache_examples=True,
172
+ )
173
+
174
+ if __name__ == "__main__":
175
+ demo.launch(
176
+ server_name="0.0.0.0",
177
+ server_port=7860,
178
+ show_error=True
179
+ )