Remsky commited on
Commit
165abce
·
1 Parent(s): 80023bb

Add audio and text utility modules, update requirements, and revise README

Browse files
Files changed (8) hide show
  1. README.md +26 -27
  2. app.py +114 -30
  3. lib/__init__.py +34 -0
  4. lib/audio_utils.py +23 -0
  5. lib/file_utils.py +101 -0
  6. lib/text_utils.py +56 -0
  7. requirements.txt +1 -1
  8. tts_model.py +213 -210
README.md CHANGED
@@ -1,47 +1,46 @@
1
  ---
2
- title: Kokoro TTS Zero
3
- emoji: 🎴
4
- colorFrom: gray
5
  colorTo: purple
6
  sdk: gradio
7
  sdk_version: 5.9.1
8
  app_file: app.py
9
- pinned: true
10
- license: apache-2.0
11
- short_description: A100 GPU Accelerated Inference applied to Kokoro-82M TTS
12
- models:
13
- - hexgrad/Kokoro-82M
14
  ---
15
 
16
  # Kokoro TTS Demo Space
17
 
18
  A Zero GPU-optimized Hugging Face Space for the Kokoro TTS model.
19
-
20
  ## Overview
21
 
22
  This Space provides a Gradio interface for the Kokoro TTS model, allowing users to:
23
  - Convert text to speech using multiple voices
24
  - Adjust speech speed
25
- - Get instant audio playback
26
-
27
- ## Technical Details
28
-
29
- - Zero GPU for efficient GPU resource management
30
- - Dynamically loads required modules from hexgrad/Kokoro-82M repository
31
-
32
- All dependencies are automatically handled:
33
- - Core modules (kokoro.py, models.py, etc.) are downloaded from hexgrad/Kokoro-82M
34
- - Model weights and voice files are cached in /data/.huggingface
35
- - System dependencies (espeak-ng) are installed via packages.txt
36
-
37
- ## Environment
38
-
39
- - Python 3.10.13
 
 
40
  - PyTorch 2.2.2
41
  - Gradio 5.9.1
42
- - A100 Zero GPU Enabled
 
 
 
43
 
44
 
45
 
46
- ## Notes
47
- - Model Warm-Up takes some time, it shines at longer lengths.
 
1
  ---
2
+ title: Kokoro TTS Demo
3
+ emoji: 🎙️
4
+ colorFrom: blue
5
  colorTo: purple
6
  sdk: gradio
7
  sdk_version: 5.9.1
8
  app_file: app.py
9
+ pinned: false
10
+ license: mit
 
 
 
11
  ---
12
 
13
  # Kokoro TTS Demo Space
14
 
15
  A Zero GPU-optimized Hugging Face Space for the Kokoro TTS model.
 
16
  ## Overview
17
 
18
  This Space provides a Gradio interface for the Kokoro TTS model, allowing users to:
19
  - Convert text to speech using multiple voices
20
  - Adjust speech speed
21
+ ## Project Structure
22
+
23
+ ```
24
+ .
25
+ ├── app.py # Main Gradio interface
26
+ ├── tts_model.py # GPU-accelerated TTS model manager
27
+ ├── lib/ # Utility modules
28
+ │ ├── __init__.py # Package exports
29
+ │ ├── text_utils.py # Text processing utilities
30
+ │ ├── file_utils.py # File operations
31
+ │ └── audio_utils.py # Audio processing
32
+ └── requirements.txt # Project dependencies
33
+ ```
34
+
35
+ ## Dependencies
36
+
37
+ Main dependencies:
38
  - PyTorch 2.2.2
39
  - Gradio 5.9.1
40
+ - Transformers 4.47.1
41
+ - HuggingFace Hub ≥0.25.1
42
+
43
+ For a complete list, see requirements.txt.
44
 
45
 
46
 
 
 
app.py CHANGED
@@ -1,8 +1,9 @@
1
  import os
2
  import gradio as gr
3
  import spaces
 
4
  from tts_model import TTSModel
5
- import numpy as np
6
 
7
  # Set HF_HOME for faster restarts with cached models/voices
8
  os.environ["HF_HOME"] = "/data/.huggingface"
@@ -22,81 +23,164 @@ def initialize_model():
22
  voice_list = initialize_model()
23
 
24
  @spaces.GPU(duration=120) # Allow 5 minutes for processing
25
- def generate_speech_from_ui(text, voice_name, speed):
26
  """Handle text-to-speech generation from the Gradio UI"""
27
  try:
28
- audio_array, duration = model.generate_speech(text, voice_name, speed)
29
- # Convert float array to int16 range (-32768 to 32767)
30
- audio_array = np.array(audio_array, dtype=np.float32)
31
- audio_array = (audio_array * 32767).astype(np.int16)
32
- return (24000, audio_array), f"Audio Duration: {duration:.2f} seconds\nProcessing complete - check console for detailed metrics"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  except Exception as e:
34
- raise gr.Error(str(e))
35
 
36
  # Create Gradio interface
37
  with gr.Blocks(title="Kokoro TTS Demo") as demo:
38
  gr.HTML(
39
  """
 
 
 
 
 
 
40
  <div style="text-align: center; max-width: 800px; margin: 0 auto;">
41
  <h1>Kokoro TTS Demo</h1>
42
  <p>Convert text to natural-sounding speech using various voices.</p>
43
  </div>
 
44
  """
45
  )
46
 
47
  with gr.Row():
48
- with gr.Column(scale=3):
49
- # Input components
50
  text_input = gr.TextArea(
51
  label="Text to speak",
52
- placeholder="Enter text here...",
53
- lines=3,
54
  value=open("the_time_machine_hgwells.txt").read()[:1000]
55
  )
