Manmay commited on
Commit
fc677b5
·
1 Parent(s): 0054bb4

fix cpu and add hub download

Browse files
Files changed (1) hide show
  1. tortoise/api.py +6 -38
tortoise/api.py CHANGED
@@ -38,45 +38,13 @@ MODELS = {
38
  'hifidecoder.pth': 'https://huggingface.co/Manmay/tortoise-tts/resolve/main/hifidecoder.pth',
39
  }
40
 
41
- def download_models(specific_models=None):
42
- """
43
- Call to download all the models that Tortoise uses.
44
- """
45
- os.makedirs(MODELS_DIR, exist_ok=True)
46
-
47
- def show_progress(block_num, block_size, total_size):
48
- global pbar
49
- if pbar is None:
50
- pbar = progressbar.ProgressBar(maxval=total_size)
51
- pbar.start()
52
-
53
- downloaded = block_num * block_size
54
- if downloaded < total_size:
55
- pbar.update(downloaded)
56
- else:
57
- pbar.finish()
58
- pbar = None
59
- for model_name, url in MODELS.items():
60
- if specific_models is not None and model_name not in specific_models:
61
- continue
62
- model_path = os.path.join(MODELS_DIR, model_name)
63
- if os.path.exists(model_path):
64
- continue
65
- print(f'Downloading {model_name} from {url}...')
66
- request.urlretrieve(url, model_path, show_progress)
67
- # hf_hub_download(repo_id="Manmay/tortoise-tts", filename=model_name, cache_dir=MODELS_DIR)
68
- print('Done.')
69
-
70
-
71
  def get_model_path(model_name, models_dir=MODELS_DIR):
72
  """
73
  Get path to given model, download it if it doesn't exist.
74
  """
75
  if model_name not in MODELS:
76
  raise ValueError(f'Model {model_name} not found in available models.')
77
- model_path = os.path.join(models_dir, model_name)
78
- if not os.path.exists(model_path) and models_dir == MODELS_DIR:
79
- download_models([model_name])
80
  return model_path
81
 
82
 
@@ -243,14 +211,14 @@ class TextToSpeech:
243
  self.autoregressive = UnifiedVoice(max_mel_tokens=604, max_text_tokens=402, max_conditioning_inputs=2, layers=30,
244
  model_dim=1024,
245
  heads=16, number_text_tokens=255, start_text_token=255, checkpointing=False,
246
- train_solo_embeddings=False).cuda().eval()
247
  self.autoregressive.load_state_dict(torch.load(get_model_path('autoregressive.pth', models_dir)), strict=False)
248
  self.autoregressive.post_init_gpt2_config(use_deepspeed=use_deepspeed, kv_cache=kv_cache, half=self.half)
249
 
250
  self.hifi_decoder = HifiganGenerator(in_channels=1024, out_channels = 1, resblock_type = "1",
251
  resblock_dilation_sizes = [[1, 3, 5], [1, 3, 5], [1, 3, 5]], resblock_kernel_sizes = [3, 7, 11],
252
  upsample_kernel_sizes = [16, 16, 4, 4], upsample_initial_channel = 512, upsample_factors = [8, 8, 2, 2],
253
- cond_channels=1024).cuda().eval()
254
  hifi_model = torch.load(get_model_path('hifidecoder.pth'))
255
  self.hifi_decoder.load_state_dict(hifi_model, strict=False)
256
  # Random latent generators (RLGs) are loaded lazily.
@@ -309,7 +277,7 @@ class TextToSpeech:
309
  settings.update(kwargs) # allow overriding of preset settings with kwargs
310
  for audio_frame in self.tts(text, **settings):
311
  yield audio_frame
312
-
313
  def handle_chunks(self, wav_gen, wav_gen_prev, wav_overlap, overlap_len):
314
  """Handle chunk formatting in streaming mode"""
315
  wav_chunk = wav_gen[:-overlap_len]
@@ -413,7 +381,7 @@ class TextToSpeech:
413
  wav_gen_prev = None
