Spaces:
Paused
Paused
Create coqui.py
Browse files- TextGen/coqui.py +399 -0
TextGen/coqui.py
ADDED
@@ -0,0 +1,399 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import io, os, stat
|
3 |
+
import subprocess
|
4 |
+
import random
|
5 |
+
from zipfile import ZipFile
|
6 |
+
import uuid
|
7 |
+
import time
|
8 |
+
import torch
|
9 |
+
import torchaudio
|
10 |
+
import numpy as np
|
11 |
+
|
12 |
+
os.system("pip install gradio-client @ git+https://github.com/gradio-app/gradio@bed454c3d22cfacedc047eb3b0ba987b485ac3fd")
|
13 |
+
os.system("pip install git+https://github.com/gradio-app/gradio.git@5.0-dev")
|
14 |
+
#update gradio to faster streaming
|
15 |
+
#download for mecab
|
16 |
+
os.system('python -m unidic download')
|
17 |
+
|
18 |
+
# By using XTTS you agree to CPML license https://coqui.ai/cpml
|
19 |
+
os.environ["COQUI_TOS_AGREED"] = "1"
|
20 |
+
|
21 |
+
# langid is used to detect language for longer text
|
22 |
+
# Most users expect text to be their own language, there is checkbox to disable it
|
23 |
+
import langid
|
24 |
+
import base64
|
25 |
+
import csv
|
26 |
+
from io import StringIO
|
27 |
+
import datetime
|
28 |
+
import re
|
29 |
+
|
30 |
+
import gradio as gr
|
31 |
+
from scipy.io.wavfile import write
|
32 |
+
from pydub import AudioSegment
|
33 |
+
|
34 |
+
from TTS.api import TTS
|
35 |
+
from TTS.tts.configs.xtts_config import XttsConfig
|
36 |
+
from TTS.tts.models.xtts import Xtts
|
37 |
+
from TTS.utils.generic_utils import get_user_data_dir
|
38 |
+
|
39 |
+
HF_TOKEN = os.environ.get("HF_TOKEN")
|
40 |
+
|
41 |
+
from huggingface_hub import HfApi
|
42 |
+
os.system("pip install git+https://github.com/gradio-app/gradio.git@5.0-dev")
|
43 |
+
# will use api to restart space on a unrecoverable error
|
44 |
+
api = HfApi(token=HF_TOKEN)
|
45 |
+
repo_id = "coqui/xtts"
|
46 |
+
|
47 |
+
# Use never ffmpeg binary for Ubuntu20 to use denoising for microphone input
|
48 |
+
print("Export newer ffmpeg binary for denoise filter")
|
49 |
+
ZipFile("ffmpeg.zip").extractall()
|
50 |
+
print("Make ffmpeg binary executable")
|
51 |
+
st = os.stat("ffmpeg")
|
52 |
+
os.chmod("ffmpeg", st.st_mode | stat.S_IEXEC)
|
53 |
+
|
54 |
+
# This will trigger downloading model
|
55 |
+
print("Downloading if not downloaded Coqui XTTS V2")
|
56 |
+
from TTS.utils.manage import ModelManager
|
57 |
+
|
58 |
+
model_name = "tts_models/multilingual/multi-dataset/xtts_v2"
|
59 |
+
ModelManager().download_model(model_name)
|
60 |
+
model_path = os.path.join(get_user_data_dir("tts"), model_name.replace("/", "--"))
|
61 |
+
print("XTTS downloaded")
|
62 |
+
|
63 |
+
config = XttsConfig()
|
64 |
+
config.load_json(os.path.join(model_path, "config.json"))
|
65 |
+
|
66 |
+
model = Xtts.init_from_config(config)
|
67 |
+
model.load_checkpoint(
|
68 |
+
config,
|
69 |
+
checkpoint_path=os.path.join(model_path, "model.pth"),
|
70 |
+
vocab_path=os.path.join(model_path, "vocab.json"),
|
71 |
+
eval=True,
|
72 |
+
use_deepspeed=True,
|
73 |
+
)
|
74 |
+
model.cuda()
|
75 |
+
|
76 |
+
# This is for debugging purposes only
|
77 |
+
DEVICE_ASSERT_DETECTED = 0
|
78 |
+
DEVICE_ASSERT_PROMPT = None
|
79 |
+
DEVICE_ASSERT_LANG = None
|
80 |
+
|
81 |
+
supported_languages = config.languages
|
82 |
+
def numpy_to_mp3(audio_array, sampling_rate):
|
83 |
+
# Normalize audio_array if it's floating-point
|
84 |
+
if np.issubdtype(audio_array.dtype, np.floating):
|
85 |
+
max_val = np.max(np.abs(audio_array))
|
86 |
+
audio_array = (audio_array / max_val) * 32767 # Normalize to 16-bit range
|
87 |
+
audio_array = audio_array.astype(np.int16)
|
88 |
+
|
89 |
+
# Create an audio segment from the numpy array
|
90 |
+
audio_segment = AudioSegment(
|
91 |
+
audio_array.tobytes(),
|
92 |
+
frame_rate=sampling_rate,
|
93 |
+
sample_width=audio_array.dtype.itemsize,
|
94 |
+
channels=1
|
95 |
+
)
|
96 |
+
|
97 |
+
# Export the audio segment to MP3 bytes - use a high bitrate to maximise quality
|
98 |
+
mp3_io = io.BytesIO()
|
99 |
+
audio_segment.export(mp3_io, format="mp3", bitrate="320k")
|
100 |
+
|
101 |
+
# Get the MP3 bytes
|
102 |
+
mp3_bytes = mp3_io.getvalue()
|
103 |
+
mp3_io.close()
|
104 |
+
|
105 |
+
return mp3_bytes
|
106 |
+
|
107 |
+
def predict(
|
108 |
+
prompt,
|
109 |
+
language,
|
110 |
+
audio_file_pth,
|
111 |
+
mic_file_path,
|
112 |
+
use_mic,
|
113 |
+
voice_cleanup,
|
114 |
+
no_lang_auto_detect,
|
115 |
+
agree,
|
116 |
+
):
|
117 |
+
if agree == True:
|
118 |
+
if language not in supported_languages:
|
119 |
+
gr.Warning(
|
120 |
+
f"Language you put {language} in is not in is not in our Supported Languages, please choose from dropdown"
|
121 |
+
)
|
122 |
+
|
123 |
+
return (
|
124 |
+
None,
|
125 |
+
)
|
126 |
+
|
127 |
+
language_predicted = langid.classify(prompt)[
|
128 |
+
0
|
129 |
+
].strip() # strip need as there is space at end!
|
130 |
+
|
131 |
+
# tts expects chinese as zh-cn
|
132 |
+
if language_predicted == "zh":
|
133 |
+
# we use zh-cn
|
134 |
+
language_predicted = "zh-cn"
|
135 |
+
|
136 |
+
print(f"Detected language:{language_predicted}, Chosen language:{language}")
|
137 |
+
|
138 |
+
# After text character length 15 trigger language detection
|
139 |
+
if len(prompt) > 15:
|
140 |
+
# allow any language for short text as some may be common
|
141 |
+
# If user unchecks language autodetection it will not trigger
|
142 |
+
# You may remove this completely for own use
|
143 |
+
if language_predicted != language and not no_lang_auto_detect:
|
144 |
+
# Please duplicate and remove this check if you really want this
|
145 |
+
# Or auto-detector fails to identify language (which it can on pretty short text or mixed text)
|
146 |
+
gr.Warning(
|
147 |
+
f"It looks like your text isn’t the language you chose , if you’re sure the text is the same language you chose, please check disable language auto-detection checkbox"
|
148 |
+
)
|
149 |
+
|
150 |
+
return (
|
151 |
+
None,
|
152 |
+
)
|
153 |
+
|
154 |
+
if use_mic == True:
|
155 |
+
if mic_file_path is not None:
|
156 |
+
speaker_wav = mic_file_path
|
157 |
+
else:
|
158 |
+
gr.Warning(
|
159 |
+
"Please record your voice with Microphone, or uncheck Use Microphone to use reference audios"
|
160 |
+
)
|
161 |
+
return (
|
162 |
+
None,
|
163 |
+
)
|
164 |
+
|
165 |
+
else:
|
166 |
+
speaker_wav = audio_file_pth
|
167 |
+
|
168 |
+
# Filtering for microphone input, as it has BG noise, maybe silence in beginning and end
|
169 |
+
# This is fast filtering not perfect
|
170 |
+
|
171 |
+
# Apply all on demand
|
172 |
+
lowpassfilter = denoise = trim = loudness = True
|
173 |
+
|
174 |
+
if lowpassfilter:
|
175 |
+
lowpass_highpass = "lowpass=8000,highpass=75,"
|
176 |
+
else:
|
177 |
+
lowpass_highpass = ""
|
178 |
+
|
179 |
+
if trim:
|
180 |
+
# better to remove silence in beginning and end for microphone
|
181 |
+
trim_silence = "areverse,silenceremove=start_periods=1:start_silence=0:start_threshold=0.02,areverse,silenceremove=start_periods=1:start_silence=0:start_threshold=0.02,"
|
182 |
+
else:
|
183 |
+
trim_silence = ""
|
184 |
+
|
185 |
+
if voice_cleanup:
|
186 |
+
try:
|
187 |
+
out_filename = (
|
188 |
+
speaker_wav + str(uuid.uuid4()) + ".wav"
|
189 |
+
) # ffmpeg to know output format
|
190 |
+
|
191 |
+
# we will use newer ffmpeg as that has afftn denoise filter
|
192 |
+
shell_command = f"./ffmpeg -y -i {speaker_wav} -af {lowpass_highpass}{trim_silence} {out_filename}".split(
|
193 |
+
" "
|
194 |
+
)
|
195 |
+
|
196 |
+
command_result = subprocess.run(
|
197 |
+
[item for item in shell_command],
|
198 |
+
capture_output=False,
|
199 |
+
text=True,
|
200 |
+
check=True,
|
201 |
+
)
|
202 |
+
speaker_wav = out_filename
|
203 |
+
print("Filtered microphone input")
|
204 |
+
except subprocess.CalledProcessError:
|
205 |
+
# There was an error - command exited with non-zero code
|
206 |
+
print("Error: failed filtering, use original microphone input")
|
207 |
+
else:
|
208 |
+
speaker_wav = speaker_wav
|
209 |
+
|
210 |
+
if len(prompt) < 2:
|
211 |
+
gr.Warning("Please give a longer prompt text")
|
212 |
+
return (
|
213 |
+
None,
|
214 |
+
)
|
215 |
+
if len(prompt) > 1000:
|
216 |
+
gr.Warning(
|
217 |
+
"Text length limited to 200 characters for this demo, please try shorter text. You can clone this space and edit code for your own usage"
|
218 |
+
)
|
219 |
+
return (
|
220 |
+
None,
|
221 |
+
)
|
222 |
+
global DEVICE_ASSERT_DETECTED
|
223 |
+
if DEVICE_ASSERT_DETECTED:
|
224 |
+
global DEVICE_ASSERT_PROMPT
|
225 |
+
global DEVICE_ASSERT_LANG
|
226 |
+
# It will likely never come here as we restart space on first unrecoverable error now
|
227 |
+
print(
|
228 |
+
f"Unrecoverable exception caused by language:{DEVICE_ASSERT_LANG} prompt:{DEVICE_ASSERT_PROMPT}"
|
229 |
+
)
|
230 |
+
|
231 |
+
# HF Space specific.. This error is unrecoverable need to restart space
|
232 |
+
space = api.get_space_runtime(repo_id=repo_id)
|
233 |
+
if space.stage != "BUILDING":
|
234 |
+
api.restart_space(repo_id=repo_id)
|
235 |
+
else:
|
236 |
+
print("TRIED TO RESTART but space is building")
|
237 |
+
|
238 |
+
try:
|
239 |
+
metrics_text = ""
|
240 |
+
t_latent = time.time()
|
241 |
+
|
242 |
+
# note diffusion_conditioning not used on hifigan (default mode), it will be empty but need to pass it to model.inference
|
243 |
+
try:
|
244 |
+
(
|
245 |
+
gpt_cond_latent,
|
246 |
+
speaker_embedding,
|
247 |
+
) = model.get_conditioning_latents(audio_path=speaker_wav, gpt_cond_len=30, gpt_cond_chunk_len=4, max_ref_length=60)
|
248 |
+
except Exception as e:
|
249 |
+
print("Speaker encoding error", str(e))
|
250 |
+
gr.Warning(
|
251 |
+
"It appears something wrong with reference, did you unmute your microphone?"
|
252 |
+
)
|
253 |
+
return (
|
254 |
+
None,
|
255 |
+
)
|
256 |
+
|
257 |
+
latent_calculation_time = time.time() - t_latent
|
258 |
+
# metrics_text=f"Embedding calculation time: {latent_calculation_time:.2f} seconds\n"
|
259 |
+
|
260 |
+
# temporary comma fix
|
261 |
+
prompt = re.sub("([^\x00-\x7F]|\w)(\.|\。|\?)", r"\1 \2\2", prompt)
|
262 |
+
|
263 |
+
wav_chunks = []
|
264 |
+
## Direct mode
|
265 |
+
"""
|
266 |
+
print("I: Generating new audio...")
|
267 |
+
t0 = time.time()
|
268 |
+
out = model.inference(
|
269 |
+
prompt,
|
270 |
+
language,
|
271 |
+
gpt_cond_latent,
|
272 |
+
speaker_embedding,
|
273 |
+
repetition_penalty=5.0,
|
274 |
+
temperature=0.75,
|
275 |
+
)
|
276 |
+
inference_time = time.time() - t0
|
277 |
+
print(f"I: Time to generate audio: {round(inference_time*1000)} milliseconds")
|
278 |
+
metrics_text+=f"Time to generate audio: {round(inference_time*1000)} milliseconds\n"
|
279 |
+
real_time_factor= (time.time() - t0) / out['wav'].shape[-1] * 24000
|
280 |
+
print(f"Real-time factor (RTF): {real_time_factor}")
|
281 |
+
metrics_text+=f"Real-time factor (RTF): {real_time_factor:.2f}\n"
|
282 |
+
torchaudio.save("output.wav", torch.tensor(out["wav"]).unsqueeze(0), 24000)
|
283 |
+
"""
|
284 |
+
print("I: Generating new audio in streaming mode...")
|
285 |
+
t0 = time.time()
|
286 |
+
chunks = model.inference_stream(
|
287 |
+
prompt,
|
288 |
+
language,
|
289 |
+
gpt_cond_latent,
|
290 |
+
speaker_embedding,
|
291 |
+
repetition_penalty=7.0,
|
292 |
+
temperature=0.85,
|
293 |
+
)
|
294 |
+
|
295 |
+
first_chunk = True
|
296 |
+
for i, chunk in enumerate(chunks):
|
297 |
+
if first_chunk:
|
298 |
+
first_chunk_time = time.time() - t0
|
299 |
+
metrics_text += f"Latency to first audio chunk: {round(first_chunk_time*1000)} milliseconds\n"
|
300 |
+
first_chunk = False
|
301 |
+
|
302 |
+
# Convert chunk to numpy array and return it
|
303 |
+
chunk_np = chunk.cpu().numpy()
|
304 |
+
print('chunk',i)
|
305 |
+
yield (24000, chunk_np)
|
306 |
+
wav_chunks.append(chunk)
|
307 |
+
|
308 |
+
print(f"Received chunk {i} of audio length {chunk.shape[-1]}")
|
309 |
+
inference_time = time.time() - t0
|
310 |
+
print(
|
311 |
+
f"I: Time to generate audio: {round(inference_time*1000)} milliseconds"
|
312 |
+
)
|
313 |
+
# metrics_text += (
|
314 |
+
# f"Time to generate audio: {round(inference_time*1000)} milliseconds\n"
|
315 |
+
#)
|
316 |
+
|
317 |
+
except RuntimeError as e:
|
318 |
+
if "device-side assert" in str(e):
|
319 |
+
# cannot do anything on cuda device side error, need tor estart
|
320 |
+
print(
|
321 |
+
f"Exit due to: Unrecoverable exception caused by language:{language} prompt:{prompt}",
|
322 |
+
flush=True,
|
323 |
+
)
|
324 |
+
gr.Warning("Unhandled Exception encounter, please retry in a minute")
|
325 |
+
print("Cuda device-assert Runtime encountered need restart")
|
326 |
+
if not DEVICE_ASSERT_DETECTED:
|
327 |
+
DEVICE_ASSERT_DETECTED = 1
|
328 |
+
DEVICE_ASSERT_PROMPT = prompt
|
329 |
+
DEVICE_ASSERT_LANG = language
|
330 |
+
|
331 |
+
# just before restarting save what caused the issue so we can handle it in future
|
332 |
+
# Uploading Error data only happens for unrecovarable error
|
333 |
+
error_time = datetime.datetime.now().strftime("%d-%m-%Y-%H:%M:%S")
|
334 |
+
error_data = [
|
335 |
+
error_time,
|
336 |
+
prompt,
|
337 |
+
language,
|
338 |
+
audio_file_pth,
|
339 |
+
mic_file_path,
|
340 |
+
use_mic,
|
341 |
+
voice_cleanup,
|
342 |
+
no_lang_auto_detect,
|
343 |
+
agree,
|
344 |
+
]
|
345 |
+
error_data = [str(e) if type(e) != str else e for e in error_data]
|
346 |
+
print(error_data)
|
347 |
+
print(speaker_wav)
|
348 |
+
write_io = StringIO()
|
349 |
+
csv.writer(write_io).writerows([error_data])
|
350 |
+
csv_upload = write_io.getvalue().encode()
|
351 |
+
|
352 |
+
filename = error_time + "_" + str(uuid.uuid4()) + ".csv"
|
353 |
+
print("Writing error csv")
|
354 |
+
error_api = HfApi()
|
355 |
+
error_api.upload_file(
|
356 |
+
path_or_fileobj=csv_upload,
|
357 |
+
path_in_repo=filename,
|
358 |
+
repo_id="coqui/xtts-flagged-dataset",
|
359 |
+
repo_type="dataset",
|
360 |
+
)
|
361 |
+
|
362 |
+
# speaker_wav
|
363 |
+
print("Writing error reference audio")
|
364 |
+
speaker_filename = (
|
365 |
+
error_time + "_reference_" + str(uuid.uuid4()) + ".wav"
|
366 |
+
)
|
367 |
+
error_api = HfApi()
|
368 |
+
error_api.upload_file(
|
369 |
+
path_or_fileobj=speaker_wav,
|
370 |
+
path_in_repo=speaker_filename,
|
371 |
+
repo_id="coqui/xtts-flagged-dataset",
|
372 |
+
repo_type="dataset",
|
373 |
+
)
|
374 |
+
|
375 |
+
# HF Space specific.. This error is unrecoverable need to restart space
|
376 |
+
space = api.get_space_runtime(repo_id=repo_id)
|
377 |
+
if space.stage != "BUILDING":
|
378 |
+
api.restart_space(repo_id=repo_id)
|
379 |
+
else:
|
380 |
+
print("TRIED TO RESTART but space is building")
|
381 |
+
|
382 |
+
else:
|
383 |
+
if "Failed to decode" in str(e):
|
384 |
+
print("Speaker encoding error", str(e))
|
385 |
+
gr.Warning(
|
386 |
+
"It appears something wrong with reference, did you unmute your microphone?"
|
387 |
+
)
|
388 |
+
else:
|
389 |
+
print("RuntimeError: non device-side assert error:", str(e))
|
390 |
+
gr.Warning("Something unexpected happened please retry again.")
|
391 |
+
return (
|
392 |
+
None,
|
393 |
+
)
|
394 |
+
|
395 |
+
else:
|
396 |
+
gr.Warning("Please accept the Terms & Condition!")
|
397 |
+
return (
|
398 |
+
None,
|
399 |
+
)
|