Spaces:
Running
Running
kevinwang676
commited on
Commit
•
c94a011
1
Parent(s):
db212a6
Update usr/diff/shallow_diffusion_tts.py
Browse files
usr/diff/shallow_diffusion_tts.py
CHANGED
@@ -15,7 +15,7 @@ from modules.fastspeech.fs2 import FastSpeech2
|
|
15 |
from modules.diffsinger_midi.fs2 import FastSpeech2MIDI
|
16 |
from utils.hparams import hparams
|
17 |
|
18 |
-
|
19 |
|
20 |
def exists(x):
|
21 |
return x is not None
|
@@ -227,11 +227,9 @@ class GaussianDiffusion(nn.Module):
|
|
227 |
|
228 |
return loss
|
229 |
|
230 |
-
@spaces.GPU(duration=120)
|
231 |
def forward(self, txt_tokens, mel2ph=None, spk_embed=None,
|
232 |
ref_mels=None, f0=None, uv=None, energy=None, infer=False, **kwargs):
|
233 |
-
b, *_, device = *txt_tokens.shape,
|
234 |
-
print(f"在{device}上运行")
|
235 |
ret = self.fs2(txt_tokens, mel2ph, spk_embed, ref_mels, f0, uv, energy,
|
236 |
skip_decoder=(not infer), infer=infer, **kwargs)
|
237 |
cond = ret['decoder_inp'].transpose(1, 2)
|
@@ -258,7 +256,6 @@ class GaussianDiffusion(nn.Module):
|
|
258 |
x = torch.randn(shape, device=device)
|
259 |
|
260 |
if hparams.get('pndm_speedup'):
|
261 |
-
print("pndm_speedup 加速中...")
|
262 |
self.noise_list = deque(maxlen=4)
|
263 |
iteration_interval = hparams['pndm_speedup']
|
264 |
for i in tqdm(reversed(range(0, t, iteration_interval)), desc='sample time step',
|
|
|
15 |
from modules.diffsinger_midi.fs2 import FastSpeech2MIDI
|
16 |
from utils.hparams import hparams
|
17 |
|
18 |
+
|
19 |
|
20 |
def exists(x):
|
21 |
return x is not None
|
|
|
227 |
|
228 |
return loss
|
229 |
|
|
|
230 |
def forward(self, txt_tokens, mel2ph=None, spk_embed=None,
|
231 |
ref_mels=None, f0=None, uv=None, energy=None, infer=False, **kwargs):
|
232 |
+
b, *_, device = *txt_tokens.shape, txt_tokens.device
|
|
|
233 |
ret = self.fs2(txt_tokens, mel2ph, spk_embed, ref_mels, f0, uv, energy,
|
234 |
skip_decoder=(not infer), infer=infer, **kwargs)
|
235 |
cond = ret['decoder_inp'].transpose(1, 2)
|
|
|
256 |
x = torch.randn(shape, device=device)
|
257 |
|
258 |
if hparams.get('pndm_speedup'):
|
|
|
259 |
self.noise_list = deque(maxlen=4)
|
260 |
iteration_interval = hparams['pndm_speedup']
|
261 |
for i in tqdm(reversed(range(0, t, iteration_interval)), desc='sample time step',
|