414
  wav_overlap = None
415
  is_end = False
416
- first_buffer = 60
417
  while not is_end:
418
  try:
419
  with torch.autocast(
@@ -428,7 +396,7 @@ class TextToSpeech:
428
  if is_end or (stream_chunk_size > 0 and len(codes_) >= max(stream_chunk_size, first_buffer)):
429
  first_buffer = 0
430
  gpt_latents = torch.cat(all_latents, dim=0)[None, :]
431
- wav_gen = self.hifi_decoder.inference(gpt_latents.cuda(), auto_conditioning)
432
  wav_gen = wav_gen.squeeze()
433
  wav_chunk, wav_gen_prev, wav_overlap = self.handle_chunks(
434
  wav_gen.squeeze(), wav_gen_prev, wav_overlap, overlap_wav_len
 
38
  'hifidecoder.pth': 'https://huggingface.co/Manmay/tortoise-tts/resolve/main/hifidecoder.pth',
39
  }
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  def get_model_path(model_name, models_dir=MODELS_DIR):
42
  """
43
  Get path to given model, download it if it doesn't exist.
44
  """
45
  if model_name not in MODELS:
46
  raise ValueError(f'Model {model_name} not found in available models.')
47
+ model_path = hf_hub_download(repo_id="Manmay/tortoise-tts", filename=model_name, cache_dir=MODELS_DIR)
 
 
48
  return model_path
49
 
50
 
 
211
  self.autoregressive = UnifiedVoice(max_mel_tokens=604, max_text_tokens=402, max_conditioning_inputs=2, layers=30,
212
  model_dim=1024,
213
  heads=16, number_text_tokens=255, start_text_token=255, checkpointing=False,
214
+ train_solo_embeddings=False).to(self.device).eval()
215
  self.autoregressive.load_state_dict(torch.load(get_model_path('autoregressive.pth', models_dir)), strict=False)
216
  self.autoregressive.post_init_gpt2_config(use_deepspeed=use_deepspeed, kv_cache=kv_cache, half=self.half)
217
 
218
  self.hifi_decoder = HifiganGenerator(in_channels=1024, out_channels = 1, resblock_type = "1",
219
  resblock_dilation_sizes = [[1, 3, 5], [1, 3, 5], [1, 3, 5]], resblock_kernel_sizes = [3, 7, 11],
220
  upsample_kernel_sizes = [16, 16, 4, 4], upsample_initial_channel = 512, upsample_factors = [8, 8, 2, 2],
221
+ cond_channels=1024).to(self.device).eval()
222
  hifi_model = torch.load(get_model_path('hifidecoder.pth'))
223
  self.hifi_decoder.load_state_dict(hifi_model, strict=False)
224
  # Random latent generators (RLGs) are loaded lazily.
 
277
  settings.update(kwargs) # allow overriding of preset settings with kwargs
278
  for audio_frame in self.tts(text, **settings):
279
  yield audio_frame
280
+
281
  def handle_chunks(self, wav_gen, wav_gen_prev, wav_overlap, overlap_len):
282
  """Handle chunk formatting in streaming mode"""
283
  wav_chunk = wav_gen[:-overlap_len]
 
381
  wav_gen_prev = None
382
  wav_overlap = None
383
  is_end = False
384
+ first_buffer = 40
385
  while not is_end:
386
  try:
387
  with torch.autocast(
 
396
  if is_end or (stream_chunk_size > 0 and len(codes_) >= max(stream_chunk_size, first_buffer)):
397
  first_buffer = 0
398
  gpt_latents = torch.cat(all_latents, dim=0)[None, :]
399
+ wav_gen = self.hifi_decoder.inference(gpt_latents.to(self.device), auto_conditioning)
400
  wav_gen = wav_gen.squeeze()
401
  wav_chunk, wav_gen_prev, wav_overlap = self.handle_chunks(
402
  wav_gen.squeeze(), wav_gen_prev, wav_overlap, overlap_wav_len