OlaWod commited on
Commit
94a6ab2
·
verified ·
1 Parent(s): fafb74b
Files changed (3) hide show
  1. Utils/JDC/model.py +1 -1
  2. app.py +2 -3
  3. models.py +11 -76
Utils/JDC/model.py CHANGED
@@ -134,7 +134,7 @@ class JDCNet(nn.Module):
134
  # sizes: (b, 31, 722), (b, 31, 2)
135
  # classifier output consists of predicted pitch classes per frame
136
  # detector output consists of: (isvoice, notvoice) estimates per frame
137
- return torch.abs(classifier_out.squeeze()), GAN_feature, poolblock_out
138
 
139
  @staticmethod
140
  def init_weights(m):
 
134
  # sizes: (b, 31, 722), (b, 31, 2)
135
  # classifier output consists of predicted pitch classes per frame
136
  # detector output consists of: (isvoice, notvoice) estimates per frame
137
+ return torch.abs(classifier_out.squeeze(-1)), GAN_feature, poolblock_out
138
 
139
  @staticmethod
140
  def init_weights(m):
app.py CHANGED
@@ -13,7 +13,6 @@ from transformers import WavLMModel
13
  from env import AttrDict
14
  from meldataset import mel_spectrogram, MAX_WAV_VALUE
15
  from models import Generator
16
- from stft import TorchSTFT
17
  from Utils.JDC.model import JDCNet
18
 
19
 
@@ -38,7 +37,6 @@ h = AttrDict(json_config)
38
  # load models
39
  F0_model = JDCNet(num_class=1, seq_len=192)
40
  generator = Generator(h, F0_model).to(device)
41
- stft = TorchSTFT(filter_length=h.gen_istft_n_fft, hop_length=h.gen_istft_hop_size, win_length=h.gen_istft_n_fft).to(device)
42
 
43
  state_dict_g = torch.load(ptfile, map_location=device)
44
  generator.load_state_dict(state_dict_g['generator'], strict=True)
@@ -84,6 +82,7 @@ def convert(tgt_spk, src_wav, f0_shift=0):
84
  spk_emb = torch.from_numpy(spk_emb).unsqueeze(0).to(device)
85
 
86
  f0_mean_tgt = f0_stats[tgt_spk]["mean"]
 
87
 
88
  # src
89
  wav, sr = librosa.load(src_wav, sr=16000)
@@ -98,7 +97,7 @@ def convert(tgt_spk, src_wav, f0_shift=0):
98
  f0 = generator.get_f0(mel, f0_mean_tgt)
99
  f0 = tune_f0(f0, f0_shift)
100
  x = generator.get_x(x, spk_emb, spk_id)
101
- y = generator.infer(x, f0, stft)
102
 
103
  audio = y.squeeze()
104
  audio = audio / torch.max(torch.abs(audio)) * 0.95
 
13
  from env import AttrDict
14
  from meldataset import mel_spectrogram, MAX_WAV_VALUE
15
  from models import Generator
 
16
  from Utils.JDC.model import JDCNet
17
 
18
 
 
37
  # load models
38
  F0_model = JDCNet(num_class=1, seq_len=192)
39
  generator = Generator(h, F0_model).to(device)
 
40
 
41
  state_dict_g = torch.load(ptfile, map_location=device)
42
  generator.load_state_dict(state_dict_g['generator'], strict=True)
 
82
  spk_emb = torch.from_numpy(spk_emb).unsqueeze(0).to(device)
83
 
84
  f0_mean_tgt = f0_stats[tgt_spk]["mean"]
85
+ f0_mean_tgt = torch.FloatTensor([f0_mean_tgt]).unsqueeze(0).to(device)
86
 
87
  # src
88
  wav, sr = librosa.load(src_wav, sr=16000)
 
97
  f0 = generator.get_f0(mel, f0_mean_tgt)
98
  f0 = tune_f0(f0, f0_shift)
99
  x = generator.get_x(x, spk_emb, spk_id)
100
+ y = generator.infer(x, f0)
101
 
102
  audio = y.squeeze()
103
  audio = audio / torch.max(torch.abs(audio)) * 0.95
models.py CHANGED
@@ -1,4 +1,3 @@
1
- import math
2
  import torch
3
  import torch.nn.functional as F
4
  import torch.nn as nn
@@ -486,9 +485,6 @@ class Generator(torch.nn.Module):
486
  g = g + spk_emb.unsqueeze(-1)
487
 
488
  f0, _, _ = self.F0_model(mel.unsqueeze(1))
489
- if len(f0.shape) == 1:
490
- f0 = f0.unsqueeze(0)
491
-
492
  f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
493
 
494
  har_source, _, _ = self.m_source(f0)
@@ -526,28 +522,21 @@ class Generator(torch.nn.Module):
526
 
527
  return spec, phase
528
 
529
- def get_f0(self, mel, f0_mean_tgt, voiced_threshold=10, interp=True):
530
  f0, _, _ = self.F0_model(mel.unsqueeze(1))
531
-
532
  voiced = f0 > voiced_threshold
533
 
534
  lf0 = torch.log(f0)
535
- lf0_mean = lf0[voiced].mean()
536
- lf0_adj = lf0 - lf0_mean + math.log(f0_mean_tgt)
 
 
537
  f0_adj = torch.exp(lf0_adj)
538
 
539
- f0_adj = torch.where(voiced, f0_adj, 0)
540
-
541
- # interpolate unsilent unvoiced f0 frames
542
- if interp:
543
- f0_adj = self.interp_f0(f0_adj.unsqueeze(0), voiced.unsqueeze(0)).squeeze(0)
544
- energy = torch.sum(mel.squeeze(0), dim=0) # simple vad
545
- unsilent = energy > -700
546
- unsilent = unsilent | voiced
547
- f0_adj = torch.where(unsilent, f0_adj, 0)
548
-
549
- if len(f0_adj.shape) == 1:
550
- f0_adj = f0_adj.unsqueeze(0)
551
 
552
  return f0_adj
553
 
@@ -562,7 +551,7 @@ class Generator(torch.nn.Module):
562
 
563
  return x
564
 
565
- def infer(self, x, f0, stft):
566
  f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
567
 
568
  har_source, _, _ = self.m_source(f0)
@@ -593,62 +582,8 @@ class Generator(torch.nn.Module):
593
  spec = torch.exp(x[:,:self.post_n_fft // 2 + 1, :])
594
  phase = torch.sin(x[:, self.post_n_fft // 2 + 1:, :])
595
 
596
- y = stft.inverse(spec, phase)
597
-
598
  return y
599
-
600
- def interp_f0(self, pitch, voiced):
601
- """Fill unvoiced regions via linear interpolation"""
602
-
603
- # Handle no voiced frames
604
- if not voiced.any():
605
- return pitch
606
-
607
- # Pitch is linear in base-2 log-space
608
- pitch = torch.log2(pitch)
609
-
610
- # Anchor endpoints
611
- pitch[..., 0] = pitch[voiced][..., 0]
612
- pitch[..., -1] = pitch[voiced][..., -1]
613
- voiced[..., 0] = True
614
- voiced[..., -1] = True
615
-
616
- # Interpolate
617
- pitch[~voiced] = self.interp(
618
- torch.where(~voiced[0])[0][None],
619
- torch.where(voiced[0])[0][None],
620
- pitch[voiced][None])
621
-
622
- return 2 ** pitch
623
-
624
- @staticmethod
625
- def interp(x, xp, fp):
626
- """1D linear interpolation for monotonically increasing sample points"""
627
- # Handle edge cases
628
- if xp.shape[-1] == 0:
629
- return x
630
- if xp.shape[-1] == 1:
631
- return torch.full(
632
- x.shape,
633
- fp.squeeze(),
634
- device=fp.device,
635
- dtype=fp.dtype)
636
-
637
- # Get slope and intercept using right-side first-differences
638
- m = (fp[:, 1:] - fp[:, :-1]) / (xp[:, 1:] - xp[:, :-1])
639
- b = fp[:, :-1] - (m.mul(xp[:, :-1]))
640
-
641
- # Get indices to sample slope and intercept
642
- indicies = torch.sum(torch.ge(x[:, :, None], xp[:, None, :]), -1) - 1
643
- indicies = torch.clamp(indicies, 0, m.shape[-1] - 1)
644
- line_idx = torch.linspace(
645
- 0,
646
- indicies.shape[0],
647
- 1,
648
- device=indicies.device).to(torch.long).expand(indicies.shape)
649
-
650
- # Interpolate
651
- return m[line_idx, indicies].mul(x) + b[line_idx, indicies]
652
 
653
  def remove_weight_norm(self):
654
  print('Removing weight norm...')
 
 
1
  import torch
2
  import torch.nn.functional as F
3
  import torch.nn as nn
 
485
  g = g + spk_emb.unsqueeze(-1)
486
 
487
  f0, _, _ = self.F0_model(mel.unsqueeze(1))
 
 
 
488
  f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
489
 
490
  har_source, _, _ = self.m_source(f0)
 
522
 
523
  return spec, phase
524
 
525
+ def get_f0(self, mel, f0_mean_tgt, voiced_threshold=10):
526
  f0, _, _ = self.F0_model(mel.unsqueeze(1))
 
527
  voiced = f0 > voiced_threshold
528
 
529
  lf0 = torch.log(f0)
530
+ lf0_ = lf0 * voiced.float()
531
+ lf0_mean = lf0_.sum(1) / voiced.float().sum(1)
532
+ lf0_mean = lf0_mean.unsqueeze(1)
533
+ lf0_adj = lf0 - lf0_mean + torch.log(f0_mean_tgt)
534
  f0_adj = torch.exp(lf0_adj)
535
 
536
+ energy = mel.sum(1)
537
+ unsilent = energy > -700
538
+ unsilent = unsilent | voiced # simple vad
539
+ f0_adj = f0_adj * unsilent.float()
 
 
 
 
 
 
 
 
540
 
541
  return f0_adj
542
 
 
551
 
552
  return x
553
 
554
+ def infer(self, x, f0):
555
  f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
556
 
557
  har_source, _, _ = self.m_source(f0)
 
582
  spec = torch.exp(x[:,:self.post_n_fft // 2 + 1, :])
583
  phase = torch.sin(x[:, self.post_n_fft // 2 + 1:, :])
584
 
585
+ y = self.stft.inverse(spec, phase)
 
586
  return y
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
587
 
588
  def remove_weight_norm(self):
589
  print('Removing weight norm...')