Manmay commited on
Commit
ee8d888
1 Parent(s): b137e4d
Files changed (1) hide show
  1. tortoise/api.py +19 -20
tortoise/api.py CHANGED
@@ -275,9 +275,7 @@ class TextToSpeech:
275
  for vs in voice_samples:
276
  auto_conds.append(format_conditioning(vs, device=self.device))
277
  auto_conds = torch.stack(auto_conds, dim=1)
278
- self.autoregressive = self.autoregressive.to(self.device)
279
  auto_latent = self.autoregressive.get_conditioning(auto_conds)
280
- self.autoregressive = self.autoregressive.cpu()
281
 
282
  diffusion_conds = []
283
  for sample in voice_samples:
@@ -288,9 +286,7 @@ class TextToSpeech:
288
  diffusion_conds.append(cond_mel)
289
  diffusion_conds = torch.stack(diffusion_conds, dim=1)
290
 
291
- self.diffusion = self.diffusion.to(self.device)
292
  diffusion_latent = self.diffusion.get_conditioning(diffusion_conds)
293
- self.diffusion = self.diffusion.cpu()
294
 
295
  if return_mels:
296
  return auto_latent, diffusion_latent, auto_conds, diffusion_conds
@@ -405,22 +401,25 @@ class TextToSpeech:
405
  calm_token = 83 # This is the token for coding silence, which is fixed in place with "fix_autoregressive_output"
406
  if verbose:
407
  print("Generating autoregressive samples..")
408
- codes = self.autoregressive.inference_speech(auto_conditioning, text_tokens,
409
- do_sample=True,
410
- top_p=top_p,
411
- temperature=temperature,
412
- num_return_sequences=num_autoregressive_samples,
413
- length_penalty=length_penalty,
414
- repetition_penalty=repetition_penalty,
415
- max_generate_length=max_mel_tokens,
416
- **hf_generate_kwargs)
417
- # The diffusion model actually wants the last hidden layer from the autoregressive model as conditioning
418
- # inputs. Re-produce those for the top results. This could be made more efficient by storing all of these
419
- # results, but will increase memory usage.
420
- best_latents = self.autoregressive(auto_conditioning.repeat(k, 1), text_tokens.repeat(k, 1),
421
- torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), codes,
422
- torch.tensor([codes.shape[-1]*self.autoregressive.mel_length_compression], device=text_tokens.device),
423
- return_latent=True, clip_inputs=False)
 
 
 
424
  del auto_conditioning
425
 
426
  if verbose:
 
275
  for vs in voice_samples:
276
  auto_conds.append(format_conditioning(vs, device=self.device))
277
  auto_conds = torch.stack(auto_conds, dim=1)
 
278
  auto_latent = self.autoregressive.get_conditioning(auto_conds)
 
279
 
280
  diffusion_conds = []
281
  for sample in voice_samples:
 
286
  diffusion_conds.append(cond_mel)
287
  diffusion_conds = torch.stack(diffusion_conds, dim=1)
288
 
 
289
  diffusion_latent = self.diffusion.get_conditioning(diffusion_conds)
 
290
 
291
  if return_mels:
292
  return auto_latent, diffusion_latent, auto_conds, diffusion_conds
 
401
  calm_token = 83 # This is the token for coding silence, which is fixed in place with "fix_autoregressive_output"
402
  if verbose:
403
  print("Generating autoregressive samples..")
404
+ with torch.autocast(
405
+ device_type="cuda" , dtype=torch.float16, enabled=self.half
406
+ ):
407
+ codes = self.autoregressive.inference_speech(auto_conditioning, text_tokens,
408
+ do_sample=True,
409
+ top_p=top_p,
410
+ temperature=temperature,
411
+ num_return_sequences=num_autoregressive_samples,
412
+ length_penalty=length_penalty,
413
+ repetition_penalty=repetition_penalty,
414
+ max_generate_length=max_mel_tokens,
415
+ **hf_generate_kwargs)
416
+ # The diffusion model actually wants the last hidden layer from the autoregressive model as conditioning
417
+ # inputs. Re-produce those for the top results. This could be made more efficient by storing all of these
418
+ # results, but will increase memory usage.
419
+ best_latents = self.autoregressive(auto_conditioning.repeat(k, 1), text_tokens.repeat(k, 1),
420
+ torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), codes,
421
+ torch.tensor([codes.shape[-1]*self.autoregressive.mel_length_compression], device=text_tokens.device),
422
+ return_latent=True, clip_inputs=False)
423
  del auto_conditioning
424
 
425
  if verbose: