Spaces:
Running
on
T4
Running
on
T4
fix cpu and add hub download
Browse files- tortoise/api.py +6 -38
tortoise/api.py
CHANGED
@@ -38,45 +38,13 @@ MODELS = {
|
|
38 |
'hifidecoder.pth': 'https://huggingface.co/Manmay/tortoise-tts/resolve/main/hifidecoder.pth',
|
39 |
}
|
40 |
|
41 |
-
def download_models(specific_models=None):
|
42 |
-
"""
|
43 |
-
Call to download all the models that Tortoise uses.
|
44 |
-
"""
|
45 |
-
os.makedirs(MODELS_DIR, exist_ok=True)
|
46 |
-
|
47 |
-
def show_progress(block_num, block_size, total_size):
|
48 |
-
global pbar
|
49 |
-
if pbar is None:
|
50 |
-
pbar = progressbar.ProgressBar(maxval=total_size)
|
51 |
-
pbar.start()
|
52 |
-
|
53 |
-
downloaded = block_num * block_size
|
54 |
-
if downloaded < total_size:
|
55 |
-
pbar.update(downloaded)
|
56 |
-
else:
|
57 |
-
pbar.finish()
|
58 |
-
pbar = None
|
59 |
-
for model_name, url in MODELS.items():
|
60 |
-
if specific_models is not None and model_name not in specific_models:
|
61 |
-
continue
|
62 |
-
model_path = os.path.join(MODELS_DIR, model_name)
|
63 |
-
if os.path.exists(model_path):
|
64 |
-
continue
|
65 |
-
print(f'Downloading {model_name} from {url}...')
|
66 |
-
request.urlretrieve(url, model_path, show_progress)
|
67 |
-
# hf_hub_download(repo_id="Manmay/tortoise-tts", filename=model_name, cache_dir=MODELS_DIR)
|
68 |
-
print('Done.')
|
69 |
-
|
70 |
-
|
71 |
def get_model_path(model_name, models_dir=MODELS_DIR):
|
72 |
"""
|
73 |
Get path to given model, download it if it doesn't exist.
|
74 |
"""
|
75 |
if model_name not in MODELS:
|
76 |
raise ValueError(f'Model {model_name} not found in available models.')
|
77 |
-
model_path =
|
78 |
-
if not os.path.exists(model_path) and models_dir == MODELS_DIR:
|
79 |
-
download_models([model_name])
|
80 |
return model_path
|
81 |
|
82 |
|
@@ -243,14 +211,14 @@ class TextToSpeech:
|
|
243 |
self.autoregressive = UnifiedVoice(max_mel_tokens=604, max_text_tokens=402, max_conditioning_inputs=2, layers=30,
|
244 |
model_dim=1024,
|
245 |
heads=16, number_text_tokens=255, start_text_token=255, checkpointing=False,
|
246 |
-
train_solo_embeddings=False).
|
247 |
self.autoregressive.load_state_dict(torch.load(get_model_path('autoregressive.pth', models_dir)), strict=False)
|
248 |
self.autoregressive.post_init_gpt2_config(use_deepspeed=use_deepspeed, kv_cache=kv_cache, half=self.half)
|
249 |
|
250 |
self.hifi_decoder = HifiganGenerator(in_channels=1024, out_channels = 1, resblock_type = "1",
|
251 |
resblock_dilation_sizes = [[1, 3, 5], [1, 3, 5], [1, 3, 5]], resblock_kernel_sizes = [3, 7, 11],
|
252 |
upsample_kernel_sizes = [16, 16, 4, 4], upsample_initial_channel = 512, upsample_factors = [8, 8, 2, 2],
|
253 |
-
cond_channels=1024).
|
254 |
hifi_model = torch.load(get_model_path('hifidecoder.pth'))
|
255 |
self.hifi_decoder.load_state_dict(hifi_model, strict=False)
|
256 |
# Random latent generators (RLGs) are loaded lazily.
|
@@ -309,7 +277,7 @@ class TextToSpeech:
|
|
309 |
settings.update(kwargs) # allow overriding of preset settings with kwargs
|
310 |
for audio_frame in self.tts(text, **settings):
|
311 |
yield audio_frame
|
312 |
-
|
313 |
def handle_chunks(self, wav_gen, wav_gen_prev, wav_overlap, overlap_len):
|
314 |
"""Handle chunk formatting in streaming mode"""
|
315 |
wav_chunk = wav_gen[:-overlap_len]
|
@@ -413,7 +381,7 @@ class TextToSpeech:
|
|
413 |
wav_gen_prev = None
|
414 |
wav_overlap = None
|
415 |
is_end = False
|
416 |
-
first_buffer =
|
417 |
while not is_end:
|
418 |
try:
|
419 |
with torch.autocast(
|
@@ -428,7 +396,7 @@ class TextToSpeech:
|
|
428 |
if is_end or (stream_chunk_size > 0 and len(codes_) >= max(stream_chunk_size, first_buffer)):
|
429 |
first_buffer = 0
|
430 |
gpt_latents = torch.cat(all_latents, dim=0)[None, :]
|
431 |
-
wav_gen = self.hifi_decoder.inference(gpt_latents.
|
432 |
wav_gen = wav_gen.squeeze()
|
433 |
wav_chunk, wav_gen_prev, wav_overlap = self.handle_chunks(
|
434 |
wav_gen.squeeze(), wav_gen_prev, wav_overlap, overlap_wav_len
|
|
|
38 |
'hifidecoder.pth': 'https://huggingface.co/Manmay/tortoise-tts/resolve/main/hifidecoder.pth',
|
39 |
}
|
40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
def get_model_path(model_name, models_dir=MODELS_DIR):
|
42 |
"""
|
43 |
Get path to given model, download it if it doesn't exist.
|
44 |
"""
|
45 |
if model_name not in MODELS:
|
46 |
raise ValueError(f'Model {model_name} not found in available models.')
|
47 |
+
model_path = hf_hub_download(repo_id="Manmay/tortoise-tts", filename=model_name, cache_dir=MODELS_DIR)
|
|
|
|
|
48 |
return model_path
|
49 |
|
50 |
|
|
|
211 |
self.autoregressive = UnifiedVoice(max_mel_tokens=604, max_text_tokens=402, max_conditioning_inputs=2, layers=30,
|
212 |
model_dim=1024,
|
213 |
heads=16, number_text_tokens=255, start_text_token=255, checkpointing=False,
|
214 |
+
train_solo_embeddings=False).to(self.device).eval()
|
215 |
self.autoregressive.load_state_dict(torch.load(get_model_path('autoregressive.pth', models_dir)), strict=False)
|
216 |
self.autoregressive.post_init_gpt2_config(use_deepspeed=use_deepspeed, kv_cache=kv_cache, half=self.half)
|
217 |
|
218 |
self.hifi_decoder = HifiganGenerator(in_channels=1024, out_channels = 1, resblock_type = "1",
|
219 |
resblock_dilation_sizes = [[1, 3, 5], [1, 3, 5], [1, 3, 5]], resblock_kernel_sizes = [3, 7, 11],
|
220 |
upsample_kernel_sizes = [16, 16, 4, 4], upsample_initial_channel = 512, upsample_factors = [8, 8, 2, 2],
|
221 |
+
cond_channels=1024).to(self.device).eval()
|
222 |
hifi_model = torch.load(get_model_path('hifidecoder.pth'))
|
223 |
self.hifi_decoder.load_state_dict(hifi_model, strict=False)
|
224 |
# Random latent generators (RLGs) are loaded lazily.
|
|
|
277 |
settings.update(kwargs) # allow overriding of preset settings with kwargs
|
278 |
for audio_frame in self.tts(text, **settings):
|
279 |
yield audio_frame
|
280 |
+
|
281 |
def handle_chunks(self, wav_gen, wav_gen_prev, wav_overlap, overlap_len):
|
282 |
"""Handle chunk formatting in streaming mode"""
|
283 |
wav_chunk = wav_gen[:-overlap_len]
|
|
|
381 |
wav_gen_prev = None
|
382 |
wav_overlap = None
|
383 |
is_end = False
|
384 |
+
first_buffer = 40
|
385 |
while not is_end:
|
386 |
try:
|
387 |
with torch.autocast(
|
|
|
396 |
if is_end or (stream_chunk_size > 0 and len(codes_) >= max(stream_chunk_size, first_buffer)):
|
397 |
first_buffer = 0
|
398 |
gpt_latents = torch.cat(all_latents, dim=0)[None, :]
|
399 |
+
wav_gen = self.hifi_decoder.inference(gpt_latents.to(self.device), auto_conditioning)
|
400 |
wav_gen = wav_gen.squeeze()
|
401 |
wav_chunk, wav_gen_prev, wav_overlap = self.handle_chunks(
|
402 |
wav_gen.squeeze(), wav_gen_prev, wav_overlap, overlap_wav_len
|