kevinwang676 commited on
Commit
5dfce5c
1 Parent(s): 901f2c9

Update usr/diff/shallow_diffusion_tts.py

Browse files
Files changed (1) hide show
  1. usr/diff/shallow_diffusion_tts.py +2 -2
usr/diff/shallow_diffusion_tts.py CHANGED
@@ -230,8 +230,8 @@ class GaussianDiffusion(nn.Module):
230
  @spaces.GPU(duration=120)
231
  def forward(self, txt_tokens, mel2ph=None, spk_embed=None,
232
  ref_mels=None, f0=None, uv=None, energy=None, infer=False, **kwargs):
233
- b, *_, device = *txt_tokens.shape, txt_tokens.device
234
- print(txt_tokens.device)
235
  ret = self.fs2(txt_tokens, mel2ph, spk_embed, ref_mels, f0, uv, energy,
236
  skip_decoder=(not infer), infer=infer, **kwargs)
237
  cond = ret['decoder_inp'].transpose(1, 2)
 
230
  @spaces.GPU(duration=120)
231
  def forward(self, txt_tokens, mel2ph=None, spk_embed=None,
232
  ref_mels=None, f0=None, uv=None, energy=None, infer=False, **kwargs):
233
+ b, *_, device = *txt_tokens.shape, "cuda:0" #txt_tokens.device
234
+ print(f"在{device}上运行")
235
  ret = self.fs2(txt_tokens, mel2ph, spk_embed, ref_mels, f0, uv, energy,
236
  skip_decoder=(not infer), infer=infer, **kwargs)
237
  cond = ret['decoder_inp'].transpose(1, 2)