kevinwang676 commited on
Commit
c94a011
1 Parent(s): db212a6

Update usr/diff/shallow_diffusion_tts.py

Browse files
Files changed (1) hide show
  1. usr/diff/shallow_diffusion_tts.py +2 -5
usr/diff/shallow_diffusion_tts.py CHANGED
@@ -15,7 +15,7 @@ from modules.fastspeech.fs2 import FastSpeech2
15
  from modules.diffsinger_midi.fs2 import FastSpeech2MIDI
16
  from utils.hparams import hparams
17
 
18
- import spaces
19
 
20
  def exists(x):
21
  return x is not None
@@ -227,11 +227,9 @@ class GaussianDiffusion(nn.Module):
227
 
228
  return loss
229
 
230
- @spaces.GPU(duration=120)
231
  def forward(self, txt_tokens, mel2ph=None, spk_embed=None,
232
  ref_mels=None, f0=None, uv=None, energy=None, infer=False, **kwargs):
233
- b, *_, device = *txt_tokens.shape, "cuda:0" #txt_tokens.device
234
- print(f"在{device}上运行")
235
  ret = self.fs2(txt_tokens, mel2ph, spk_embed, ref_mels, f0, uv, energy,
236
  skip_decoder=(not infer), infer=infer, **kwargs)
237
  cond = ret['decoder_inp'].transpose(1, 2)
@@ -258,7 +256,6 @@ class GaussianDiffusion(nn.Module):
258
  x = torch.randn(shape, device=device)
259
 
260
  if hparams.get('pndm_speedup'):
261
- print("pndm_speedup 加速中...")
262
  self.noise_list = deque(maxlen=4)
263
  iteration_interval = hparams['pndm_speedup']
264
  for i in tqdm(reversed(range(0, t, iteration_interval)), desc='sample time step',
 
15
  from modules.diffsinger_midi.fs2 import FastSpeech2MIDI
16
  from utils.hparams import hparams
17
 
18
+
19
 
20
  def exists(x):
21
  return x is not None
 
227
 
228
  return loss
229
 
 
230
  def forward(self, txt_tokens, mel2ph=None, spk_embed=None,
231
  ref_mels=None, f0=None, uv=None, energy=None, infer=False, **kwargs):
232
+ b, *_, device = *txt_tokens.shape, txt_tokens.device
 
233
  ret = self.fs2(txt_tokens, mel2ph, spk_embed, ref_mels, f0, uv, energy,
234
  skip_decoder=(not infer), infer=infer, **kwargs)
235
  cond = ret['decoder_inp'].transpose(1, 2)
 
256
  x = torch.randn(shape, device=device)
257
 
258
  if hparams.get('pndm_speedup'):
 
259
  self.noise_list = deque(maxlen=4)
260
  iteration_interval = hparams['pndm_speedup']
261
  for i in tqdm(reversed(range(0, t, iteration_interval)), desc='sample time step',