johnbridges commited on
Commit
3461da4
·
1 Parent(s): 683efa7

first commit

Browse files
Files changed (8) hide show
  1. Dockerfile +38 -0
  2. README.md +3 -3
  3. app.py +373 -0
  4. commit +3 -0
  5. kokoro.py +165 -0
  6. packages.txt +5 -0
  7. requirements.txt +17 -0
  8. tts_processor.py +163 -0
Dockerfile ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ # Install system dependencies
4
+ RUN apt-get update && apt-get install -y --no-install-recommends \
5
+ libsndfile1 \
6
+ espeak-ng \
7
+ ffmpeg \
8
+ git \
9
+ wget \
10
+ && rm -rf /var/lib/apt/lists/*
11
+ RUN useradd -m -u 1000 user
12
+
13
+ # Switch to the "user" user
14
+ USER user
15
+
16
+ # Set home to the user's home directory
17
+ ENV HOME=/home/user \
18
+ PATH=/home/user/.local/bin:$PATH
19
+
20
+ # Set the working directory to the user's home directory
21
+ WORKDIR $HOME/app
22
+
23
+ # Create the files directory
24
+ RUN mkdir -p $HOME/app/files
25
+
26
+ # Copy and install Python dependencies
27
+ COPY requirements.txt $HOME/app/
28
+ RUN pip install --no-cache-dir -r requirements.txt && pip install --upgrade pip
29
+
30
+ # Copy the current directory contents into the container at $HOME/app setting the owner to the user
31
+ COPY --chown=user . $HOME/app
32
+
33
+
34
+ # Expose port
35
+ EXPOSE 7860
36
+
37
+ # Run the application
38
+ CMD ["python", "app.py"]
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
  title: NetMonTTS3
3
- emoji: 🔥
4
- colorFrom: purple
5
- colorTo: blue
6
  sdk: gradio
7
  sdk_version: 5.46.0
8
  app_file: app.py
 
1
  ---
2
  title: NetMonTTS3
3
+ emoji: 🏃
4
+ colorFrom: indigo
5
+ colorTo: red
6
  sdk: gradio
7
  sdk_version: 5.46.0
8
  app_file: app.py
app.py ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, jsonify, send_from_directory, abort
2
+ from transformers import WhisperProcessor, WhisperForConditionalGeneration
3
+ import librosa
4
+ import torch
5
+ import numpy as np
6
+ from onnxruntime import InferenceSession
7
+ import soundfile as sf
8
+ import os
9
+ import sys
10
+ import uuid
11
+ import logging
12
+ from flask_cors import CORS
13
+ import threading
14
+ import werkzeug
15
+ import tempfile
16
+ from huggingface_hub import snapshot_download
17
+ from tts_processor import preprocess_all
18
+ import hashlib
19
+ import os
20
+ import torch
21
+ import numpy as np
22
+ import onnxruntime as ort
23
+
24
+ # ---------------------------
25
+ # THREAD LIMIT CONFIG
26
+ # ---------------------------
27
+ MAX_THREADS = 2 # <-- change this number to control all thread usage
28
+
29
+ SERVE_DIR = "/home/user/app/files"
30
+ os.makedirs(SERVE_DIR, exist_ok=True)
31
+
32
+ # Limit NumPy / BLAS / MKL threads
33
+ os.environ["OMP_NUM_THREADS"] = str(MAX_THREADS)
34
+ os.environ["OPENBLAS_NUM_THREADS"] = str(MAX_THREADS)
35
+ os.environ["MKL_NUM_THREADS"] = str(MAX_THREADS)
36
+ os.environ["VECLIB_MAXIMUM_THREADS"] = str(MAX_THREADS)
37
+ os.environ["NUMEXPR_NUM_THREADS"] = str(MAX_THREADS)
38
+
39
+ # Torch thread limits
40
+ torch.set_num_threads(MAX_THREADS)
41
+ torch.set_num_interop_threads(1) # keep inter-op small to avoid overhead
42
+
43
+ # ONNXRuntime session options (use when creating the session)
44
+ sess_options = ort.SessionOptions()
45
+ sess_options.intra_op_num_threads = MAX_THREADS
46
+ sess_options.inter_op_num_threads = 1
47
+
48
+
49
+ # Configure logging
50
+ logging.basicConfig(level=logging.INFO)
51
+ logger = logging.getLogger(__name__)
52
+
53
+ app = Flask(__name__)
54
+ CORS(app, resources={r"/*": {"origins": "*"}})
55
+
56
+ # Global lock to ensure one method runs at a time
57
+ global_lock = threading.Lock()
58
+
59
+ # Repository ID and paths
60
+ kokoro_model_id = 'onnx-community/Kokoro-82M-v1.0-ONNX'
61
+ model_path = 'kokoro_model'
62
+ voice_name = 'am_adam' # Example voice: af (adjust as needed)
63
+
64
+ # Directory to serve files from
65
+ SERVE_DIR = os.environ.get("SERVE_DIR", "./files") # Default to './files' if not provided
66
+
67
+ os.makedirs(SERVE_DIR, exist_ok=True)
68
+ def validate_audio_file(file):
69
+ """Validates audio files including WebM/Opus format"""
70
+ if not isinstance(file, werkzeug.datastructures.FileStorage):
71
+ raise ValueError("Invalid file type")
72
+
73
+ # Supported MIME types (add WebM/Opus)
74
+ supported_types = [
75
+ "audio/wav",
76
+ "audio/x-wav",
77
+ "audio/mpeg",
78
+ "audio/mp3",
79
+ "audio/webm",
80
+ "audio/ogg" # For Opus in Ogg container
81
+ ]
82
+
83
+ # Check MIME type
84
+ if file.content_type not in supported_types:
85
+ raise ValueError(f"Unsupported file type. Must be one of: {', '.join(supported_types)}")
86
+
87
+ # Check file size
88
+ file.seek(0, os.SEEK_END)
89
+ file_size = file.tell()
90
+ file.seek(0) # Reset file pointer
91
+
92
+ max_size = 10 * 1024 * 1024 # 10 MB
93
+ if file_size > max_size:
94
+ raise ValueError(f"File is too large (max {max_size//(1024*1024)} MB)")
95
+
96
+ # Optional: Verify file header matches content_type
97
+ if not verify_audio_header(file):
98
+ raise ValueError("File header doesn't match declared content type")
99
+ def verify_audio_header(file):
100
+ """Quickly checks if file headers match the declared audio format"""
101
+ header = file.read(4)
102
+ file.seek(0) # Rewind after reading
103
+
104
+ if file.content_type in ["audio/webm", "audio/ogg"]:
105
+ # WebM starts with \x1aE\xdf\xa3, Ogg with OggS
106
+ return (
107
+ (file.content_type == "audio/webm" and header.startswith(b'\x1aE\xdf\xa3')) or
108
+ (file.content_type == "audio/ogg" and header.startswith(b'OggS'))
109
+ )
110
+ elif file.content_type in ["audio/wav", "audio/x-wav"]:
111
+ return header.startswith(b'RIFF')
112
+ elif file.content_type in ["audio/mpeg", "audio/mp3"]:
113
+ return header.startswith(b'\xff\xfb') # MP3 frame sync
114
+ return True # Skip verification for other types
115
+
116
+ def validate_text_input(text):
117
+ if not isinstance(text, str):
118
+ raise ValueError("Text input must be a string")
119
+ if len(text.strip()) == 0:
120
+ raise ValueError("Text input cannot be empty")
121
+ if len(text) > 1024: # Limit to 1024 characters
122
+ raise ValueError("Text input is too long (max 1024 characters)")
123
+
124
+ file_cache = {}
125
+
126
+ def is_cached(cached_file_path):
127
+ """
128
+ Check if a file exists in the cache.
129
+ If the file is not in the cache, perform a disk check and update the cache.
130
+ """
131
+ if cached_file_path in file_cache:
132
+ return file_cache[cached_file_path] # Return cached result
133
+ exists = os.path.exists(cached_file_path) # Perform disk check
134
+ file_cache[cached_file_path] = exists # Update the cache
135
+ return exists
136
+
137
+ # Initialize models
138
+ def initialize_models():
139
+ global sess, voice_style, processor, whisper_model
140
+
141
+ try:
142
+ # Download the ONNX model if not already downloaded
143
+ if not os.path.exists(model_path):
144
+ logger.info("Downloading and loading Kokoro model...")
145
+ kokoro_dir = snapshot_download(kokoro_model_id, cache_dir=model_path)
146
+ logger.info(f"Kokoro model directory: {kokoro_dir}")
147
+ else:
148
+ kokoro_dir = model_path
149
+ logger.info(f"Using cached Kokoro model directory: {kokoro_dir}")
150
+
151
+ # Validate ONNX file path
152
+ onnx_path = None
153
+ for root, _, files in os.walk(kokoro_dir):
154
+ if 'model.onnx' in files:
155
+ onnx_path = os.path.join(root, 'model.onnx')
156
+ break
157
+
158
+ if not onnx_path or not os.path.exists(onnx_path):
159
+ raise FileNotFoundError(f"ONNX file not found after redownload at {kokoro_dir}")
160
+
161
+ logger.info("Loading ONNX session...")
162
+ sess = InferenceSession(onnx_path, sess_options)
163
+ logger.info(f"ONNX session loaded successfully from {onnx_path}")
164
+
165
+ # Load the voice style vector
166
+ voice_style_path = None
167
+ for root, _, files in os.walk(kokoro_dir):
168
+ if f'{voice_name}.bin' in files:
169
+ voice_style_path = os.path.join(root, f'{voice_name}.bin')
170
+ break
171
+
172
+ if not voice_style_path or not os.path.exists(voice_style_path):
173
+ raise FileNotFoundError(f"Voice style file not found at {voice_style_path}")
174
+
175
+ logger.info("Loading voice style vector...")
176
+ voice_style = np.fromfile(voice_style_path, dtype=np.float32).reshape(-1, 1, 256)
177
+ logger.info(f"Voice style vector loaded successfully from {voice_style_path}")
178
+
179
+ # Initialize Whisper model for S2T
180
+ logger.info("Downloading and loading Whisper model...")
181
+ processor = WhisperProcessor.from_pretrained("openai/whisper-base")
182
+ whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base")
183
+ whisper_model.config.forced_decoder_ids = None
184
+ logger.info("Whisper model loaded successfully")
185
+
186
+ except Exception as e:
187
+ logger.error(f"Error initializing models: {str(e)}")
188
+ raise
189
+
190
+ # Initialize models
191
+ initialize_models()
192
+
193
+ # Health check endpoint
194
+ @app.route('/health', methods=['GET'])
195
+ def health_check():
196
+ try:
197
+ return jsonify({"status": "healthy"}), 200
198
+ except Exception as e:
199
+ logger.error(f"Health check failed: {str(e)}")
200
+ return jsonify({"status": "unhealthy"}), 500
201
+
202
+ # Text-to-Speech (T2S) Endpoint
203
+ @app.route('/generate_audio', methods=['POST'])
204
+ def generate_audio():
205
+ """Text-to-Speech (T2S) Endpoint"""
206
+ with global_lock:
207
+ try:
208
+ logger.debug("Received request to /generate_audio")
209
+ data = request.json
210
+ text = data['text']
211
+
212
+ validate_text_input(text)
213
+
214
+ # Preprocess & stable hash
215
+ text = preprocess_all(text)
216
+ text_hash = hashlib.sha256(text.encode('utf-8')).hexdigest()
217
+ filename = f"{text_hash}.wav"
218
+ cached_file_path = os.path.join(SERVE_DIR, filename)
219
+
220
+ # Cache hit
221
+ if is_cached(cached_file_path):
222
+ logger.info("Returning cached audio")
223
+ return jsonify({"status": "success", "filename": filename})
224
+
225
+ # Tokenize
226
+ from kokoro import phonemize, tokenize # lazy import is fine
227
+ tokens = tokenize(phonemize(text, 'a'))
228
+ if len(tokens) > 510:
229
+ logger.warning("Text too long; truncating to 510 tokens.")
230
+ tokens = tokens[:510]
231
+ tokens = [[0, *tokens, 0]]
232
+
233
+ # Style vector
234
+ ref_s = voice_style[len(tokens[0]) - 2] # (1,256)
235
+
236
+ # ONNX inference
237
+ audio = sess.run(None, dict(
238
+ input_ids=np.array(tokens, dtype=np.int64),
239
+ style=ref_s,
240
+ speed=np.ones(1, dtype=np.float32),
241
+ ))[0]
242
+
243
+ # Save
244
+ audio = np.squeeze(audio).astype(np.float32)
245
+ sf.write(cached_file_path, audio, 24000)
246
+
247
+ logger.info(f"Audio saved: {cached_file_path}")
248
+ return jsonify({"status": "success", "filename": filename})
249
+ except Exception as e:
250
+ logger.error(f"Error generating audio: {str(e)}")
251
+ return jsonify({"status": "error", "message": str(e)}), 500
252
+
253
+ # Speech-to-Text (S2T) Endpoint
254
+ # Add these imports at the top with the other imports
255
+ import subprocess
256
+ import tempfile
257
+ from pathlib import Path
258
+
259
+ # Then update the transcribe_audio function:
260
+ @app.route('/transcribe_audio', methods=['POST'])
261
+ def transcribe_audio():
262
+ """Speech-to-Text (S2T) Endpoint with automatic format conversion"""
263
+ with global_lock: # Acquire global lock to ensure only one instance runs
264
+ input_audio_path = None
265
+ converted_audio_path = None
266
+ try:
267
+ logger.debug("Received request to /transcribe_audio")
268
+ file = request.files['file']
269
+
270
+ # Create temporary files for both input and output
271
+ with tempfile.NamedTemporaryFile(delete=False, suffix=Path(file.filename).suffix) as input_temp:
272
+ input_audio_path = input_temp.name
273
+ file.save(input_audio_path)
274
+ logger.debug(f"Original audio file saved to {input_audio_path}")
275
+
276
+ # Create a temporary file for the converted WAV
277
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as output_temp:
278
+ converted_audio_path = output_temp.name
279
+
280
+ # Convert to WAV with ffmpeg (16kHz, mono)
281
+ logger.debug(f"Converting audio to 16kHz mono WAV format...")
282
+ conversion_command = [
283
+ 'ffmpeg',
284
+ '-y', # Force overwrite without prompting
285
+ '-i', input_audio_path,
286
+ '-acodec', 'pcm_s16le', # 16-bit PCM
287
+ '-ac', '1', # mono
288
+ '-ar', '16000', # 16kHz sample rate
289
+ '-af', 'highpass=f=80,lowpass=f=7500,afftdn=nr=10:nf=-25,loudnorm=I=-16:TP=-1.5:LRA=11', # Audio cleanup filters
290
+ converted_audio_path
291
+ ]
292
+ result = subprocess.run(
293
+ conversion_command,
294
+ stdout=subprocess.PIPE,
295
+ stderr=subprocess.PIPE,
296
+ text=True
297
+ )
298
+
299
+ if result.returncode != 0:
300
+ logger.error(f"FFmpeg conversion error: {result.stderr}")
301
+ raise Exception(f"Audio conversion failed: {result.stderr}")
302
+
303
+ logger.debug(f"Audio successfully converted to {converted_audio_path}")
304
+
305
+ # Load and process the converted audio
306
+ logger.debug("Processing audio for transcription...")
307
+ audio_array, sampling_rate = librosa.load(converted_audio_path, sr=16000)
308
+
309
+ input_features = processor(
310
+ audio_array,
311
+ sampling_rate=sampling_rate,
312
+ return_tensors="pt"
313
+ ).input_features
314
+
315
+ # Generate transcription
316
+ logger.debug("Generating transcription...")
317
+ predicted_ids = whisper_model.generate(input_features)
318
+ transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
319
+ logger.info(f"Transcription: {transcription}")
320
+
321
+ return jsonify({"status": "success", "transcription": transcription})
322
+ except Exception as e:
323
+ logger.error(f"Error transcribing audio: {str(e)}")
324
+ return jsonify({"status": "error", "message": str(e)}), 500
325
+ finally:
326
+ # Clean up temporary files
327
+ for path in [input_audio_path, converted_audio_path]:
328
+ if path and os.path.exists(path):
329
+ try:
330
+ os.remove(path)
331
+ logger.debug(f"Temporary file {path} removed")
332
+ except Exception as e:
333
+ logger.warning(f"Failed to remove temporary file {path}: {e}")
334
+
335
+ @app.route('/files/<filename>', methods=['GET'])
336
+ def serve_wav_file(filename):
337
+ """
338
+ Serve a .wav file from the configured directory.
339
+ Only serves files ending with '.wav'.
340
+ """
341
+ # Ensure only .wav files are allowed
342
+ if not filename.lower().endswith('.wav'):
343
+ abort(400, "Only .wav files are allowed.")
344
+
345
+ # Check if the file exists in the directory
346
+ file_path = os.path.join(SERVE_DIR, filename)
347
+ logger.debug(f"Looking for file at: {file_path}")
348
+ if not os.path.isfile(file_path):
349
+ logger.error(f"File not found: {file_path}")
350
+ abort(404, "File not found.")
351
+
352
+ # Serve the file
353
+ return send_from_directory(SERVE_DIR, filename)
354
+
355
+ # Error handlers
356
+ @app.errorhandler(400)
357
+ def bad_request(error):
358
+ """Handle 400 errors."""
359
+ return {"error": "Bad Request", "message": str(error)}, 400
360
+
361
+ @app.errorhandler(404)
362
+ def not_found(error):
363
+ """Handle 404 errors."""
364
+ return {"error": "Not Found", "message": str(error)}, 404
365
+
366
+ @app.errorhandler(500)
367
+ def internal_error(error):
368
+ """Handle unexpected errors."""
369
+ return {"error": "Internal Server Error", "message": "An unexpected error occurred."}, 500
370
+
371
+ if __name__ == "__main__":
372
+ app.run(host="0.0.0.0", port=7860, threaded=False, processes=1)
373
+
commit ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ git add .
2
+ git commit -m "$*"
3
+ git push
kokoro.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import phonemizer
2
+ import re
3
+ import torch
4
+ import numpy as np
5
+
6
+ def split_num(num):
7
+ num = num.group()
8
+ if '.' in num:
9
+ return num
10
+ elif ':' in num:
11
+ h, m = [int(n) for n in num.split(':')]
12
+ if m == 0:
13
+ return f"{h} o'clock"
14
+ elif m < 10:
15
+ return f'{h} oh {m}'
16
+ return f'{h} {m}'
17
+ year = int(num[:4])
18
+ if year < 1100 or year % 1000 < 10:
19
+ return num
20
+ left, right = num[:2], int(num[2:4])
21
+ s = 's' if num.endswith('s') else ''
22
+ if 100 <= year % 1000 <= 999:
23
+ if right == 0:
24
+ return f'{left} hundred{s}'
25
+ elif right < 10:
26
+ return f'{left} oh {right}{s}'
27
+ return f'{left} {right}{s}'
28
+
29
+ def flip_money(m):
30
+ m = m.group()
31
+ bill = 'dollar' if m[0] == '$' else 'pound'
32
+ if m[-1].isalpha():
33
+ return f'{m[1:]} {bill}s'
34
+ elif '.' not in m:
35
+ s = '' if m[1:] == '1' else 's'
36
+ return f'{m[1:]} {bill}{s}'
37
+ b, c = m[1:].split('.')
38
+ s = '' if b == '1' else 's'
39
+ c = int(c.ljust(2, '0'))
40
+ coins = f"cent{'' if c == 1 else 's'}" if m[0] == '$' else ('penny' if c == 1 else 'pence')
41
+ return f'{b} {bill}{s} and {c} {coins}'
42
+
43
+ def point_num(num):
44
+ a, b = num.group().split('.')
45
+ return ' point '.join([a, ' '.join(b)])
46
+
47
+ def normalize_text(text):
48
+ text = text.replace(chr(8216), "'").replace(chr(8217), "'")
49
+ text = text.replace('«', chr(8220)).replace('»', chr(8221))
50
+ text = text.replace(chr(8220), '"').replace(chr(8221), '"')
51
+ text = text.replace('(', '«').replace(')', '»')
52
+ for a, b in zip('、。!,:;?', ',.!,:;?'):
53
+ text = text.replace(a, b+' ')
54
+ text = re.sub(r'[^\S \n]', ' ', text)
55
+ text = re.sub(r' +', ' ', text)
56
+ text = re.sub(r'(?<=\n) +(?=\n)', '', text)
57
+ text = re.sub(r'\bD[Rr]\.(?= [A-Z])', 'Doctor', text)
58
+ text = re.sub(r'\b(?:Mr\.|MR\.(?= [A-Z]))', 'Mister', text)
59
+ text = re.sub(r'\b(?:Ms\.|MS\.(?= [A-Z]))', 'Miss', text)
60
+ text = re.sub(r'\b(?:Mrs\.|MRS\.(?= [A-Z]))', 'Mrs', text)
61
+ text = re.sub(r'\betc\.(?! [A-Z])', 'etc', text)
62
+ text = re.sub(r'(?i)\b(y)eah?\b', r"\1e'a", text)
63
+ text = re.sub(r'\d*\.\d+|\b\d{4}s?\b|(?<!:)\b(?:[1-9]|1[0-2]):[0-5]\d\b(?!:)', split_num, text)
64
+ text = re.sub(r'(?<=\d),(?=\d)', '', text)
65
+ text = re.sub(r'(?i)[$£]\d+(?:\.\d+)?(?: hundred| thousand| (?:[bm]|tr)illion)*\b|[$£]\d+\.\d\d?\b', flip_money, text)
66
+ text = re.sub(r'\d*\.\d+', point_num, text)
67
+ text = re.sub(r'(?<=\d)-(?=\d)', ' to ', text)
68
+ text = re.sub(r'(?<=\d)S', ' S', text)
69
+ text = re.sub(r"(?<=[BCDFGHJ-NP-TV-Z])'?s\b", "'S", text)
70
+ text = re.sub(r"(?<=X')S\b", 's', text)
71
+ text = re.sub(r'(?:[A-Za-z]\.){2,} [a-z]', lambda m: m.group().replace('.', '-'), text)
72
+ text = re.sub(r'(?i)(?<=[A-Z])\.(?=[A-Z])', '-', text)
73
+ return text.strip()
74
+
75
+ def get_vocab():
76
+ _pad = "$"
77
+ _punctuation = ';:,.!?¡¿—…"«»“” '
78
+ _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
79
+ _letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
80
+ symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa)
81
+ dicts = {}
82
+ for i in range(len((symbols))):
83
+ dicts[symbols[i]] = i
84
+ return dicts
85
+
86
+ VOCAB = get_vocab()
87
+ def tokenize(ps):
88
+ return [i for i in map(VOCAB.get, ps) if i is not None]
89
+
90
+ phonemizers = dict(
91
+ a=phonemizer.backend.EspeakBackend(language='en-us', preserve_punctuation=True, with_stress=True),
92
+ b=phonemizer.backend.EspeakBackend(language='en-gb', preserve_punctuation=True, with_stress=True),
93
+ )
94
+ def phonemize(text, lang, norm=True):
95
+ if norm:
96
+ text = normalize_text(text)
97
+ ps = phonemizers[lang].phonemize([text])
98
+ ps = ps[0] if ps else ''
99
+ # https://en.wiktionary.org/wiki/kokoro#English
100
+ ps = ps.replace('kəkˈoːɹoʊ', 'kˈoʊkəɹoʊ').replace('kəkˈɔːɹəʊ', 'kˈəʊkəɹəʊ')
101
+ ps = ps.replace('ʲ', 'j').replace('r', 'ɹ').replace('x', 'k').replace('ɬ', 'l')
102
+ ps = re.sub(r'(?<=[a-zɹː])(?=hˈʌndɹɪd)', ' ', ps)
103
+ ps = re.sub(r' z(?=[;:,.!?¡¿—…"«»“” ]|$)', 'z', ps)
104
+ if lang == 'a':
105
+ ps = re.sub(r'(?<=nˈaɪn)ti(?!ː)', 'di', ps)
106
+ ps = ''.join(filter(lambda p: p in VOCAB, ps))
107
+ return ps.strip()
108
+
109
+ def length_to_mask(lengths):
110
+ mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
111
+ mask = torch.gt(mask+1, lengths.unsqueeze(1))
112
+ return mask
113
+
114
+ @torch.no_grad()
115
+ def forward(model, tokens, ref_s, speed):
116
+ device = ref_s.device
117
+ tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)
118
+ input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
119
+ text_mask = length_to_mask(input_lengths).to(device)
120
+ bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
121
+ d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
122
+ s = ref_s[:, 128:]
123
+ d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask)
124
+ x, _ = model.predictor.lstm(d)
125
+ duration = model.predictor.duration_proj(x)
126
+ duration = torch.sigmoid(duration).sum(axis=-1) / speed
127
+ pred_dur = torch.round(duration).clamp(min=1).long()
128
+ pred_aln_trg = torch.zeros(input_lengths, pred_dur.sum().item())
129
+ c_frame = 0
130
+ for i in range(pred_aln_trg.size(0)):
131
+ pred_aln_trg[i, c_frame:c_frame + pred_dur[0,i].item()] = 1
132
+ c_frame += pred_dur[0,i].item()
133
+ en = d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device)
134
+ F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
135
+ t_en = model.text_encoder(tokens, input_lengths, text_mask)
136
+ asr = t_en @ pred_aln_trg.unsqueeze(0).to(device)
137
+ return model.decoder(asr, F0_pred, N_pred, ref_s[:, :128]).squeeze().cpu().numpy()
138
+
139
+ def generate(model, text, voicepack, lang='a', speed=1, ps=None):
140
+ ps = ps or phonemize(text, lang)
141
+ tokens = tokenize(ps)
142
+ if not tokens:
143
+ return None
144
+ elif len(tokens) > 510:
145
+ tokens = tokens[:510]
146
+ print('Truncated to 510 tokens')
147
+ ref_s = voicepack[len(tokens)]
148
+ out = forward(model, tokens, ref_s, speed)
149
+ ps = ''.join(next(k for k, v in VOCAB.items() if i == v) for i in tokens)
150
+ return out, ps
151
+
152
+ def generate_full(model, text, voicepack, lang='a', speed=1, ps=None):
153
+ ps = ps or phonemize(text, lang)
154
+ tokens = tokenize(ps)
155
+ if not tokens:
156
+ return None
157
+ outs = []
158
+ loop_count = len(tokens)//510 + (1 if len(tokens) % 510 != 0 else 0)
159
+ for i in range(loop_count):
160
+ ref_s = voicepack[len(tokens[i*510:(i+1)*510])]
161
+ out = forward(model, tokens[i*510:(i+1)*510], ref_s, speed)
162
+ outs.append(out)
163
+ outs = np.concatenate(outs)
164
+ ps = ''.join(next(k for k, v in VOCAB.items() if i == v) for i in tokens)
165
+ return outs, ps
packages.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ ffmpeg
2
+ libsndfile1
3
+ espeak-ng
4
+ espeak-ng-data
5
+
requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ flask
2
+ flask-cors
3
+ transformers
4
+ librosa
5
+ numpy
6
+ soundfile
7
+ huggingface_hub
8
+ phonemizer
9
+ munch
10
+ werkzeug
11
+ num2words
12
+ dateparser
13
+ inflect
14
+ ftfy
15
+ sentencepiece
16
+ torch --index-url https://download.pytorch.org/whl/cpu
17
+ onnxruntime
tts_processor.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from dateutil.parser import parse
3
+ from num2words import num2words
4
+ import inflect
5
+ from ftfy import fix_text
6
+
7
+ # Initialize the inflect engine
8
+ inflect_engine = inflect.engine()
9
+
10
+ # Define alphabet pronunciation mapping
11
+ alphabet_map = {
12
+ "A": " Eh ", "B": " Bee ", "C": " See ", "D": " Dee ", "E": " Eee ",
13
+ "F": " Eff ", "G": " Jee ", "H": " Aitch ", "I": " Eye ", "J": " Jay ",
14
+ "K": " Kay ", "L": " El ", "M": " Emm ", "N": " Enn ", "O": " Ohh ",
15
+ "P": " Pee ", "Q": " Queue ", "R": " Are ", "S": " Ess ", "T": " Tee ",
16
+ "U": " You ", "V": " Vee ", "W": " Double You ", "X": " Ex ", "Y": " Why ", "Z": " Zed "
17
+ }
18
+
19
+ # Function to add ordinal suffix to a number
20
+ def add_ordinal_suffix(day):
21
+ """Adds ordinal suffix to a day (e.g., 13 -> 13th)."""
22
+ if 11 <= day <= 13: # Special case for 11th, 12th, 13th
23
+ return f"{day}th"
24
+ elif day % 10 == 1:
25
+ return f"{day}st"
26
+ elif day % 10 == 2:
27
+ return f"{day}nd"
28
+ elif day % 10 == 3:
29
+ return f"{day}rd"
30
+ else:
31
+ return f"{day}th"
32
+
33
+ # Function to format dates in a human-readable form
34
+ def format_date(parsed_date, include_time=True):
35
+ """Formats a parsed date into a human-readable string."""
36
+ if not parsed_date:
37
+ return None
38
+
39
+ # Convert the day into an ordinal (e.g., 13 -> 13th)
40
+ day = add_ordinal_suffix(parsed_date.day)
41
+
42
+ # Format the date in a TTS-friendly way
43
+ if include_time and parsed_date.hour != 0 and parsed_date.minute != 0:
44
+ return parsed_date.strftime(f"%B {day}, %Y at %-I:%M %p") # Unix
45
+ return parsed_date.strftime(f"%B {day}, %Y") # Only date
46
+
47
+ # Normalize dates in the text
48
+ def normalize_dates(text):
49
+ """
50
+ Finds and replaces date strings with a nicely formatted, TTS-friendly version.
51
+ """
52
+ def replace_date(match):
53
+ raw_date = match.group(0)
54
+ try:
55
+ parsed_date = parse(raw_date)
56
+ if parsed_date:
57
+ include_time = "T" in raw_date or " " in raw_date # Include time only if explicitly provided
58
+ return format_date(parsed_date, include_time)
59
+ except ValueError:
60
+ pass
61
+ return raw_date
62
+
63
+ # Match common date formats
64
+ date_pattern = r"\b(\d{4}-\d{2}-\d{2}(?:[ T]\d{2}:\d{2}:\d{2})?|\d{2}/\d{2}/\d{4}|\d{1,2} \w+ \d{4})\b"
65
+ return re.sub(date_pattern, replace_date, text)
66
+
67
+ # Replace invalid characters and clean text
68
+ def replace_invalid_chars(string):
69
+ string = fix_text(string)
70
+ replacements = {
71
+ "**": "",
72
+ '&#x27;': "'",
73
+ 'AI;': 'Artificial Intelligence!',
74
+ 'iddqd;': 'Immortality cheat code',
75
+ '😉;': 'wink wink!',
76
+ ':D': '*laughs* Ahahaha!',
77
+ ';D': '*laughs* Ahahaha!'
78
+ }
79
+ for old, new in replacements.items():
80
+ string = string.replace(old, new)
81
+ return string
82
+
83
+ # Replace numbers with their word equivalents
84
+ def replace_numbers(string):
85
+ ipv4_pattern = r'(\b\d{1,3}(\.\d{1,3}){3}\b)'
86
+ ipv6_pattern = r'([0-9a-fA-F]{1,4}:){2,7}[0-9a-fA-F]{1,4}'
87
+ range_pattern = r'\b\d+-\d+\b' # Detect ranges like 1-4
88
+ date_pattern = r'\b\d{4}-\d{2}-\d{2}(?:T\d{2}:\d{2}:\d{2})?\b'
89
+ alphanumeric_pattern = r'\b[A-Za-z]+\d+|\d+[A-Za-z]+\b'
90
+
91
+ # Do not process IP addresses, date patterns, or alphanumerics
92
+ if re.search(ipv4_pattern, string) or re.search(ipv6_pattern, string) or re.search(range_pattern, string) or re.search(date_pattern, string) or re.search(alphanumeric_pattern, string):
93
+ return string
94
+
95
+ # Convert standalone numbers and port numbers
96
+ def convert_number(match):
97
+ number = match.group()
98
+ return num2words(int(number)) if number.isdigit() else number
99
+
100
+ pattern = re.compile(r'\b\d+\b')
101
+ return re.sub(pattern, convert_number, string)
102
+
103
+ # Replace abbreviations with expanded form
104
+ def replace_abbreviations(string):
105
+ words = string.split()
106
+ for i, word in enumerate(words):
107
+ if word.isupper() and len(word) <= 4 and not any(char.isdigit() for char in word) and word not in ["ID", "AM", "PM"]:
108
+ words[i] = ''.join([alphabet_map.get(char, char) for char in word])
109
+ return ' '.join(words)
110
+
111
+ def clean_whitespace(string):
112
+ # Remove spaces before punctuation
113
+ string = re.sub(r'\s+([.,?!])', r'\1', string)
114
+ # Collapse multiple spaces into one, but don’t touch inside tokens like "test.com"
115
+ string = re.sub(r'\s{2,}', ' ', string)
116
+ return string.strip()
117
+
118
+ def make_dots_tts_friendly(text):
119
+ # Handle IP addresses (force "dot")
120
+ ipv4_pattern = r'\b\d{1,3}(\.\d{1,3}){3}\b'
121
+ text = re.sub(ipv4_pattern, lambda m: m.group(0).replace('.', ' dot '), text)
122
+
123
+ # Handle domain-like endings (force "dot")
124
+ domain_pattern = r'\b([\w-]+)\.(com|net|org|io|gov|edu|exe|dll|local)\b'
125
+ text = re.sub(domain_pattern, lambda m: m.group(0).replace('.', ' dot '), text)
126
+
127
+ # Handle decimals (use "point")
128
+ decimal_pattern = r'\b\d+\.\d+\b'
129
+ text = re.sub(decimal_pattern, lambda m: m.group(0).replace('.', ' point '), text)
130
+
131
+ # Handle leading dot words (.Net → dot Net)
132
+ text = re.sub(r'\.(?=\w)', 'dot ', text)
133
+
134
+ return text
135
+
136
+ # Main preprocessing pipeline
137
+ def preprocess_all(string):
138
+ string = normalize_dates(string)
139
+ string = replace_invalid_chars(string)
140
+ string = replace_numbers(string)
141
+ string = replace_abbreviations(string)
142
+ string = make_dots_tts_friendly(string)
143
+ string = clean_whitespace(string)
144
+ return string
145
+
146
+ # Expose a testing function for external use
147
+ def test_preprocessing(file_path):
148
+ with open(file_path, 'r') as file:
149
+ lines = file.readlines()
150
+ for line in lines:
151
+ original = line.strip()
152
+ processed = preprocess_all(original)
153
+ print(f"Original: {original}")
154
+ print(f"Processed: {processed}\n")
155
+
156
+ if __name__ == "__main__":
157
+ import sys
158
+ if len(sys.argv) > 1:
159
+ test_file = sys.argv[1]
160
+ test_preprocessing(test_file)
161
+ else:
162
+ print("Please provide a file path as an argument.")
163
+