56
- voice_dropdown = gr.Dropdown(
57
- label="Voice",
58
- choices=voice_list,
59
- value=voice_list[0] if voice_list else None,
60
- allow_custom_value=True # Allow custom values to avoid warnings
 
 
61
  )
62
- speed_slider = gr.Slider(
63
- label="Speed",
64
- minimum=0.5,
65
- maximum=2.0,
66
- value=1.0,
67
- step=0.1
 
 
 
 
 
 
 
68
  )
69
- submit_btn = gr.Button("Generate Speech")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
- with gr.Column(scale=2):
72
- # Output components
73
  audio_output = gr.Audio(
74
  label="Generated Speech",
75
  type="numpy",
76
  format="wav",
77
  autoplay=False
78
  )
 
 
 
 
 
 
79
  duration_text = gr.Textbox(
80
  label="Processing Info",
81
  interactive=False,
82
- lines=4
83
  )
84
 
85
  # Set up event handler
86
  submit_btn.click(
87
  fn=generate_speech_from_ui,
88
  inputs=[text_input, voice_dropdown, speed_slider],
89
- outputs=[audio_output, duration_text]
 
90
  )
91
-
92
 
93
  # Add text analysis info
94
  with gr.Row():
95
  with gr.Column():
96
  gr.Markdown("""
97
  ### Demo Text Info
98
- The preloaded text is from H.G. Wells' "The Time Machine" (Public Domain)
99
  """)
 
100
 
101
  # Launch the app
102
  if __name__ == "__main__":
 
1
  import os
2
  import gradio as gr
3
  import spaces
4
+ import time
5
  from tts_model import TTSModel
6
+ from lib import format_audio_output
7
 
8
  # Set HF_HOME for faster restarts with cached models/voices
9
  os.environ["HF_HOME"] = "/data/.huggingface"
 
23
  voice_list = initialize_model()
24
 
25
  @spaces.GPU(duration=120) # Allow 5 minutes for processing
26
+ def generate_speech_from_ui(text, voice_name, speed, progress=gr.Progress(track_tqdm=False)):
27
  """Handle text-to-speech generation from the Gradio UI"""
28
  try:
29
+ start_time = time.time()
30
+ gpu_timeout = 120 # seconds
31
+
32
+ # Create progress state
33
+ progress_state = {
34
+ "progress": 0.0,
35
+ "tokens_per_sec": 0.0,
36
+ "gpu_time_left": gpu_timeout
37
+ }
38
+
39
+ def update_progress(chunk_num, total_chunks, tokens_per_sec, rtf):
40
+ progress_state["progress"] = chunk_num / total_chunks
41
+ progress_state["tokens_per_sec"] = tokens_per_sec
42
+
43
+ # Update GPU time remaining
44
+ elapsed = time.time() - start_time
45
+ gpu_time_left = max(0, gpu_timeout - elapsed)
46
+ progress_state["gpu_time_left"] = gpu_time_left
47
+
48
+ # Only update progress display during processing
49
+ progress(progress_state["progress"], desc=f"Processing chunk {chunk_num}/{total_chunks} | GPU Time Left: {int(gpu_time_left)}s")
50
+
51
+ # Generate speech with progress tracking
52
+ audio_array, duration = model.generate_speech(
53
+ text,
54
+ voice_name,
55
+ speed,
56
+ progress_callback=update_progress
57
+ )
58
+
59
+ # Format output for Gradio
60
+ audio_output, duration_text = format_audio_output(audio_array)
61
+
62
+ # Calculate final metrics
63
+ total_time = time.time() - start_time
64
+ total_duration = len(audio_array) / 24000 # audio duration in seconds
65
+ final_rtf = total_time / total_duration if total_duration > 0 else 0
66
+
67
+ # Prepare final metrics display
68
+ metrics_text = (
69
+ f"Tokens/sec: {progress_state['tokens_per_sec']:.1f}\n" +
70
+ f"Real-time factor: {final_rtf:.2f}x (Processing Time / Audio Duration)\n" +
71
+ f"GPU Time Used: {int(total_time)}s of {gpu_timeout}s"
72
+ )
73
+
74
+ return (
75
+ audio_output,
76
+ metrics_text,
77
+ duration_text
78
+ )
79
  except Exception as e:
80
+ raise gr.Error(f"Generation failed: {str(e)}")
81
 
82
  # Create Gradio interface
83
  with gr.Blocks(title="Kokoro TTS Demo") as demo:
84
  gr.HTML(
85
  """
86
+ <div style="display: flex; justify-content: flex-end; padding: 10px; gap: 10px;">
87
+ <a href="https://huggingface.co/hexgrad/Kokoro-82M" target="_blank">
88
+ <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/model-on-hf-md-dark.svg" alt="Model on HF">
89
+ </a>
90
+ <a class="github-button" href="https://github.com/remsky/Kokoro-FastAPI" data-color-scheme="no-preference: light; light: light; dark: dark;" data-size="large" data-show-count="true" aria-label="Star remsky/Kokoro-FastAPI on GitHub">Repo for Local Use</a>
91
+ </div>
92
  <div style="text-align: center; max-width: 800px; margin: 0 auto;">
93
  <h1>Kokoro TTS Demo</h1>
94
  <p>Convert text to natural-sounding speech using various voices.</p>
95
  </div>
96
+ <script async defer src="https://buttons.github.io/buttons.js"></script>
97
  """
98
  )
99
 
100
  with gr.Row():
101
+ # Column 1: Text Input
102
+ with gr.Column():
103
  text_input = gr.TextArea(
104
  label="Text to speak",
105
+ placeholder="Enter text here or upload a .txt file",
106
+ lines=10,
107
  value=open("the_time_machine_hgwells.txt").read()[:1000]
108
  )
109
+
110
+ # Column 2: Controls
111
+ with gr.Column():
112
+ file_input = gr.File(
113
+ label="Upload .txt file",
114
+ file_types=[".txt"],
115
+ type="binary"
116
  )
