Hecheng0625 commited on
Commit
1adbad7
1 Parent(s): a63132d

Update Amphion/models/ns3_codec/facodec.py

Browse files
Files changed (1) hide show
  1. Amphion/models/ns3_codec/facodec.py +454 -0
Amphion/models/ns3_codec/facodec.py CHANGED
@@ -10,6 +10,7 @@ from einops import rearrange
10
  from einops.layers.torch import Rearrange
11
  from .transformer import TransformerEncoder
12
  from .gradient_reversal import GradientReversal
 
13
 
14
 
15
  def init_weights(m):
@@ -761,3 +762,456 @@ class FACodecRedecoder(nn.Module):
761
  x = x * gamma + beta
762
  x = self.model(x)
763
  return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  from einops.layers.torch import Rearrange
11
  from .transformer import TransformerEncoder
12
  from .gradient_reversal import GradientReversal
13
+ from .melspec import MelSpectrogram
14
 
15
 
16
  def init_weights(m):
 
762
  x = x * gamma + beta
763
  x = self.model(x)
764
  return x
765
+
766
+
767
+ class FACodecEncoderV2(nn.Module):
768
+ def __init__(
769
+ self,
770
+ ngf=32,
771
+ up_ratios=(2, 4, 5, 5),
772
+ out_channels=1024,
773
+ ):
774
+ super().__init__()
775
+ self.hop_length = np.prod(up_ratios)
776
+ self.up_ratios = up_ratios
777
+
778
+ # Create first convolution
779
+ d_model = ngf
780
+ self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)]
781
+
782
+ # Create EncoderBlocks that double channels as they downsample by `stride`
783
+ for stride in up_ratios:
784
+ d_model *= 2
785
+ self.block += [EncoderBlock(d_model, stride=stride)]
786
+
787
+ # Create last convolution
788
+ self.block += [
789
+ Activation1d(activation=SnakeBeta(d_model, alpha_logscale=True)),
790
+ WNConv1d(d_model, out_channels, kernel_size=3, padding=1),
791
+ ]
792
+
793
+ # Wrap black into nn.Sequential
794
+ self.block = nn.Sequential(*self.block)
795
+ self.enc_dim = d_model
796
+
797
+ self.mel_transform = MelSpectrogram(
798
+ n_fft=1024,
799
+ num_mels=80,
800
+ sampling_rate=16000,
801
+ hop_size=200,
802
+ win_size=800,
803
+ fmin=0,
804
+ fmax=8000,
805
+ )
806
+
807
+ self.reset_parameters()
808
+
809
+ def forward(self, x):
810
+ out = self.block(x)
811
+ return out
812
+
813
+ def inference(self, x):
814
+ return self.block(x)
815
+
816
+ def get_prosody_feature(self, x):
817
+ return self.mel_transform(x.squeeze(1))[:, :20, :]
818
+
819
+ def remove_weight_norm(self):
820
+ """Remove weight normalization module from all of the layers."""
821
+
822
+ def _remove_weight_norm(m):
823
+ try:
824
+ torch.nn.utils.remove_weight_norm(m)
825
+ except ValueError: # this module didn't have weight norm
826
+ return
827
+
828
+ self.apply(_remove_weight_norm)
829
+
830
+ def apply_weight_norm(self):
831
+ """Apply weight normalization module from all of the layers."""
832
+
833
+ def _apply_weight_norm(m):
834
+ if isinstance(m, nn.Conv1d):
835
+ torch.nn.utils.weight_norm(m)
836
+
837
+ self.apply(_apply_weight_norm)
838
+
839
+ def reset_parameters(self):
840
+ self.apply(init_weights)
841
+
842
+
843
+ class FACodecDecoderV2(nn.Module):
844
+ def __init__(
845
+ self,
846
+ in_channels=256,
847
+ upsample_initial_channel=1536,
848
+ ngf=32,
849
+ up_ratios=(5, 5, 4, 2),
850
+ vq_num_q_c=2,
851
+ vq_num_q_p=1,
852
+ vq_num_q_r=3,
853
+ vq_dim=1024,
854
+ vq_commit_weight=0.005,
855
+ vq_weight_init=False,
856
+ vq_full_commit_loss=False,
857
+ codebook_dim=8,
858
+ codebook_size_prosody=10, # true codebook size is equal to 2^codebook_size
859
+ codebook_size_content=10,
860
+ codebook_size_residual=10,
861
+ quantizer_dropout=0.0,
862
+ dropout_type="linear",
863
+ use_gr_content_f0=False,
864
+ use_gr_prosody_phone=False,
865
+ use_gr_residual_f0=False,
866
+ use_gr_residual_phone=False,
867
+ use_gr_x_timbre=False,
868
+ use_random_mask_residual=True,
869
+ prob_random_mask_residual=0.75,
870
+ ):
871
+ super().__init__()
872
+ self.hop_length = np.prod(up_ratios)
873
+ self.ngf = ngf
874
+ self.up_ratios = up_ratios
875
+
876
+ self.use_random_mask_residual = use_random_mask_residual
877
+ self.prob_random_mask_residual = prob_random_mask_residual
878
+
879
+ self.vq_num_q_p = vq_num_q_p
880
+ self.vq_num_q_c = vq_num_q_c
881
+ self.vq_num_q_r = vq_num_q_r
882
+
883
+ self.codebook_size_prosody = codebook_size_prosody
884
+ self.codebook_size_content = codebook_size_content
885
+ self.codebook_size_residual = codebook_size_residual
886
+
887
+ quantizer_class = ResidualVQ
888
+
889
+ self.quantizer = nn.ModuleList()
890
+
891
+ # prosody
892
+ quantizer = quantizer_class(
893
+ num_quantizers=vq_num_q_p,
894
+ dim=vq_dim,
895
+ codebook_size=codebook_size_prosody,
896
+ codebook_dim=codebook_dim,
897
+ threshold_ema_dead_code=2,
898
+ commitment=vq_commit_weight,
899
+ weight_init=vq_weight_init,
900
+ full_commit_loss=vq_full_commit_loss,
901
+ quantizer_dropout=quantizer_dropout,
902
+ dropout_type=dropout_type,
903
+ )
904
+ self.quantizer.append(quantizer)
905
+
906
+ # phone
907
+ quantizer = quantizer_class(
908
+ num_quantizers=vq_num_q_c,
909
+ dim=vq_dim,
910
+ codebook_size=codebook_size_content,
911
+ codebook_dim=codebook_dim,
912
+ threshold_ema_dead_code=2,
913
+ commitment=vq_commit_weight,
914
+ weight_init=vq_weight_init,
915
+ full_commit_loss=vq_full_commit_loss,
916
+ quantizer_dropout=quantizer_dropout,
917
+ dropout_type=dropout_type,
918
+ )
919
+ self.quantizer.append(quantizer)
920
+
921
+ # residual
922
+ if self.vq_num_q_r > 0:
923
+ quantizer = quantizer_class(
924
+ num_quantizers=vq_num_q_r,
925
+ dim=vq_dim,
926
+ codebook_size=codebook_size_residual,
927
+ codebook_dim=codebook_dim,
928
+ threshold_ema_dead_code=2,
929
+ commitment=vq_commit_weight,
930
+ weight_init=vq_weight_init,
931
+ full_commit_loss=vq_full_commit_loss,
932
+ quantizer_dropout=quantizer_dropout,
933
+ dropout_type=dropout_type,
934
+ )
935
+ self.quantizer.append(quantizer)
936
+
937
+ # Add first conv layer
938
+ channels = upsample_initial_channel
939
+ layers = [WNConv1d(in_channels, channels, kernel_size=7, padding=3)]
940
+
941
+ # Add upsampling + MRF blocks
942
+ for i, stride in enumerate(up_ratios):
943
+ input_dim = channels // 2**i
944
+ output_dim = channels // 2 ** (i + 1)
945
+ layers += [DecoderBlock(input_dim, output_dim, stride)]
946
+
947
+ # Add final conv layer
948
+ layers += [
949
+ Activation1d(activation=SnakeBeta(output_dim, alpha_logscale=True)),
950
+ WNConv1d(output_dim, 1, kernel_size=7, padding=3),
951
+ nn.Tanh(),
952
+ ]
953
+
954
+ self.model = nn.Sequential(*layers)
955
+
956
+ self.timbre_encoder = TransformerEncoder(
957
+ enc_emb_tokens=None,
958
+ encoder_layer=4,
959
+ encoder_hidden=256,
960
+ encoder_head=4,
961
+ conv_filter_size=1024,
962
+ conv_kernel_size=5,
963
+ encoder_dropout=0.1,
964
+ use_cln=False,
965
+ )
966
+
967
+ self.timbre_linear = nn.Linear(in_channels, in_channels * 2)
968
+ self.timbre_linear.bias.data[:in_channels] = 1
969
+ self.timbre_linear.bias.data[in_channels:] = 0
970
+ self.timbre_norm = nn.LayerNorm(in_channels, elementwise_affine=False)
971
+
972
+ self.f0_predictor = CNNLSTM(in_channels, 1, 2)
973
+ self.phone_predictor = CNNLSTM(in_channels, 5003, 1)
974
+
975
+ self.use_gr_content_f0 = use_gr_content_f0
976
+ self.use_gr_prosody_phone = use_gr_prosody_phone
977
+ self.use_gr_residual_f0 = use_gr_residual_f0
978
+ self.use_gr_residual_phone = use_gr_residual_phone
979
+ self.use_gr_x_timbre = use_gr_x_timbre
980
+
981
+ if self.vq_num_q_r > 0 and self.use_gr_residual_f0:
982
+ self.res_f0_predictor = nn.Sequential(
983
+ GradientReversal(alpha=1.0), CNNLSTM(in_channels, 1, 2)
984
+ )
985
+
986
+ if self.vq_num_q_r > 0 and self.use_gr_residual_phone > 0:
987
+ self.res_phone_predictor = nn.Sequential(
988
+ GradientReversal(alpha=1.0), CNNLSTM(in_channels, 5003, 1)
989
+ )
990
+
991
+ if self.use_gr_content_f0:
992
+ self.content_f0_predictor = nn.Sequential(
993
+ GradientReversal(alpha=1.0), CNNLSTM(in_channels, 1, 2)
994
+ )
995
+
996
+ if self.use_gr_prosody_phone:
997
+ self.prosody_phone_predictor = nn.Sequential(
998
+ GradientReversal(alpha=1.0), CNNLSTM(in_channels, 5003, 1)
999
+ )
1000
+
1001
+ if self.use_gr_x_timbre:
1002
+ self.x_timbre_predictor = nn.Sequential(
1003
+ GradientReversal(alpha=1),
1004
+ CNNLSTM(in_channels, 245200, 1, global_pred=True),
1005
+ )
1006
+
1007
+ self.melspec_linear = nn.Linear(20, 256)
1008
+ self.melspec_encoder = TransformerEncoder(
1009
+ enc_emb_tokens=None,
1010
+ encoder_layer=4,
1011
+ encoder_hidden=256,
1012
+ encoder_head=4,
1013
+ conv_filter_size=1024,
1014
+ conv_kernel_size=5,
1015
+ encoder_dropout=0.1,
1016
+ use_cln=False,
1017
+ cfg=None,
1018
+ )
1019
+
1020
+ self.reset_parameters()
1021
+
1022
+ def quantize(self, x, prosody_feature, n_quantizers=None):
1023
+ outs, qs, commit_loss, quantized_buf = 0, [], [], []
1024
+
1025
+ # prosody
1026
+ f0_input = prosody_feature.transpose(1, 2) # (B, T, 20)
1027
+ f0_input = self.melspec_linear(f0_input)
1028
+ f0_input = self.melspec_encoder(f0_input, None, None)
1029
+ f0_input = f0_input.transpose(1, 2)
1030
+ f0_quantizer = self.quantizer[0]
1031
+ out, q, commit, quantized = f0_quantizer(f0_input, n_quantizers=n_quantizers)
1032
+ outs += out
1033
+ qs.append(q)
1034
+ quantized_buf.append(quantized.sum(0))
1035
+ commit_loss.append(commit)
1036
+
1037
+ # phone
1038
+ phone_input = x
1039
+ phone_quantizer = self.quantizer[1]
1040
+ out, q, commit, quantized = phone_quantizer(
1041
+ phone_input, n_quantizers=n_quantizers
1042
+ )
1043
+ outs += out
1044
+ qs.append(q)
1045
+ quantized_buf.append(quantized.sum(0))
1046
+ commit_loss.append(commit)
1047
+
1048
+ # residual
1049
+ if self.vq_num_q_r > 0:
1050
+ residual_quantizer = self.quantizer[2]
1051
+ residual_input = x - (quantized_buf[0] + quantized_buf[1]).detach()
1052
+ out, q, commit, quantized = residual_quantizer(
1053
+ residual_input, n_quantizers=n_quantizers
1054
+ )
1055
+ outs += out
1056
+ qs.append(q)
1057
+ quantized_buf.append(quantized.sum(0)) # [L, B, C, T] -> [B, C, T]
1058
+ commit_loss.append(commit)
1059
+
1060
+ qs = torch.cat(qs, dim=0)
1061
+ commit_loss = torch.cat(commit_loss, dim=0)
1062
+ return outs, qs, commit_loss, quantized_buf
1063
+
1064
+ def forward(
1065
+ self,
1066
+ x,
1067
+ prosody_feature,
1068
+ vq=True,
1069
+ get_vq=False,
1070
+ eval_vq=True,
1071
+ speaker_embedding=None,
1072
+ n_quantizers=None,
1073
+ quantized=None,
1074
+ ):
1075
+ if get_vq:
1076
+ return self.quantizer.get_emb()
1077
+ if vq is True:
1078
+ if eval_vq:
1079
+ self.quantizer.eval()
1080
+ x_timbre = x
1081
+ outs, qs, commit_loss, quantized_buf = self.quantize(
1082
+ x, prosody_feature, n_quantizers=n_quantizers
1083
+ )
1084
+
1085
+ x_timbre = x_timbre.transpose(1, 2)
1086
+ x_timbre = self.timbre_encoder(x_timbre, None, None)
1087
+ x_timbre = x_timbre.transpose(1, 2)
1088
+ spk_embs = torch.mean(x_timbre, dim=2)
1089
+ return outs, qs, commit_loss, quantized_buf, spk_embs
1090
+
1091
+ out = {}
1092
+
1093
+ layer_0 = quantized[0]
1094
+ f0, uv = self.f0_predictor(layer_0)
1095
+ f0 = rearrange(f0, "... 1 -> ...")
1096
+ uv = rearrange(uv, "... 1 -> ...")
1097
+
1098
+ layer_1 = quantized[1]
1099
+ (phone,) = self.phone_predictor(layer_1)
1100
+
1101
+ out = {"f0": f0, "uv": uv, "phone": phone}
1102
+
1103
+ if self.use_gr_prosody_phone:
1104
+ (prosody_phone,) = self.prosody_phone_predictor(layer_0)
1105
+ out["prosody_phone"] = prosody_phone
1106
+
1107
+ if self.use_gr_content_f0:
1108
+ content_f0, content_uv = self.content_f0_predictor(layer_1)
1109
+ content_f0 = rearrange(content_f0, "... 1 -> ...")
1110
+ content_uv = rearrange(content_uv, "... 1 -> ...")
1111
+ out["content_f0"] = content_f0
1112
+ out["content_uv"] = content_uv
1113
+
1114
+ if self.vq_num_q_r > 0:
1115
+ layer_2 = quantized[2]
1116
+
1117
+ if self.use_gr_residual_f0:
1118
+ res_f0, res_uv = self.res_f0_predictor(layer_2)
1119
+ res_f0 = rearrange(res_f0, "... 1 -> ...")
1120
+ res_uv = rearrange(res_uv, "... 1 -> ...")
1121
+ out["res_f0"] = res_f0
1122
+ out["res_uv"] = res_uv
1123
+
1124
+ if self.use_gr_residual_phone:
1125
+ (res_phone,) = self.res_phone_predictor(layer_2)
1126
+ out["res_phone"] = res_phone
1127
+
1128
+ style = self.timbre_linear(speaker_embedding).unsqueeze(2) # (B, 2d, 1)
1129
+ gamma, beta = style.chunk(2, 1) # (B, d, 1)
1130
+ if self.vq_num_q_r > 0:
1131
+ if self.use_random_mask_residual:
1132
+ bsz = quantized[2].shape[0]
1133
+ res_mask = np.random.choice(
1134
+ [0, 1],
1135
+ size=bsz,
1136
+ p=[
1137
+ self.prob_random_mask_residual,
1138
+ 1 - self.prob_random_mask_residual,
1139
+ ],
1140
+ )
1141
+ res_mask = (
1142
+ torch.from_numpy(res_mask).unsqueeze(1).unsqueeze(1)
1143
+ ) # (B, 1, 1)
1144
+ res_mask = res_mask.to(
1145
+ device=quantized[2].device, dtype=quantized[2].dtype
1146
+ )
1147
+ x = (
1148
+ quantized[0].detach()
1149
+ + quantized[1].detach()
1150
+ + quantized[2] * res_mask
1151
+ )
1152
+ # x = quantized_perturbe[0].detach() + quantized[1].detach() + quantized[2] * res_mask
1153
+ else:
1154
+ x = quantized[0].detach() + quantized[1].detach() + quantized[2]
1155
+ # x = quantized_perturbe[0].detach() + quantized[1].detach() + quantized[2]
1156
+ else:
1157
+ x = quantized[0].detach() + quantized[1].detach()
1158
+ # x = quantized_perturbe[0].detach() + quantized[1].detach()
1159
+
1160
+ if self.use_gr_x_timbre:
1161
+ (x_timbre,) = self.x_timbre_predictor(x)
1162
+ out["x_timbre"] = x_timbre
1163
+
1164
+ x = x.transpose(1, 2)
1165
+ x = self.timbre_norm(x)
1166
+ x = x.transpose(1, 2)
1167
+ x = x * gamma + beta
1168
+
1169
+ x = self.model(x)
1170
+ out["audio"] = x
1171
+
1172
+ return out
1173
+
1174
+ def vq2emb(self, vq, use_residual=True):
1175
+ # vq: [num_quantizer, B, T]
1176
+ self.quantizer = self.quantizer.eval()
1177
+ out = 0
1178
+ out += self.quantizer[0].vq2emb(vq[0 : self.vq_num_q_p])
1179
+ out += self.quantizer[1].vq2emb(
1180
+ vq[self.vq_num_q_p : self.vq_num_q_p + self.vq_num_q_c]
1181
+ )
1182
+ if self.vq_num_q_r > 0 and use_residual:
1183
+ out += self.quantizer[2].vq2emb(vq[self.vq_num_q_p + self.vq_num_q_c :])
1184
+ return out
1185
+
1186
+ def inference(self, x, speaker_embedding):
1187
+ style = self.timbre_linear(speaker_embedding).unsqueeze(2) # (B, 2d, 1)
1188
+ gamma, beta = style.chunk(2, 1) # (B, d, 1)
1189
+ x = x.transpose(1, 2)
1190
+ x = self.timbre_norm(x)
1191
+ x = x.transpose(1, 2)
1192
+ x = x * gamma + beta
1193
+ x = self.model(x)
1194
+ return x
1195
+
1196
+ def remove_weight_norm(self):
1197
+ """Remove weight normalization module from all of the layers."""
1198
+
1199
+ def _remove_weight_norm(m):
1200
+ try:
1201
+ torch.nn.utils.remove_weight_norm(m)
1202
+ except ValueError: # this module didn't have weight norm
1203
+ return
1204
+
1205
+ self.apply(_remove_weight_norm)
1206
+
1207
+ def apply_weight_norm(self):
1208
+ """Apply weight normalization module from all of the layers."""
1209
+
1210
+ def _apply_weight_norm(m):
1211
+ if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d):
1212
+ torch.nn.utils.weight_norm(m)
1213
+
1214
+ self.apply(_apply_weight_norm)
1215
+
1216
+ def reset_parameters(self):
1217
+ self.apply(init_weights)