ikechan8370 commited on
Commit
23aa815
1 Parent(s): b772f7c

fix: add support for gpu

Browse files
Files changed (1) hide show
  1. models.py +3 -2
models.py CHANGED
@@ -496,9 +496,10 @@ class SynthesizerTrn(nn.Module):
496
  return o, l_length, attn, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
497
 
498
  def infer(self, x, x_lengths, sid=None, noise_scale=1, length_scale=1, noise_scale_w=1., max_len=None):
499
- x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths)
 
500
  if self.n_speakers > 0:
501
- g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
502
  else:
503
  g = None
504
 
 
496
  return o, l_length, attn, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
497
 
498
  def infer(self, x, x_lengths, sid=None, noise_scale=1, length_scale=1, noise_scale_w=1., max_len=None):
499
+ device = next(self.parameters()).device # 获取模型所在的设备
500
+ x, m_p, logs_p, x_mask = self.enc_p(x.to(device), x_lengths.to(device))
501
  if self.n_speakers > 0:
502
+ g = self.emb_g(sid.to(device)).unsqueeze(-1) # [b, h, 1]
503
  else:
504
  g = None
505