117
+
118
+ def load_text_from_file(file_bytes):
119
+ if file_bytes is None:
120
+ return None
121
+ try:
122
+ return file_bytes.decode('utf-8')
123
+ except Exception as e:
124
+ raise gr.Error(f"Failed to read file: {str(e)}")
125
+
126
+ file_input.change(
127
+ fn=load_text_from_file,
128
+ inputs=[file_input],
129
+ outputs=[text_input]
130
  )
131
+
132
+ with gr.Group():
133
+ voice_dropdown = gr.Dropdown(
134
+ label="Voice",
135
+ choices=voice_list,
136
+ value=voice_list[0] if voice_list else None,
137
+ allow_custom_value=True
138
+ )
139
+ speed_slider = gr.Slider(
140
+ label="Speed",
141
+ minimum=0.5,
142
+ maximum=2.0,
143
+ value=1.0,
144
+ step=0.1
145
+ )
146
+ submit_btn = gr.Button("Generate Speech", variant="primary")
147
 
148
+ # Column 3: Output
149
+ with gr.Column():
150
  audio_output = gr.Audio(
151
  label="Generated Speech",
152
  type="numpy",
153
  format="wav",
154
  autoplay=False
155
  )
156
+ progress_bar = gr.Progress(track_tqdm=False)
157
+ metrics_text = gr.Textbox(
158
+ label="Processing Metrics",
159
+ interactive=False,
160
+ lines=3
161
+ )
162
  duration_text = gr.Textbox(
163
  label="Processing Info",
164
  interactive=False,
165
+ lines=2
166
  )
167
 
168
  # Set up event handler
169
  submit_btn.click(
170
  fn=generate_speech_from_ui,
171
  inputs=[text_input, voice_dropdown, speed_slider],
172
+ outputs=[audio_output, metrics_text, duration_text],
173
+ show_progress=True
174
  )
 
175
 
176
  # Add text analysis info
177
  with gr.Row():
178
  with gr.Column():
179
  gr.Markdown("""
180
  ### Demo Text Info
181
+ The demo text is loaded from H.G. Wells' "The Time Machine". This classic text demonstrates the system's ability to handle long-form content through chunking.
182
  """)
183
+
184
 
185
  # Launch the app
186
  if __name__ == "__main__":
lib/__init__.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .text_utils import normalize_text, chunk_text, count_tokens
2
+ from .file_utils import (
3
+ load_module_from_file,
4
+ download_model_files,
5
+ list_voice_files,
6
+ download_voice_files,
7
+ ensure_dir
8
+ )
9
+ from .audio_utils import (
10
+ convert_float_to_int16,
11
+ get_audio_duration,
12
+ format_audio_output,
13
+ concatenate_audio_chunks
14
+ )
15
+
16
+ __all__ = [
17
+ # Text utilities
18
+ 'normalize_text',
19
+ 'chunk_text',
20
+ 'count_tokens',
21
+
22
+ # File utilities
23
+ 'load_module_from_file',
24
+ 'download_model_files',
25
+ 'list_voice_files',
26
+ 'download_voice_files',
27
+ 'ensure_dir',
28
+
29
+ # Audio utilities
30
+ 'convert_float_to_int16',
31
+ 'get_audio_duration',
32
+ 'format_audio_output',
33
+ 'concatenate_audio_chunks'
34
+ ]
lib/audio_utils.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from typing import Tuple
3
+
4
+ def convert_float_to_int16(audio_array: np.ndarray) -> np.ndarray:
5
+ """Convert float audio array to int16 format"""
6
+ # Convert to float32 first to ensure proper scaling
7
+ audio_array = np.array(audio_array, dtype=np.float32)
8
+ # Scale to int16 range (-32768 to 32767)
9
+ return (audio_array * 32767).astype(np.int16)
10
+
11
+ def get_audio_duration(audio_array: np.ndarray, sample_rate: int = 24000) -> float:
12
+ """Calculate duration of audio in seconds"""
13
+ return len(audio_array) / sample_rate
14
+
15
+ def format_audio_output(audio_array: np.ndarray, sample_rate: int = 24000) -> Tuple[Tuple[int, np.ndarray], str]:
16
+ """Format audio array for Gradio output with duration info"""
17
+ audio_array = convert_float_to_int16(audio_array)
18
+ duration = get_audio_duration(audio_array, sample_rate)
19
+ return (sample_rate, audio_array), f"Audio Duration: {duration:.2f} seconds"
20
+
21
+ def concatenate_audio_chunks(chunks: list[np.ndarray]) -> np.ndarray:
22
+ """Concatenate multiple audio chunks into a single array"""
23
+ return np.concatenate(chunks)
lib/file_utils.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import importlib.util
3
+ import sys
4
+ from huggingface_hub import hf_hub_download
5
+ from typing import List, Optional
6
+
7
+ def load_module_from_file(module_name: str, file_path: str):
8
+ """Load a Python module from file path"""
9
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
10
+ if spec is None or spec.loader is None:
11
+ raise ImportError(f"Cannot load module {module_name} from {file_path}")
12
+ module = importlib.util.module_from_spec(spec)
13
+ sys.modules[module_name] = module
14
+ spec.loader.exec_module(module)
15
+ return module
16
+
17
+ def download_model_files(repo_id: str, filenames: List[str], local_dir: Optional[str] = None) -> List[str]:
18
+ """Download multiple files from Hugging Face Hub"""
19
+ paths = []
20
+ for filename in filenames:
21
+ try:
22
+ path = hf_hub_download(
23
+ repo_id=repo_id,
24
+ filename=filename,
25
+ local_dir=local_dir,
26
+ local_dir_use_symlinks=False
27
+ )
28
+ paths.append(path)
29
+ except Exception as e:
30
+ print(f"Error downloading {filename}: {str(e)}")
31
+ raise
32
+ return paths
33
+
34
+ def ensure_dir(path: str) -> None:
35
+ """Ensure directory exists, create if it doesn't"""
36
+ os.makedirs(path, exist_ok=True)
37
+
38
+ def list_voice_files(voices_dir: str) -> List[str]:
39
+ """List available voice files in directory"""
40
+ voices = []
41
+ try:
42
+ if not os.path.exists(voices_dir):
43
+ print(f"Voices directory does not exist: {voices_dir}")
44
+ return voices
45
+
46
+ files = os.listdir(voices_dir)
47
+ print(f"Found {len(files)} files in voices directory")
48
+
49
+ for file in files:
50
+ if file.endswith(".pt"):
51
+ voice_name = file[:-3] # Remove .pt extension
52
+ print(f"Found voice: {voice_name}")
53
+ voices.append(voice_name)
54
+
55
+ if not voices:
56
+ print("No voice files found in voices directory")
57
+
58
+ except Exception as e:
59
+ print(f"Error listing voices: {str(e)}")
60
+ import traceback
61
+ traceback.print_exc()
62
+
63
+ return sorted(voices)
64
+
65
+ def download_voice_files(repo_id: str, voices: List[str], voices_dir: str) -> None:
66
+ """Download voice files from Hugging Face Hub"""
67
+ ensure_dir(voices_dir)
68
+
69
+ for voice in voices:
70
+ try:
71
+ voice_path = os.path.join(voices_dir, voice)
72
+ print(f"Attempting to download voice {voice} to {voice_path}")
73
+
74
+ try:
75
+ downloaded_path = hf_hub_download(
76
+ repo_id=repo_id,
77
+ filename=f"voices/{voice}",
78
+ local_dir=voices_dir,
79
+ local_dir_use_symlinks=False,
80
+ force_filename=voice
81
+ )
82
+ print(f"Download completed to: {downloaded_path}")
83
+
84
+ if not os.path.exists(voice_path):
85
+ print(f"Warning: File not found at expected path {voice_path}")
86
+ print(f"Checking download location: {downloaded_path}")
87
+ if os.path.exists(downloaded_path):
88
+ print(f"Moving file from {downloaded_path} to {voice_path}")
89
+ os.rename(downloaded_path, voice_path)
90
+ else:
91
+ print(f"Verified voice file exists: {voice_path}")
92
+
93
+ except Exception as e:
94
+ print(f"Error downloading voice {voice}: {str(e)}")
95
+ import traceback
96
+ traceback.print_exc()
97
+
98
+ except Exception as e:
99
+ print(f"Error downloading voice {voice}: {str(e)}")
100
+ import traceback
101
+ traceback.print_exc()
lib/text_utils.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tiktoken
2
+
3
+ def normalize_text(text: str) -> str:
4
+ """Normalize text for TTS processing"""
5
+ if not text:
6
+ return ""
7
+ # Basic normalization - can be expanded based on needs
8
+ return text.strip()
9
+
10
+ def chunk_text(text: str, max_chars: int = 300) -> list[str]:
11
+ """Break text into chunks at natural boundaries"""
12
+ chunks = []
13
+ current_chunk = ""
14
+
15
+ # Split on sentence boundaries first
16
+ sentences = text.replace(".", ".|").replace("!", "!|").replace("?", "?|").replace(";", ";|").split("|")
17
+
18
+ for sentence in sentences:
19
+ if not sentence.strip():
20
+ continue
21
+
22
+ # If sentence is already too long, break on commas
23
+ if len(sentence) > max_chars:
24
+ parts = sentence.split(",")
25
+ for part in parts:
26
+ if len(current_chunk) + len(part) <= max_chars:
27
+ current_chunk += part + ","
28
+ else:
29
+ # If part is still too long, break on whitespace
30
+ if len(part) > max_chars:
31
+ words = part.split()
32
+ for word in words:
33
+ if len(current_chunk) + len(word) > max_chars:
34
+ chunks.append(current_chunk.strip())
35
+ current_chunk = word + " "
36
+ else:
37
+ current_chunk += word + " "
38
+ else:
39
+ chunks.append(current_chunk.strip())
40
+ current_chunk = part + ","
41
+ else:
42
+ if len(current_chunk) + len(sentence) <= max_chars:
43
+ current_chunk += sentence
44
+ else:
45
+ chunks.append(current_chunk.strip())
46
+ current_chunk = sentence
47
+
48
+ if current_chunk:
49
+ chunks.append(current_chunk.strip())
50
+
51
+ return chunks
52
+
53
+ def count_tokens(text: str) -> int:
54
+ """Count tokens in text using tiktoken"""
55
+ enc = tiktoken.get_encoding("cl100k_base")
56
+ return len(enc.encode(text))
requirements.txt CHANGED
@@ -9,4 +9,4 @@ regex==2024.11.6
9
  tiktoken==0.8.0
10
  transformers==4.47.1
11
  munch==4.0.0
12
-
 
9
  tiktoken==0.8.0
10
  transformers==4.47.1
11
  munch==4.0.0
12
+ matplotlib==3.4.3
tts_model.py CHANGED
@@ -1,122 +1,61 @@
1
  import os
2
- import io
3
- import spaces
4
  import torch
5
  import numpy as np
6
  import time
