jbetker commited on
Commit
2888ae0
1 Parent(s): cdf44d7

Fix bug with k>1

Browse files
Files changed (1) hide show
  1. tortoise/api.py +2 -1
tortoise/api.py CHANGED
@@ -416,7 +416,8 @@ class TextToSpeech:
416
  # inputs. Re-produce those for the top results. This could be made more efficient by storing all of these
417
  # results, but will increase memory usage.
418
  self.autoregressive = self.autoregressive.cuda()
419
- best_latents = self.autoregressive(auto_conditioning, text_tokens, torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), best_results,
 
420
  torch.tensor([best_results.shape[-1]*self.autoregressive.mel_length_compression], device=text_tokens.device),
421
  return_latent=True, clip_inputs=False)
422
  self.autoregressive = self.autoregressive.cpu()
 
416
  # inputs. Re-produce those for the top results. This could be made more efficient by storing all of these
417
  # results, but will increase memory usage.
418
  self.autoregressive = self.autoregressive.cuda()
419
+ best_latents = self.autoregressive(auto_conditioning.repeat(k, 1), text_tokens.repeat(k, 1),
420
+ torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), best_results,
421
  torch.tensor([best_results.shape[-1]*self.autoregressive.mel_length_compression], device=text_tokens.device),
422
  return_latent=True, clip_inputs=False)
423
  self.autoregressive = self.autoregressive.cpu()