ChrisPreston commited on
Commit
e5442c4
1 Parent(s): 85bc08a

Update infer_tools/infer_tool.py

Browse files
Files changed (1) hide show
  1. infer_tools/infer_tool.py +4 -4
infer_tools/infer_tool.py CHANGED
@@ -96,12 +96,12 @@ class Svc:
96
  @timeit
97
  def diff_infer():
98
  spk_embed = batch.get('spk_embed') if not hparams['use_spk_id'] else batch.get('spk_ids')
99
- energy = batch.get('energy').cuda() if batch.get('energy') else None
100
  if spk_embed is None:
101
- spk_embed = torch.LongTensor([0])
102
  diff_outputs = self.model(
103
- hubert=batch['hubert'].cuda(), spk_embed_id=spk_embed.cuda(), mel2ph=batch['mel2ph'].cuda(),
104
- f0=batch['f0'].cuda(), energy=energy, ref_mels=batch["mels"].cuda(), infer=True)
105
  return diff_outputs
106
 
107
  outputs = diff_infer()
 
96
  @timeit
97
  def diff_infer():
98
  spk_embed = batch.get('spk_embed') if not hparams['use_spk_id'] else batch.get('spk_ids')
99
+ energy = batch.get('energy').cpu() if batch.get('energy') else None
100
  if spk_embed is None:
101
+ spk_embed = torch.LongTensor([0]).cpu()
102
  diff_outputs = self.model(
103
+ hubert=batch['hubert'].cpu(), spk_embed_id=spk_embed.cpu(), mel2ph=batch['mel2ph'].cpu(),
104
+ f0=batch['f0'].cpu(), energy=energy, ref_mels=batch["mels"].cpu(), infer=True)
105
  return diff_outputs
106
 
107
  outputs = diff_infer()