7
- import tiktoken
8
- import scipy.io.wavfile as wavfile
9
- from huggingface_hub import hf_hub_download
10
- import importlib.util
11
- import sys
12
-
13
- def load_module_from_file(module_name, file_path):
14
- """Load a Python module from file path"""
15
- spec = importlib.util.spec_from_file_location(module_name, file_path)
16
- if spec is None or spec.loader is None:
17
- raise ImportError(f"Cannot load module {module_name} from {file_path}")
18
- module = importlib.util.module_from_spec(spec)
19
- sys.modules[module_name] = module
20
- spec.loader.exec_module(module)
21
- return module
22
-
23
- # Download and load required Python modules
24
- py_modules = ["istftnet", "plbert", "models"]
25
- for py_module in py_modules:
26
- path = hf_hub_download(repo_id="hexgrad/Kokoro-82M", filename=f"{py_module}.py")
27
- load_module_from_file(py_module, path)
28
-
29
- # Load the kokoro module
30
- kokoro_path = hf_hub_download(repo_id="hexgrad/Kokoro-82M", filename="kokoro.py")
31
- kokoro = load_module_from_file("kokoro", kokoro_path)
32
-
33
- # Import required functions
34
- generate = kokoro.generate
35
- normalize_text = kokoro.normalize_text
36
- models = sys.modules['models']
37
- build_model = models.build_model
38
-
39
- # Set HF_HOME for faster restarts
40
- os.environ["HF_HOME"] = "/data/.huggingface"
41
 
42
  class TTSModel:
43
- """Self-contained TTS model manager for Hugging Face Spaces"""
44
 
45
  def __init__(self):
46
  self.model = None
47
  self.voices_dir = "voices"
48
  self.model_repo = "hexgrad/Kokoro-82M"
49
- os.makedirs(self.voices_dir, exist_ok=True)
50
 
51
- def initialize(self):
 
 
 
 
 
 
 
 
 
 
 
 
52
  """Initialize model and download voices"""
53
  try:
54
  print("Initializing model...")
55
 
56
- # Download model and config
57
- model_path = hf_hub_download(
58
- repo_id=self.model_repo,
59
- filename="kokoro-v0_19.pth"
60
- )
61
- config_path = hf_hub_download(
62
- repo_id=self.model_repo,
63
- filename="config.json"
64
  )
 
65
 
66
- # Build model directly on GPU if available
67
  with torch.cuda.device(0):
68
  torch.cuda.set_device(0)
69
- self.model = build_model(model_path, 'cuda')
70
  self._model_on_gpu = True
71
 
72
- # Download all available voices
73
- voices = [
74
- "af_bella.pt", "af_nicole.pt", "af_sarah.pt", "af_sky.pt", "af.pt",
75
- "am_adam.pt", "am_michael.pt",
76
- "bf_emma.pt", "bf_isabella.pt",
77
- "bm_george.pt", "bm_lewis.pt"
78
- ]
79
- for voice in voices:
80
- try:
81
- # Download voice file
82
- # Create full destination path
83
- voice_path = os.path.join(self.voices_dir, voice)
84
- print(f"Attempting to download voice {voice} to {voice_path}")
85
-
86
- # Ensure directory exists
87
- os.makedirs(self.voices_dir, exist_ok=True)
88
-
89
- # Download with explicit destination
90
- try:
91
- downloaded_path = hf_hub_download(
92
- repo_id=self.model_repo,
93
- filename=f"voices/{voice}",
94
- local_dir=self.voices_dir,
95
- local_dir_use_symlinks=False,
96
- force_filename=voice
97
- )
98
- print(f"Download completed to: {downloaded_path}")
99
-
100
- # Verify file exists
101
- if not os.path.exists(voice_path):
102
- print(f"Warning: File not found at expected path {voice_path}")
103
- print(f"Checking download location: {downloaded_path}")
104
- if os.path.exists(downloaded_path):
105
- print(f"Moving file from {downloaded_path} to {voice_path}")
106
- os.rename(downloaded_path, voice_path)
107
- else:
108
- print(f"Verified voice file exists: {voice_path}")
109
-
110
- except Exception as e:
111
- print(f"Error downloading voice {voice}: {str(e)}")
112
- import traceback
113
- traceback.print_exc()
114
-
115
- except Exception as e:
116
- print(f"Error downloading voice {voice}: {str(e)}")
117
- import traceback
118
- traceback.print_exc()
119
-
120
  print("Model initialization complete")
121
  return True
122
 
@@ -124,46 +63,35 @@ class TTSModel:
124
  print(f"Error initializing model: {str(e)}")
125
  return False
126
 
127
- def list_voices(self):
128
- """List available voices"""
129
- voices = []
130
  try:
131
- # Verify voices directory exists
132
- if not os.path.exists(self.voices_dir):
133
- print(f"Voices directory does not exist: {self.voices_dir}")
134
- return voices
135
-
136
- # Get list of files
137
- files = os.listdir(self.voices_dir)
138
- print(f"Found {len(files)} files in voices directory")
139
-
140
- # Filter for .pt files
141
- for file in files:
142
- if file.endswith(".pt"):
143
- voices.append(file[:-3]) # Remove .pt extension
144
- print(f"Found voice: {file[:-3]}")
145
-
146
- if not voices:
147
- print("No voice files found in voices directory")
148
-
149
  except Exception as e:
150
- print(f"Error listing voices: {str(e)}")
151
- import traceback
152
- traceback.print_exc()
153
-
154
- return sorted(voices)
 
 
 
 
 
155
 
156
- def _ensure_model_on_gpu(self):
157
  """Ensure model is on GPU and stays there"""
158
  if not hasattr(self, '_model_on_gpu') or not self._model_on_gpu:
159
  print("Moving model to GPU...")
160
  with torch.cuda.device(0):
161
  torch.cuda.set_device(0)
162
- # Move model to GPU using torch.nn.Module method
163
  if hasattr(self.model, 'to'):
164
  self.model.to('cuda')
165
  else:
166
- # Fallback for Munch object - move parameters individually
167
  for name in self.model:
168
  if isinstance(self.model[name], torch.Tensor):
169
  self.model[name] = self.model[name].cuda()
@@ -190,7 +118,7 @@ class TTSModel:
190
  voicepack = voicepack.cuda()
191
 
192
  # Run generation with everything on GPU
193
- audio, _ = generate(
194
  self.model,
195
  text,
196
  voicepack,
@@ -203,63 +131,24 @@ class TTSModel:
203
  except Exception as e:
204
  print(f"Error in audio generation: {str(e)}")
205
  raise e
206
-
207
- def chunk_text(self, text: str, max_chars: int = 300) -> list[str]:
208
- """Break text into chunks at natural boundaries"""
209
- chunks = []
210
- current_chunk = ""
211
-
212
- # Split on sentence boundaries first
213
- sentences = text.replace(".", ".|").replace("!", "!|").replace("?", "?|").replace(";", ";|").split("|")
214
-
215
- for sentence in sentences:
216
- if not sentence.strip():
217
- continue
218
-
219
- # If sentence is already too long, break on commas
220
- if len(sentence) > max_chars:
221
- parts = sentence.split(",")
222
- for part in parts:
223
- if len(current_chunk) + len(part) <= max_chars:
224
- current_chunk += part + ","
225
- else:
226
- # If part is still too long, break on whitespace
227
- if len(part) > max_chars:
228
- words = part.split()
229
- for word in words:
230
- if len(current_chunk) + len(word) > max_chars:
231
- chunks.append(current_chunk.strip())
232
- current_chunk = word + " "
233
- else:
234
- current_chunk += word + " "
235
- else:
236
- chunks.append(current_chunk.strip())
237
- current_chunk = part + ","
238
- else:
239
- if len(current_chunk) + len(sentence) <= max_chars:
240
- current_chunk += sentence
241
- else:
242
- chunks.append(current_chunk.strip())
243
- current_chunk = sentence
244
-
245
- if current_chunk:
246
- chunks.append(current_chunk.strip())
247
-
248
- return chunks
249
 
250
- def generate_speech(self, text: str, voice_name: str, speed: float = 1.0) -> tuple[np.ndarray, float]:
251
- """Generate speech from text. Returns (audio_array, duration)"""
 
 
 
 
 
 
 
252
  try:
253
  if not text or not voice_name:
254
  raise ValueError("Text and voice name are required")
255
 
256
  start_time = time.time()
257
 
258
- # Initialize tokenizer
259
- enc = tiktoken.get_encoding("cl100k_base")
260
- total_tokens = len(enc.encode(text))
261
-
262
- # Normalize text
263
  text = normalize_text(text)
264
  if not text:
265
  raise ValueError("Text is empty after normalization")
@@ -269,49 +158,158 @@ class TTSModel:
269
  torch.cuda.set_device(0)
270
 
271
  voice_path = os.path.join(self.voices_dir, f"{voice_name}.pt")
272
- if not os.path.exists(voice_path):
273
- raise ValueError(f"Voice not found: {voice_name}")
274
 
275
- # Load voice directly to GPU
 
 
276
  voicepack = torch.load(voice_path, map_location='cuda', weights_only=True)
277
 
278
  # Break text into chunks for better memory management
279
- chunks = self.chunk_text(text)
280
  print(f"Processing {len(chunks)} chunks...")
281
 
282
- # Ensure model is initialized and on GPU
283
- if self.model is None:
284
- print("Model not initialized, reinitializing...")
285
- if not self.initialize():
286
- raise ValueError("Failed to initialize model")
287
 
288
- # Move model to GPU if needed
289
- if not hasattr(self, '_model_on_gpu') or not self._model_on_gpu:
290
- print("Moving model to GPU...")
291
- if hasattr(self.model, 'to'):
292
- self.model.to('cuda')
293
- else:
294
- for name in self.model:
295
- if isinstance(self.model[name], torch.Tensor):
296
- self.model[name] = self.model[name].cuda()
297
- self._model_on_gpu = True
298
 
299
- # Process all chunks within same GPU context
300
- audio_chunks = []
301
- for i, chunk in enumerate(chunks):
302
- chunk_start = time.time()
303
- chunk_audio = self._generate_audio(
304
- text=chunk,
305
- voicepack=voicepack,
306
- lang=voice_name[0],
307
- speed=speed
308
- )
309
- chunk_time = time.time() - chunk_start
310
- print(f"Chunk {i+1}/{len(chunks)} processed in {chunk_time:.2f}s")
311
- audio_chunks.append(chunk_audio)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
312
 
313
  # Concatenate audio chunks
314
- audio = np.concatenate(audio_chunks)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
315
 
316
  # Calculate metrics
317
  total_time = time.time() - start_time
@@ -321,6 +319,11 @@ class TTSModel:
321
  print(f"Total tokens: {total_tokens}")
322
  print(f"Total time: {total_time:.2f}s")
323
  print(f"Tokens per second: {tokens_per_second:.2f}")
 
 
 
 
 
324
 
325
  return audio, len(audio) / 24000 # Return audio array and duration
326
 
 
1
  import os
 
 
2
  import torch
3
  import numpy as np
4
  import time
5
+ import matplotlib.pyplot as plt
6
+ from typing import Tuple, List
7
+ from statistics import mean, median, stdev
8
+ from lib import (
9
+ normalize_text,
10
+ chunk_text,
11
+ count_tokens,
12
+ load_module_from_file,
13
+ download_model_files,
14
+ list_voice_files,
15
+ download_voice_files,
16
+ ensure_dir,
17
+ concatenate_audio_chunks
18
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  class TTSModel:
21
+ """GPU-accelerated TTS model manager"""
22
 
23
  def __init__(self):
24
  self.model = None
25
  self.voices_dir = "voices"
26
  self.model_repo = "hexgrad/Kokoro-82M"
27
+ ensure_dir(self.voices_dir)
28
 
29
+ # Load required modules
30
+ py_modules = ["istftnet", "plbert", "models", "kokoro"]
31
+ module_files = download_model_files(self.model_repo, [f"{m}.py" for m in py_modules])
32
+
33
+ for module_name, file_path in zip(py_modules, module_files):
34
+ load_module_from_file(module_name, file_path)
35
+
36
+ # Import required functions from kokoro module
37
+ kokoro = __import__("kokoro")
38
+ self.generate = kokoro.generate
39
+ self.build_model = __import__("models").build_model
40
+
41
+ def initialize(self) -> bool:
42
  """Initialize model and download voices"""
43
  try:
44
  print("Initializing model...")
45
 
46
+ # Download model files
47
+ model_files = download_model_files(
48
+ self.model_repo,
49
+ ["kokoro-v0_19.pth", "config.json"]
 
 
 
 
50
  )
51
+ model_path = model_files[0] # kokoro-v0_19.pth
52
 
53
+ # Build model directly on GPU
54
  with torch.cuda.device(0):
55
  torch.cuda.set_device(0)
56
+ self.model = self.build_model(model_path, 'cuda')
57
  self._model_on_gpu = True
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  print("Model initialization complete")
60
  return True
61
 
 
63
  print(f"Error initializing model: {str(e)}")
64
  return False
65
 
66
+ def ensure_voice_downloaded(self, voice_name: str) -> bool:
67
+ """Ensure specific voice is downloaded"""
 
68
  try:
69
+ voice_path = os.path.join(self.voices_dir, f"{voice_name}.pt")
70
+ if not os.path.exists(voice_path):
71
+ print(f"Downloading voice {voice_name}.pt...")
72
+ download_voice_files(self.model_repo, [f"{voice_name}.pt"], self.voices_dir)
73
+ return True
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  except Exception as e:
75
+ print(f"Error downloading voice {voice_name}: {str(e)}")
76
+ return False
77
+
78
+ def list_voices(self) -> List[str]:
79
+ """List available voices"""
80
+ return [
81
+ "af_bella", "af_nicole", "af_sarah", "af_sky", "af",
82
+ "am_adam", "am_michael", "bf_emma", "bf_isabella",
83
+ "bm_george", "bm_lewis"
84
+ ]
85
 
86
+ def _ensure_model_on_gpu(self) -> None:
87
  """Ensure model is on GPU and stays there"""
88
  if not hasattr(self, '_model_on_gpu') or not self._model_on_gpu:
89
  print("Moving model to GPU...")
90
  with torch.cuda.device(0):
91
  torch.cuda.set_device(0)
 
92
  if hasattr(self.model, 'to'):
93
  self.model.to('cuda')
94
  else:
 
95
  for name in self.model:
96
  if isinstance(self.model[name], torch.Tensor):
97
  self.model[name] = self.model[name].cuda()
 
118
  voicepack = voicepack.cuda()
119
 
120
  # Run generation with everything on GPU
121
+ audio, _ = self.generate(
122
  self.model,
123
  text,
124
  voicepack,
 
131
  except Exception as e:
132
  print(f"Error in audio generation: {str(e)}")
133
  raise e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
+ def generate_speech(self, text: str, voice_name: str, speed: float = 1.0, progress_callback=None) -> Tuple[np.ndarray, float]:
136
+ """Generate speech from text. Returns (audio_array, duration)
137
+
138
+ Args:
139
+ text: Input text to convert to speech
140
+ voice_name: Name of voice to use
141
+ speed: Speech speed multiplier
142
+ progress_callback: Optional callback function(chunk_num, total_chunks, tokens_per_sec, rtf)
143
+ """
144
  try:
145
  if not text or not voice_name:
146
  raise ValueError("Text and voice name are required")
147
 
148
  start_time = time.time()
149
 
150
+ # Count tokens and normalize text
151
+ total_tokens = count_tokens(text)
 
 
 
152
  text = normalize_text(text)
153
  if not text:
154
  raise ValueError("Text is empty after normalization")
 
158
  torch.cuda.set_device(0)
159
 
160
  voice_path = os.path.join(self.voices_dir, f"{voice_name}.pt")
 
 
161
 
162
+ # Ensure voice is downloaded and load directly to GPU
163
+ if not self.ensure_voice_downloaded(voice_name):
164
+ raise ValueError(f"Failed to download voice: {voice_name}")
165
  voicepack = torch.load(voice_path, map_location='cuda', weights_only=True)
166
 
167
  # Break text into chunks for better memory management
168
+ chunks = chunk_text(text)
169
  print(f"Processing {len(chunks)} chunks...")
170
 
171
+ # Ensure model is initialized and on GPU
172
+ if self.model is None:
173
+ print("Model not initialized, reinitializing...")
174
+ if not self.initialize():
175
+ raise ValueError("Failed to initialize model")
176
 
177
+ # Move model to GPU if needed
178
+ if not hasattr(self, '_model_on_gpu') or not self._model_on_gpu:
179
+ print("Moving model to GPU...")
180
+ if hasattr(self.model, 'to'):
181
+ self.model.to('cuda')
182
+ else:
183
+ for name in self.model:
184
+ if isinstance(self.model[name], torch.Tensor):
185
+ self.model[name] = self.model[name].cuda()
186
+ self._model_on_gpu = True
187
 
188
+ # Process all chunks within same GPU context
189
+ audio_chunks = []
190
+ chunk_times = []
191
+ chunk_sizes = [] # Store chunk lengths
192
+ total_processed_tokens = 0
193
+ total_processed_time = 0
194
+
195
+ for i, chunk in enumerate(chunks):
196
+ chunk_start = time.time()
197
+ chunk_audio = self._generate_audio(
198
+ text=chunk,
199
+ voicepack=voicepack,
200
+ lang=voice_name[0],
201
+ speed=speed
202
+ )
203
+ chunk_time = time.time() - chunk_start
204
+
205
+ # Update metrics
206
+ chunk_tokens = count_tokens(chunk)
207
+ total_processed_tokens += chunk_tokens
208
+ total_processed_time += chunk_time
209
+ current_tokens_per_sec = total_processed_tokens / total_processed_time
210
+
211
+ # Calculate processing speed metrics
212
+ chunk_duration = len(chunk_audio) / 24000 # audio duration in seconds
213
+ rtf = chunk_time / chunk_duration
214
+ times_faster = 1 / rtf
215
+
216
+ chunk_times.append(chunk_time)
217
+ chunk_sizes.append(len(chunk))
218
+ print(f"Chunk {i+1}/{len(chunks)} processed in {chunk_time:.2f}s")
219
+ print(f"Current tokens/sec: {current_tokens_per_sec:.2f}")
220
+ print(f"Real-time factor: {rtf:.2f}x")
221
+ print(f"{times_faster:.1f}x faster than real-time")
222
+
223
+ audio_chunks.append(chunk_audio)
224
+
225
+ # Call progress callback if provided
226
+ if progress_callback:
227
+ progress_callback(i + 1, len(chunks), current_tokens_per_sec, rtf)
228
 
229
  # Concatenate audio chunks
230
+ audio = concatenate_audio_chunks(audio_chunks)
231
+
232
+ def setup_plot(fig, ax, title):
233
+ """Configure plot styling"""
234
+ # Improve grid
235
+ ax.grid(True, linestyle="--", alpha=0.3, color="#ffffff")
236
+
237
+ # Set title and labels with better fonts and more padding
238
+ ax.set_title(title, pad=40, fontsize=16, fontweight="bold", color="#ffffff")
239
+ ax.set_xlabel(ax.get_xlabel(), fontsize=14, fontweight="medium", color="#ffffff")
240
+ ax.set_ylabel(ax.get_ylabel(), fontsize=14, fontweight="medium", color="#ffffff")
241
+
242
+ # Improve tick labels
243
+ ax.tick_params(labelsize=12, colors="#ffffff")
244
+
245
+ # Style spines
246
+ for spine in ax.spines.values():
247
+ spine.set_color("#ffffff")
248
+ spine.set_alpha(0.3)
249
+ spine.set_linewidth(0.5)
250
+
251
+ # Set background colors
252
+ ax.set_facecolor("#1a1a2e")
253
+ fig.patch.set_facecolor("#1a1a2e")
254
+
255
+ return fig, ax
256
+
257
+ # Set dark style
258
+ plt.style.use("dark_background")
259
+
260
+ # Create figure with subplots
261
+ fig = plt.figure(figsize=(18, 16))
262
+ fig.patch.set_facecolor("#1a1a2e")
263
+
264
+ # Create subplot grid
265
+ gs = plt.GridSpec(2, 1, left=0.15, right=0.85, top=0.9, bottom=0.15, hspace=0.4)
266
+
267
+ # Processing times plot
268
+ ax1 = plt.subplot(gs[0])
269
+ chunks_x = list(range(1, len(chunks) + 1))
270
+ bars = ax1.bar(chunks_x, chunk_times, color='#ff2a6d', alpha=0.8)
271
+
272
+ # Add statistics lines
273
+ mean_time = mean(chunk_times)
274
+ median_time = median(chunk_times)
275
+ std_time = stdev(chunk_times) if len(chunk_times) > 1 else 0
276
+
277
+ ax1.axhline(y=mean_time, color='#05d9e8', linestyle='--',
278
+ label=f'Mean: {mean_time:.2f}s')
279
+ ax1.axhline(y=median_time, color='#d1f7ff', linestyle=':',
280
+ label=f'Median: {median_time:.2f}s')
281
+
282
+ # Add ±1 std dev range
283
+ if len(chunk_times) > 1:
284
+ ax1.axhspan(mean_time - std_time, mean_time + std_time,
285
+ color='#8c1eff', alpha=0.2, label='±1 Std Dev')
286
+
287
+ # Add value labels on top of bars
288
+ for bar in bars:
289
+ height = bar.get_height()
290
+ ax1.text(bar.get_x() + bar.get_width() / 2.0,
291
+ height,
292
+ f'{height:.2f}s',
293
+ ha='center',
294
+ va='bottom',
295
+ color='white',
296
+ fontsize=10)
297
+
298
+ ax1.set_xlabel('Chunk Number')
299
+ ax1.set_ylabel('Processing Time (seconds)')
300
+ setup_plot(fig, ax1, 'Chunk Processing Times')
301
+ ax1.legend(facecolor="#1a1a2e", edgecolor="#ffffff")
302
+
303
+ # Chunk sizes plot
304
+ ax2 = plt.subplot(gs[1])
305
+ ax2.plot(chunks_x, chunk_sizes, color='#ff9e00', marker='o', linewidth=2)
306
+ ax2.set_xlabel('Chunk Number')
307
+ ax2.set_ylabel('Chunk Size (chars)')
308
+ setup_plot(fig, ax2, 'Chunk Sizes')
309
+
310
+ # Save plot
311
+ plt.savefig('chunk_times.png')
312
+ plt.close()
313
 
314
  # Calculate metrics
315
  total_time = time.time() - start_time
 
319
  print(f"Total tokens: {total_tokens}")
320
  print(f"Total time: {total_time:.2f}s")
321
  print(f"Tokens per second: {tokens_per_second:.2f}")
322
+ print(f"Mean chunk time: {mean_time:.2f}s")
323
+ print(f"Median chunk time: {median_time:.2f}s")
324
+ if len(chunk_times) > 1:
325
+ print(f"Std dev: {std_time:.2f}s")
326
+ print(f"\nChunk time plot saved as 'chunk_times.png'")
327
 
328
  return audio, len(audio) / 24000 # Return audio array and duration
329