Hecheng0625 commited on
Commit
cd03576
1 Parent(s): 4c9108a

Upload 12 files

Browse files
Amphion/models/ns3_codec/README.md ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## FACodec: Speech Codec with Attribute Factorization used for NaturalSpeech 3
2
+
3
+ [![arXiv](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)](https://arxiv.org/pdf/2403.03100.pdf)
4
+ [![demo](https://img.shields.io/badge/FACodec-Demo-red)](https://speechresearch.github.io/naturalspeech3/)
5
+ [![model](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-Models-pink)](https://huggingface.co/amphion/naturalspeech3_facodec)
6
+ [![hf](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-Spaces-yellow)](https://huggingface.co/spaces/amphion/naturalspeech3_facodec)
7
+
8
+ ## Overview
9
+
10
+ FACodec is a core component of the advanced text-to-speech (TTS) model NaturalSpeech 3. FACodec converts complex speech waveform into disentangled subspaces representing speech attributes of content, prosody, timbre, and acoustic details and reconstruct high-quality speech waveform from these attributes. FACodec decomposes complex speech into subspaces representing different attributes, thus simplifying the modeling of speech representation.
11
+
12
+ Research can use FACodec to develop different modes of TTS models, such as non-autoregressive based discrete diffusion (NaturalSpeech 3) or autoregressive models (like VALL-E).
13
+
14
+ <br>
15
+ <div align="center">
16
+ <img src="../../imgs/ns3/ns3_overview.png" width="65%">
17
+ </div>
18
+ <br>
19
+
20
+ <br>
21
+ <div align="center">
22
+ <img src="../../imgs/ns3/ns3_facodec.png" width="100%">
23
+ </div>
24
+ <br>
25
+
26
+ ## Useage
27
+
28
+ Download the pre-trained FACodec model from HuggingFace: [Pretrained FACodec checkpoint](https://huggingface.co/amphion/naturalspeech3_facodec)
29
+
30
+ Install Amphion
31
+ ```bash
32
+ git https://github.com/open-mmlab/Amphion.git
33
+ ```
34
+
35
+ Few lines of code to use the pre-trained FACodec model
36
+ ```python
37
+ from AmphionOpen.models.ns3_codec import FACodecEncoder, FACodecDecoder
38
+
39
+ fa_encoder = FACodecEncoder(
40
+ ngf=32,
41
+ up_ratios=[2, 4, 5, 5],
42
+ out_channels=256,
43
+ )
44
+
45
+ fa_decoder = FACodecDecoder(
46
+ in_channels=256,
47
+ upsample_initial_channel=1024,
48
+ ngf=32,
49
+ up_ratios=[5, 5, 4, 2],
50
+ vq_num_q_c=2,
51
+ vq_num_q_p=1,
52
+ vq_num_q_r=3,
53
+ vq_dim=256,
54
+ codebook_dim=8,
55
+ codebook_size_prosody=10,
56
+ codebook_size_content=10,
57
+ codebook_size_residual=10,
58
+ use_gr_x_timbre=True,
59
+ use_gr_residual_f0=True,
60
+ use_gr_residual_phone=True,
61
+ )
62
+
63
+ fa_encoder = torch.load("ns3_facodec_encoder.bin")
64
+ fa_decoder = torch.load("ns3_facodec_decoder.bin")
65
+
66
+ fa_encoder.eval()
67
+ fa_decoder.eval()
68
+
69
+ ```
70
+
71
+ Test
72
+ ```python
73
+ test_wav_path = "test.wav"
74
+ test_wav = librosa.load(test_wav_path, sr=16000)[0]
75
+ test_wav = torch.from_numpy(test_wav).float()
76
+ test_wav = test_wav.unsqueeze(0).unsqueeze(0)
77
+
78
+ with torch.no_grad():
79
+
80
+ # encode
81
+ enc_out = fa_encoder(test_wav)
82
+ print(enc_out.shape)
83
+
84
+ # quantize
85
+ vq_post_emb, vq_id, _, quantized, spk_embs = fa_decoder(enc_out, eval_vq=False, vq=True)
86
+
87
+ # latent after quantization
88
+ print(vq_post_emb.shape)
89
+
90
+ # codes
91
+ print("vq id shape:", vq_id.shape)
92
+
93
+ # get prosody code
94
+ prosody_code = vq_id[:1]
95
+ print("prosody code shape:", prosody_code.shape)
96
+
97
+ # get content code
98
+ cotent_code = vq_id[1:3]
99
+ print("content code shape:", cotent_code.shape)
100
+
101
+ # get residual code (acoustic detail codes)
102
+ residual_code = vq_id[3:]
103
+ print("residual code shape:", residual_code.shape)
104
+
105
+ # speaker embedding
106
+ print("speaker embedding shape:", spk_embs.shape)
107
+
108
+ # decode (recommand)
109
+ recon_wav = fa_decoder.inference(vq_post_emb, spk_embs)
110
+ print(recon_wav.shape)
111
+ sf.write("recon.wav", recon_wav[0][0].cpu().numpy(), 16000)
112
+ ```
113
+
114
+
115
+
116
+ ## Some Q&A
117
+
118
+ Q1: What audio sample rate does FACodec support? What is the hop size? How many codes will be generated for each frame?
119
+
120
+ A1: FACodec supports 16KHz speech audio. The hop size is 200 samples, and (16000/200) * 6 (total number of codebooks) codes will be generated for each frame.
121
+
122
+ Q2: Is it possible to train an autoregressive TTS model like VALL-E using FACodec?
123
+
124
+ A2: Yes. In fact, the authors of NaturalSpeech 3 have already employ explore the autoregressive generative model for discrete token generation with FACodec. They use an autoregressive language model to generate prosody codes, followed by a non-autoregressive model to generate the remaining content and acoustic details codes.
125
+
126
+ Q3: Is it possible to train a latent diffusion TTS model like NaturalSpeech2 using FACodec?
127
+
128
+ A3: Yes. You can use the latent getted after quanzaition as the modelling target for the latent diffusion model.
129
+
130
+ Q4: Can FACodec compress and reconstruct audio from other domains? Such as sound effects, music, etc.
131
+
132
+ A4: Since FACodec is designed for speech, it may not be suitable for other audio domains. However, it is possible to use the FACodec model to compress and reconstruct audio from other domains, but the quality may not be as good as the original audio.
133
+
134
+ Q5: Can FACodec be used for content feature for some other tasks like voice conversion?
135
+
136
+ A5: I think the answer is yes. Researchers can use the content code of FACodec as the content feature for voice conversion. We hope to see more research in this direction.
137
+
138
+ ## Citations
139
+
140
+ If you use our FACodec model, please cite the following paper:
141
+
142
+ ```bibtex
143
+ @misc{ju2024naturalspeech,
144
+ title={NaturalSpeech 3: Zero-Shot Speech Synthesis with Factorized Codec and Diffusion Models},
145
+ author={Zeqian Ju and Yuancheng Wang and Kai Shen and Xu Tan and Detai Xin and Dongchao Yang and Yanqing Liu and Yichong Leng and Kaitao Song and Siliang Tang and Zhizheng Wu and Tao Qin and Xiang-Yang Li and Wei Ye and Shikun Zhang and Jiang Bian and Lei He and Jinyu Li and Sheng Zhao},
146
+ year={2024},
147
+ eprint={2403.03100},
148
+ archivePrefix={arXiv},
149
+ primaryClass={eess.AS}
150
+ }
151
+
152
+ @article{zhang2023amphion,
153
+ title={Amphion: An Open-Source Audio, Music and Speech Generation Toolkit},
154
+ author={Xueyao Zhang and Liumeng Xue and Yicheng Gu and Yuancheng Wang and Haorui He and Chaoren Wang and Xi Chen and Zihao Fang and Haopeng Chen and Junan Zhang and Tze Ying Tang and Lexiao Zou and Mingxuan Wang and Jun Han and Kai Chen and Haizhou Li and Zhizheng Wu},
155
+ journal={arXiv},
156
+ year={2024},
157
+ volume={abs/2312.09911}
158
+ }
159
+ ```
160
+
Amphion/models/ns3_codec/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .facodec import *
Amphion/models/ns3_codec/alias_free_torch/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ from .filter import *
5
+ from .resample import *
6
+ from .act import *
Amphion/models/ns3_codec/alias_free_torch/act.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import torch.nn as nn
5
+ from .resample import UpSample1d, DownSample1d
6
+
7
+
8
+ class Activation1d(nn.Module):
9
+ def __init__(
10
+ self,
11
+ activation,
12
+ up_ratio: int = 2,
13
+ down_ratio: int = 2,
14
+ up_kernel_size: int = 12,
15
+ down_kernel_size: int = 12,
16
+ ):
17
+ super().__init__()
18
+ self.up_ratio = up_ratio
19
+ self.down_ratio = down_ratio
20
+ self.act = activation
21
+ self.upsample = UpSample1d(up_ratio, up_kernel_size)
22
+ self.downsample = DownSample1d(down_ratio, down_kernel_size)
23
+
24
+ # x: [B,C,T]
25
+ def forward(self, x):
26
+ x = self.upsample(x)
27
+ x = self.act(x)
28
+ x = self.downsample(x)
29
+
30
+ return x
Amphion/models/ns3_codec/alias_free_torch/filter.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import math
8
+
9
+ if "sinc" in dir(torch):
10
+ sinc = torch.sinc
11
+ else:
12
+ # This code is adopted from adefossez's julius.core.sinc under the MIT License
13
+ # https://adefossez.github.io/julius/julius/core.html
14
+ # LICENSE is in incl_licenses directory.
15
+ def sinc(x: torch.Tensor):
16
+ """
17
+ Implementation of sinc, i.e. sin(pi * x) / (pi * x)
18
+ __Warning__: Different to julius.sinc, the input is multiplied by `pi`!
19
+ """
20
+ return torch.where(
21
+ x == 0,
22
+ torch.tensor(1.0, device=x.device, dtype=x.dtype),
23
+ torch.sin(math.pi * x) / math.pi / x,
24
+ )
25
+
26
+
27
+ # This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
28
+ # https://adefossez.github.io/julius/julius/lowpass.html
29
+ # LICENSE is in incl_licenses directory.
30
+ def kaiser_sinc_filter1d(
31
+ cutoff, half_width, kernel_size
32
+ ): # return filter [1,1,kernel_size]
33
+ even = kernel_size % 2 == 0
34
+ half_size = kernel_size // 2
35
+
36
+ # For kaiser window
37
+ delta_f = 4 * half_width
38
+ A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
39
+ if A > 50.0:
40
+ beta = 0.1102 * (A - 8.7)
41
+ elif A >= 21.0:
42
+ beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0)
43
+ else:
44
+ beta = 0.0
45
+ window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
46
+
47
+ # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
48
+ if even:
49
+ time = torch.arange(-half_size, half_size) + 0.5
50
+ else:
51
+ time = torch.arange(kernel_size) - half_size
52
+ if cutoff == 0:
53
+ filter_ = torch.zeros_like(time)
54
+ else:
55
+ filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
56
+ # Normalize filter to have sum = 1, otherwise we will have a small leakage
57
+ # of the constant component in the input signal.
58
+ filter_ /= filter_.sum()
59
+ filter = filter_.view(1, 1, kernel_size)
60
+
61
+ return filter
62
+
63
+
64
+ class LowPassFilter1d(nn.Module):
65
+ def __init__(
66
+ self,
67
+ cutoff=0.5,
68
+ half_width=0.6,
69
+ stride: int = 1,
70
+ padding: bool = True,
71
+ padding_mode: str = "replicate",
72
+ kernel_size: int = 12,
73
+ ):
74
+ # kernel_size should be even number for stylegan3 setup,
75
+ # in this implementation, odd number is also possible.
76
+ super().__init__()
77
+ if cutoff < -0.0:
78
+ raise ValueError("Minimum cutoff must be larger than zero.")
79
+ if cutoff > 0.5:
80
+ raise ValueError("A cutoff above 0.5 does not make sense.")
81
+ self.kernel_size = kernel_size
82
+ self.even = kernel_size % 2 == 0
83
+ self.pad_left = kernel_size // 2 - int(self.even)
84
+ self.pad_right = kernel_size // 2
85
+ self.stride = stride
86
+ self.padding = padding
87
+ self.padding_mode = padding_mode
88
+ filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
89
+ self.register_buffer("filter", filter)
90
+
91
+ # input [B, C, T]
92
+ def forward(self, x):
93
+ _, C, _ = x.shape
94
+
95
+ if self.padding:
96
+ x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
97
+ out = F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
98
+
99
+ return out
Amphion/models/ns3_codec/alias_free_torch/resample.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import torch.nn as nn
5
+ from torch.nn import functional as F
6
+ from .filter import LowPassFilter1d
7
+ from .filter import kaiser_sinc_filter1d
8
+
9
+
10
+ class UpSample1d(nn.Module):
11
+ def __init__(self, ratio=2, kernel_size=None):
12
+ super().__init__()
13
+ self.ratio = ratio
14
+ self.kernel_size = (
15
+ int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
16
+ )
17
+ self.stride = ratio
18
+ self.pad = self.kernel_size // ratio - 1
19
+ self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
20
+ self.pad_right = (
21
+ self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
22
+ )
23
+ filter = kaiser_sinc_filter1d(
24
+ cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size
25
+ )
26
+ self.register_buffer("filter", filter)
27
+
28
+ # x: [B, C, T]
29
+ def forward(self, x):
30
+ _, C, _ = x.shape
31
+
32
+ x = F.pad(x, (self.pad, self.pad), mode="replicate")
33
+ x = self.ratio * F.conv_transpose1d(
34
+ x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C
35
+ )
36
+ x = x[..., self.pad_left : -self.pad_right]
37
+
38
+ return x
39
+
40
+
41
+ class DownSample1d(nn.Module):
42
+ def __init__(self, ratio=2, kernel_size=None):
43
+ super().__init__()
44
+ self.ratio = ratio
45
+ self.kernel_size = (
46
+ int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
47
+ )
48
+ self.lowpass = LowPassFilter1d(
49
+ cutoff=0.5 / ratio,
50
+ half_width=0.6 / ratio,
51
+ stride=ratio,
52
+ kernel_size=self.kernel_size,
53
+ )
54
+
55
+ def forward(self, x):
56
+ xx = self.lowpass(x)
57
+
58
+ return xx
Amphion/models/ns3_codec/facodec.py ADDED
@@ -0,0 +1,593 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from torch import nn, sin, pow
4
+ from torch.nn import Parameter
5
+ import torch.nn.functional as F
6
+ from torch.nn.utils import weight_norm
7
+ from .alias_free_torch import *
8
+ from .quantize import *
9
+ from einops import rearrange
10
+ from einops.layers.torch import Rearrange
11
+ from .transformer import TransformerEncoder
12
+ from .gradient_reversal import GradientReversal
13
+
14
+
15
+ def init_weights(m):
16
+ if isinstance(m, nn.Conv1d):
17
+ nn.init.trunc_normal_(m.weight, std=0.02)
18
+ nn.init.constant_(m.bias, 0)
19
+
20
+
21
+ def WNConv1d(*args, **kwargs):
22
+ return weight_norm(nn.Conv1d(*args, **kwargs))
23
+
24
+
25
+ def WNConvTranspose1d(*args, **kwargs):
26
+ return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
27
+
28
+
29
+ class CNNLSTM(nn.Module):
30
+ def __init__(self, indim, outdim, head, global_pred=False):
31
+ super().__init__()
32
+ self.global_pred = global_pred
33
+ self.model = nn.Sequential(
34
+ ResidualUnit(indim, dilation=1),
35
+ ResidualUnit(indim, dilation=2),
36
+ ResidualUnit(indim, dilation=3),
37
+ Activation1d(activation=SnakeBeta(indim, alpha_logscale=True)),
38
+ Rearrange("b c t -> b t c"),
39
+ )
40
+ self.heads = nn.ModuleList([nn.Linear(indim, outdim) for i in range(head)])
41
+
42
+ def forward(self, x):
43
+ # x: [B, C, T]
44
+ x = self.model(x)
45
+ if self.global_pred:
46
+ x = torch.mean(x, dim=1, keepdim=False)
47
+ outs = [head(x) for head in self.heads]
48
+ return outs
49
+
50
+
51
+ class SnakeBeta(nn.Module):
52
+ """
53
+ A modified Snake function which uses separate parameters for the magnitude of the periodic components
54
+ Shape:
55
+ - Input: (B, C, T)
56
+ - Output: (B, C, T), same shape as the input
57
+ Parameters:
58
+ - alpha - trainable parameter that controls frequency
59
+ - beta - trainable parameter that controls magnitude
60
+ References:
61
+ - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
62
+ https://arxiv.org/abs/2006.08195
63
+ Examples:
64
+ >>> a1 = snakebeta(256)
65
+ >>> x = torch.randn(256)
66
+ >>> x = a1(x)
67
+ """
68
+
69
+ def __init__(
70
+ self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
71
+ ):
72
+ """
73
+ Initialization.
74
+ INPUT:
75
+ - in_features: shape of the input
76
+ - alpha - trainable parameter that controls frequency
77
+ - beta - trainable parameter that controls magnitude
78
+ alpha is initialized to 1 by default, higher values = higher-frequency.
79
+ beta is initialized to 1 by default, higher values = higher-magnitude.
80
+ alpha will be trained along with the rest of your model.
81
+ """
82
+ super(SnakeBeta, self).__init__()
83
+ self.in_features = in_features
84
+
85
+ # initialize alpha
86
+ self.alpha_logscale = alpha_logscale
87
+ if self.alpha_logscale: # log scale alphas initialized to zeros
88
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
89
+ self.beta = Parameter(torch.zeros(in_features) * alpha)
90
+ else: # linear scale alphas initialized to ones
91
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
92
+ self.beta = Parameter(torch.ones(in_features) * alpha)
93
+
94
+ self.alpha.requires_grad = alpha_trainable
95
+ self.beta.requires_grad = alpha_trainable
96
+
97
+ self.no_div_by_zero = 0.000000001
98
+
99
+ def forward(self, x):
100
+ """
101
+ Forward pass of the function.
102
+ Applies the function to the input elementwise.
103
+ SnakeBeta := x + 1/b * sin^2 (xa)
104
+ """
105
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
106
+ beta = self.beta.unsqueeze(0).unsqueeze(-1)
107
+ if self.alpha_logscale:
108
+ alpha = torch.exp(alpha)
109
+ beta = torch.exp(beta)
110
+ x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
111
+
112
+ return x
113
+
114
+
115
+ class ResidualUnit(nn.Module):
116
+ def __init__(self, dim: int = 16, dilation: int = 1):
117
+ super().__init__()
118
+ pad = ((7 - 1) * dilation) // 2
119
+ self.block = nn.Sequential(
120
+ Activation1d(activation=SnakeBeta(dim, alpha_logscale=True)),
121
+ WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
122
+ Activation1d(activation=SnakeBeta(dim, alpha_logscale=True)),
123
+ WNConv1d(dim, dim, kernel_size=1),
124
+ )
125
+
126
+ def forward(self, x):
127
+ return x + self.block(x)
128
+
129
+
130
+ class EncoderBlock(nn.Module):
131
+ def __init__(self, dim: int = 16, stride: int = 1):
132
+ super().__init__()
133
+ self.block = nn.Sequential(
134
+ ResidualUnit(dim // 2, dilation=1),
135
+ ResidualUnit(dim // 2, dilation=3),
136
+ ResidualUnit(dim // 2, dilation=9),
137
+ Activation1d(activation=SnakeBeta(dim // 2, alpha_logscale=True)),
138
+ WNConv1d(
139
+ dim // 2,
140
+ dim,
141
+ kernel_size=2 * stride,
142
+ stride=stride,
143
+ padding=stride // 2 + stride % 2,
144
+ ),
145
+ )
146
+
147
+ def forward(self, x):
148
+ return self.block(x)
149
+
150
+
151
+ class FACodecEncoder(nn.Module):
152
+ def __init__(
153
+ self,
154
+ ngf=32,
155
+ up_ratios=(2, 4, 5, 5),
156
+ out_channels=1024,
157
+ ):
158
+ super().__init__()
159
+ self.hop_length = np.prod(up_ratios)
160
+ self.up_ratios = up_ratios
161
+
162
+ # Create first convolution
163
+ d_model = ngf
164
+ self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)]
165
+
166
+ # Create EncoderBlocks that double channels as they downsample by `stride`
167
+ for stride in up_ratios:
168
+ d_model *= 2
169
+ self.block += [EncoderBlock(d_model, stride=stride)]
170
+
171
+ # Create last convolution
172
+ self.block += [
173
+ Activation1d(activation=SnakeBeta(d_model, alpha_logscale=True)),
174
+ WNConv1d(d_model, out_channels, kernel_size=3, padding=1),
175
+ ]
176
+
177
+ # Wrap black into nn.Sequential
178
+ self.block = nn.Sequential(*self.block)
179
+ self.enc_dim = d_model
180
+
181
+ self.reset_parameters()
182
+
183
+ def forward(self, x):
184
+ out = self.block(x)
185
+ return out
186
+
187
+ def inference(self, x):
188
+ return self.block(x)
189
+
190
+ def remove_weight_norm(self):
191
+ """Remove weight normalization module from all of the layers."""
192
+
193
+ def _remove_weight_norm(m):
194
+ try:
195
+ torch.nn.utils.remove_weight_norm(m)
196
+ except ValueError: # this module didn't have weight norm
197
+ return
198
+
199
+ self.apply(_remove_weight_norm)
200
+
201
+ def apply_weight_norm(self):
202
+ """Apply weight normalization module from all of the layers."""
203
+
204
+ def _apply_weight_norm(m):
205
+ if isinstance(m, nn.Conv1d):
206
+ torch.nn.utils.weight_norm(m)
207
+
208
+ self.apply(_apply_weight_norm)
209
+
210
+ def reset_parameters(self):
211
+ self.apply(init_weights)
212
+
213
+
214
+ class DecoderBlock(nn.Module):
215
+ def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1):
216
+ super().__init__()
217
+ self.block = nn.Sequential(
218
+ Activation1d(activation=SnakeBeta(input_dim, alpha_logscale=True)),
219
+ WNConvTranspose1d(
220
+ input_dim,
221
+ output_dim,
222
+ kernel_size=2 * stride,
223
+ stride=stride,
224
+ padding=stride // 2 + stride % 2,
225
+ output_padding=stride % 2,
226
+ ),
227
+ ResidualUnit(output_dim, dilation=1),
228
+ ResidualUnit(output_dim, dilation=3),
229
+ ResidualUnit(output_dim, dilation=9),
230
+ )
231
+
232
+ def forward(self, x):
233
+ return self.block(x)
234
+
235
+
236
+ class FACodecDecoder(nn.Module):
237
+ def __init__(
238
+ self,
239
+ in_channels=256,
240
+ upsample_initial_channel=1536,
241
+ ngf=32,
242
+ up_ratios=(5, 5, 4, 2),
243
+ vq_num_q_c=2,
244
+ vq_num_q_p=1,
245
+ vq_num_q_r=3,
246
+ vq_dim=1024,
247
+ vq_commit_weight=0.005,
248
+ vq_weight_init=False,
249
+ vq_full_commit_loss=False,
250
+ codebook_dim=8,
251
+ codebook_size_prosody=10, # true codebook size is equal to 2^codebook_size
252
+ codebook_size_content=10,
253
+ codebook_size_residual=10,
254
+ quantizer_dropout=0.0,
255
+ dropout_type="linear",
256
+ use_gr_content_f0=False,
257
+ use_gr_prosody_phone=False,
258
+ use_gr_residual_f0=False,
259
+ use_gr_residual_phone=False,
260
+ use_gr_x_timbre=False,
261
+ use_random_mask_residual=True,
262
+ prob_random_mask_residual=0.75,
263
+ ):
264
+ super().__init__()
265
+ self.hop_length = np.prod(up_ratios)
266
+ self.ngf = ngf
267
+ self.up_ratios = up_ratios
268
+
269
+ self.use_random_mask_residual = use_random_mask_residual
270
+ self.prob_random_mask_residual = prob_random_mask_residual
271
+
272
+ self.vq_num_q_p = vq_num_q_p
273
+ self.vq_num_q_c = vq_num_q_c
274
+ self.vq_num_q_r = vq_num_q_r
275
+
276
+ self.codebook_size_prosody = codebook_size_prosody
277
+ self.codebook_size_content = codebook_size_content
278
+ self.codebook_size_residual = codebook_size_residual
279
+
280
+ quantizer_class = ResidualVQ
281
+
282
+ self.quantizer = nn.ModuleList()
283
+
284
+ # prosody
285
+ quantizer = quantizer_class(
286
+ num_quantizers=vq_num_q_p,
287
+ dim=vq_dim,
288
+ codebook_size=codebook_size_prosody,
289
+ codebook_dim=codebook_dim,
290
+ threshold_ema_dead_code=2,
291
+ commitment=vq_commit_weight,
292
+ weight_init=vq_weight_init,
293
+ full_commit_loss=vq_full_commit_loss,
294
+ quantizer_dropout=quantizer_dropout,
295
+ dropout_type=dropout_type,
296
+ )
297
+ self.quantizer.append(quantizer)
298
+
299
+ # phone
300
+ quantizer = quantizer_class(
301
+ num_quantizers=vq_num_q_c,
302
+ dim=vq_dim,
303
+ codebook_size=codebook_size_content,
304
+ codebook_dim=codebook_dim,
305
+ threshold_ema_dead_code=2,
306
+ commitment=vq_commit_weight,
307
+ weight_init=vq_weight_init,
308
+ full_commit_loss=vq_full_commit_loss,
309
+ quantizer_dropout=quantizer_dropout,
310
+ dropout_type=dropout_type,
311
+ )
312
+ self.quantizer.append(quantizer)
313
+
314
+ # residual
315
+ if self.vq_num_q_r > 0:
316
+ quantizer = quantizer_class(
317
+ num_quantizers=vq_num_q_r,
318
+ dim=vq_dim,
319
+ codebook_size=codebook_size_residual,
320
+ codebook_dim=codebook_dim,
321
+ threshold_ema_dead_code=2,
322
+ commitment=vq_commit_weight,
323
+ weight_init=vq_weight_init,
324
+ full_commit_loss=vq_full_commit_loss,
325
+ quantizer_dropout=quantizer_dropout,
326
+ dropout_type=dropout_type,
327
+ )
328
+ self.quantizer.append(quantizer)
329
+
330
+ # Add first conv layer
331
+ channels = upsample_initial_channel
332
+ layers = [WNConv1d(in_channels, channels, kernel_size=7, padding=3)]
333
+
334
+ # Add upsampling + MRF blocks
335
+ for i, stride in enumerate(up_ratios):
336
+ input_dim = channels // 2**i
337
+ output_dim = channels // 2 ** (i + 1)
338
+ layers += [DecoderBlock(input_dim, output_dim, stride)]
339
+
340
+ # Add final conv layer
341
+ layers += [
342
+ Activation1d(activation=SnakeBeta(output_dim, alpha_logscale=True)),
343
+ WNConv1d(output_dim, 1, kernel_size=7, padding=3),
344
+ nn.Tanh(),
345
+ ]
346
+
347
+ self.model = nn.Sequential(*layers)
348
+
349
+ self.timbre_encoder = TransformerEncoder(
350
+ enc_emb_tokens=None,
351
+ encoder_layer=4,
352
+ encoder_hidden=256,
353
+ encoder_head=4,
354
+ conv_filter_size=1024,
355
+ conv_kernel_size=5,
356
+ encoder_dropout=0.1,
357
+ use_cln=False,
358
+ )
359
+
360
+ self.timbre_linear = nn.Linear(in_channels, in_channels * 2)
361
+ self.timbre_linear.bias.data[:in_channels] = 1
362
+ self.timbre_linear.bias.data[in_channels:] = 0
363
+ self.timbre_norm = nn.LayerNorm(in_channels, elementwise_affine=False)
364
+
365
+ self.f0_predictor = CNNLSTM(in_channels, 1, 2)
366
+ self.phone_predictor = CNNLSTM(in_channels, 5003, 1)
367
+
368
+ self.use_gr_content_f0 = use_gr_content_f0
369
+ self.use_gr_prosody_phone = use_gr_prosody_phone
370
+ self.use_gr_residual_f0 = use_gr_residual_f0
371
+ self.use_gr_residual_phone = use_gr_residual_phone
372
+ self.use_gr_x_timbre = use_gr_x_timbre
373
+
374
+ if self.vq_num_q_r > 0 and self.use_gr_residual_f0:
375
+ self.res_f0_predictor = nn.Sequential(
376
+ GradientReversal(alpha=1.0), CNNLSTM(in_channels, 1, 2)
377
+ )
378
+
379
+ if self.vq_num_q_r > 0 and self.use_gr_residual_phone > 0:
380
+ self.res_phone_predictor = nn.Sequential(
381
+ GradientReversal(alpha=1.0), CNNLSTM(in_channels, 5003, 1)
382
+ )
383
+
384
+ if self.use_gr_content_f0:
385
+ self.content_f0_predictor = nn.Sequential(
386
+ GradientReversal(alpha=1.0), CNNLSTM(in_channels, 1, 2)
387
+ )
388
+
389
+ if self.use_gr_prosody_phone:
390
+ self.prosody_phone_predictor = nn.Sequential(
391
+ GradientReversal(alpha=1.0), CNNLSTM(in_channels, 5003, 1)
392
+ )
393
+
394
+ if self.use_gr_x_timbre:
395
+ self.x_timbre_predictor = nn.Sequential(
396
+ GradientReversal(alpha=1),
397
+ CNNLSTM(in_channels, 245200, 1, global_pred=True),
398
+ )
399
+
400
+ self.reset_parameters()
401
+
402
+ def quantize(self, x, n_quantizers=None):
403
+ outs, qs, commit_loss, quantized_buf = 0, [], [], []
404
+
405
+ # prosody
406
+ f0_input = x # (B, d, T)
407
+ f0_quantizer = self.quantizer[0]
408
+ out, q, commit, quantized = f0_quantizer(f0_input, n_quantizers=n_quantizers)
409
+ outs += out
410
+ qs.append(q)
411
+ quantized_buf.append(quantized.sum(0))
412
+ commit_loss.append(commit)
413
+
414
+ # phone
415
+ phone_input = x
416
+ phone_quantizer = self.quantizer[1]
417
+ out, q, commit, quantized = phone_quantizer(
418
+ phone_input, n_quantizers=n_quantizers
419
+ )
420
+ outs += out
421
+ qs.append(q)
422
+ quantized_buf.append(quantized.sum(0))
423
+ commit_loss.append(commit)
424
+
425
+ # residual
426
+ if self.vq_num_q_r > 0:
427
+ residual_quantizer = self.quantizer[2]
428
+ residual_input = x - (quantized_buf[0] + quantized_buf[1]).detach()
429
+ out, q, commit, quantized = residual_quantizer(
430
+ residual_input, n_quantizers=n_quantizers
431
+ )
432
+ outs += out
433
+ qs.append(q)
434
+ quantized_buf.append(quantized.sum(0)) # [L, B, C, T] -> [B, C, T]
435
+ commit_loss.append(commit)
436
+
437
+ qs = torch.cat(qs, dim=0)
438
+ commit_loss = torch.cat(commit_loss, dim=0)
439
+ return outs, qs, commit_loss, quantized_buf
440
+
441
+ def forward(
442
+ self,
443
+ x,
444
+ vq=True,
445
+ get_vq=False,
446
+ eval_vq=True,
447
+ speaker_embedding=None,
448
+ n_quantizers=None,
449
+ quantized=None,
450
+ ):
451
+ if get_vq:
452
+ return self.quantizer.get_emb()
453
+ if vq is True:
454
+ if eval_vq:
455
+ self.quantizer.eval()
456
+ x_timbre = x
457
+ outs, qs, commit_loss, quantized_buf = self.quantize(
458
+ x, n_quantizers=n_quantizers
459
+ )
460
+
461
+ x_timbre = x_timbre.transpose(1, 2)
462
+ x_timbre = self.timbre_encoder(x_timbre, None, None)
463
+ x_timbre = x_timbre.transpose(1, 2)
464
+ spk_embs = torch.mean(x_timbre, dim=2)
465
+ return outs, qs, commit_loss, quantized_buf, spk_embs
466
+
467
+ out = {}
468
+
469
+ layer_0 = quantized[0]
470
+ f0, uv = self.f0_predictor(layer_0)
471
+ f0 = rearrange(f0, "... 1 -> ...")
472
+ uv = rearrange(uv, "... 1 -> ...")
473
+
474
+ layer_1 = quantized[1]
475
+ (phone,) = self.phone_predictor(layer_1)
476
+
477
+ out = {"f0": f0, "uv": uv, "phone": phone}
478
+
479
+ if self.use_gr_prosody_phone:
480
+ (prosody_phone,) = self.prosody_phone_predictor(layer_0)
481
+ out["prosody_phone"] = prosody_phone
482
+
483
+ if self.use_gr_content_f0:
484
+ content_f0, content_uv = self.content_f0_predictor(layer_1)
485
+ content_f0 = rearrange(content_f0, "... 1 -> ...")
486
+ content_uv = rearrange(content_uv, "... 1 -> ...")
487
+ out["content_f0"] = content_f0
488
+ out["content_uv"] = content_uv
489
+
490
+ if self.vq_num_q_r > 0:
491
+ layer_2 = quantized[2]
492
+
493
+ if self.use_gr_residual_f0:
494
+ res_f0, res_uv = self.res_f0_predictor(layer_2)
495
+ res_f0 = rearrange(res_f0, "... 1 -> ...")
496
+ res_uv = rearrange(res_uv, "... 1 -> ...")
497
+ out["res_f0"] = res_f0
498
+ out["res_uv"] = res_uv
499
+
500
+ if self.use_gr_residual_phone:
501
+ (res_phone,) = self.res_phone_predictor(layer_2)
502
+ out["res_phone"] = res_phone
503
+
504
+ style = self.timbre_linear(speaker_embedding).unsqueeze(2) # (B, 2d, 1)
505
+ gamma, beta = style.chunk(2, 1) # (B, d, 1)
506
+ if self.vq_num_q_r > 0:
507
+ if self.use_random_mask_residual:
508
+ bsz = quantized[2].shape[0]
509
+ res_mask = np.random.choice(
510
+ [0, 1],
511
+ size=bsz,
512
+ p=[
513
+ self.prob_random_mask_residual,
514
+ 1 - self.prob_random_mask_residual,
515
+ ],
516
+ )
517
+ res_mask = (
518
+ torch.from_numpy(res_mask).unsqueeze(1).unsqueeze(1)
519
+ ) # (B, 1, 1)
520
+ res_mask = res_mask.to(
521
+ device=quantized[2].device, dtype=quantized[2].dtype
522
+ )
523
+ x = (
524
+ quantized[0].detach()
525
+ + quantized[1].detach()
526
+ + quantized[2] * res_mask
527
+ )
528
+ # x = quantized_perturbe[0].detach() + quantized[1].detach() + quantized[2] * res_mask
529
+ else:
530
+ x = quantized[0].detach() + quantized[1].detach() + quantized[2]
531
+ # x = quantized_perturbe[0].detach() + quantized[1].detach() + quantized[2]
532
+ else:
533
+ x = quantized[0].detach() + quantized[1].detach()
534
+ # x = quantized_perturbe[0].detach() + quantized[1].detach()
535
+
536
+ if self.use_gr_x_timbre:
537
+ (x_timbre,) = self.x_timbre_predictor(x)
538
+ out["x_timbre"] = x_timbre
539
+
540
+ x = x.transpose(1, 2)
541
+ x = self.timbre_norm(x)
542
+ x = x.transpose(1, 2)
543
+ x = x * gamma + beta
544
+
545
+ x = self.model(x)
546
+ out["audio"] = x
547
+
548
+ return out
549
+
550
+ def vq2emb(self, vq, use_residual_code=True):
551
+ # vq: [num_quantizer, B, T]
552
+ self.quantizer = self.quantizer.eval()
553
+ out = 0
554
+ out += self.quantizer[0].vq2emb(vq[0 : self.vq_num_q_p])
555
+ out += self.quantizer[1].vq2emb(
556
+ vq[self.vq_num_q_p : self.vq_num_q_p + self.vq_num_q_c]
557
+ )
558
+ if self.vq_num_q_r > 0 and use_residual_code:
559
+ out += self.quantizer[2].vq2emb(vq[self.vq_num_q_p + self.vq_num_q_c :])
560
+ return out
561
+
562
+ def inference(self, x, speaker_embedding):
563
+ style = self.timbre_linear(speaker_embedding).unsqueeze(2) # (B, 2d, 1)
564
+ gamma, beta = style.chunk(2, 1) # (B, d, 1)
565
+ x = x.transpose(1, 2)
566
+ x = self.timbre_norm(x)
567
+ x = x.transpose(1, 2)
568
+ x = x * gamma + beta
569
+ x = self.model(x)
570
+ return x
571
+
572
+ def remove_weight_norm(self):
573
+ """Remove weight normalization module from all of the layers."""
574
+
575
+ def _remove_weight_norm(m):
576
+ try:
577
+ torch.nn.utils.remove_weight_norm(m)
578
+ except ValueError: # this module didn't have weight norm
579
+ return
580
+
581
+ self.apply(_remove_weight_norm)
582
+
583
+ def apply_weight_norm(self):
584
+ """Apply weight normalization module from all of the layers."""
585
+
586
+ def _apply_weight_norm(m):
587
+ if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d):
588
+ torch.nn.utils.weight_norm(m)
589
+
590
+ self.apply(_apply_weight_norm)
591
+
592
+ def reset_parameters(self):
593
+ self.apply(init_weights)
Amphion/models/ns3_codec/gradient_reversal.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.autograd import Function
2
+ import torch
3
+ from torch import nn
4
+
5
+
6
+ class GradientReversal(Function):
7
+ @staticmethod
8
+ def forward(ctx, x, alpha):
9
+ ctx.save_for_backward(x, alpha)
10
+ return x
11
+
12
+ @staticmethod
13
+ def backward(ctx, grad_output):
14
+ grad_input = None
15
+ _, alpha = ctx.saved_tensors
16
+ if ctx.needs_input_grad[0]:
17
+ grad_input = -alpha * grad_output
18
+ return grad_input, None
19
+
20
+
21
+ revgrad = GradientReversal.apply
22
+
23
+
24
+ class GradientReversal(nn.Module):
25
+ def __init__(self, alpha):
26
+ super().__init__()
27
+ self.alpha = torch.tensor(alpha, requires_grad=False)
28
+
29
+ def forward(self, x):
30
+ return revgrad(x, self.alpha)
Amphion/models/ns3_codec/quantize/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .fvq import *
2
+ from .rvq import *
Amphion/models/ns3_codec/quantize/fvq.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from einops import rearrange
8
+ from torch.nn.utils import weight_norm
9
+
10
+
11
+ class FactorizedVectorQuantize(nn.Module):
12
+ def __init__(self, dim, codebook_size, codebook_dim, commitment, **kwargs):
13
+ super().__init__()
14
+ self.codebook_size = codebook_size
15
+ self.codebook_dim = codebook_dim
16
+ self.commitment = commitment
17
+
18
+ if dim != self.codebook_dim:
19
+ self.in_proj = weight_norm(nn.Linear(dim, self.codebook_dim))
20
+ self.out_proj = weight_norm(nn.Linear(self.codebook_dim, dim))
21
+ else:
22
+ self.in_proj = nn.Identity()
23
+ self.out_proj = nn.Identity()
24
+ self._codebook = nn.Embedding(codebook_size, self.codebook_dim)
25
+
26
+ @property
27
+ def codebook(self):
28
+ return self._codebook
29
+
30
+ def forward(self, z):
31
+ """Quantized the input tensor using a fixed codebook and returns
32
+ the corresponding codebook vectors
33
+
34
+ Parameters
35
+ ----------
36
+ z : Tensor[B x D x T]
37
+
38
+ Returns
39
+ -------
40
+ Tensor[B x D x T]
41
+ Quantized continuous representation of input
42
+ Tensor[1]
43
+ Commitment loss to train encoder to predict vectors closer to codebook
44
+ entries
45
+ Tensor[1]
46
+ Codebook loss to update the codebook
47
+ Tensor[B x T]
48
+ Codebook indices (quantized discrete representation of input)
49
+ Tensor[B x D x T]
50
+ Projected latents (continuous representation of input before quantization)
51
+ """
52
+ # transpose since we use linear
53
+
54
+ z = rearrange(z, "b d t -> b t d")
55
+
56
+ # Factorized codes project input into low-dimensional space
57
+ z_e = self.in_proj(z) # z_e : (B x T x D)
58
+ z_e = rearrange(z_e, "b t d -> b d t")
59
+ z_q, indices = self.decode_latents(z_e)
60
+
61
+ if self.training:
62
+ commitment_loss = (
63
+ F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
64
+ * self.commitment
65
+ )
66
+ codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
67
+ commit_loss = commitment_loss + codebook_loss
68
+ else:
69
+ commit_loss = torch.zeros(z.shape[0], device=z.device)
70
+
71
+ z_q = (
72
+ z_e + (z_q - z_e).detach()
73
+ ) # noop in forward pass, straight-through gradient estimator in backward pass
74
+
75
+ z_q = rearrange(z_q, "b d t -> b t d")
76
+ z_q = self.out_proj(z_q)
77
+ z_q = rearrange(z_q, "b t d -> b d t")
78
+
79
+ return z_q, indices, commit_loss
80
+
81
+ def vq2emb(self, vq, proj=True):
82
+ emb = self.embed_code(vq)
83
+ if proj:
84
+ emb = self.out_proj(emb)
85
+ return emb.transpose(1, 2)
86
+
87
+ def get_emb(self):
88
+ return self.codebook.weight
89
+
90
+ def embed_code(self, embed_id):
91
+ return F.embedding(embed_id, self.codebook.weight)
92
+
93
+ def decode_code(self, embed_id):
94
+ return self.embed_code(embed_id).transpose(1, 2)
95
+
96
+ def decode_latents(self, latents):
97
+ encodings = rearrange(latents, "b d t -> (b t) d")
98
+ codebook = self.codebook.weight # codebook: (N x D)
99
+ # L2 normalize encodings and codebook
100
+ encodings = F.normalize(encodings)
101
+ codebook = F.normalize(codebook)
102
+
103
+ # Compute euclidean distance with codebook
104
+ dist = (
105
+ encodings.pow(2).sum(1, keepdim=True)
106
+ - 2 * encodings @ codebook.t()
107
+ + codebook.pow(2).sum(1, keepdim=True).t()
108
+ )
109
+ indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
110
+ z_q = self.decode_code(indices)
111
+ return z_q, indices
Amphion/models/ns3_codec/quantize/rvq.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from .fvq import FactorizedVectorQuantize
5
+
6
+
7
+ class ResidualVQ(nn.Module):
8
+ """Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf"""
9
+
10
+ def __init__(self, *, num_quantizers, codebook_size, **kwargs):
11
+ super().__init__()
12
+ VQ = FactorizedVectorQuantize
13
+ if type(codebook_size) == int:
14
+ codebook_size = [codebook_size] * num_quantizers
15
+ self.layers = nn.ModuleList(
16
+ [VQ(codebook_size=2**size, **kwargs) for size in codebook_size]
17
+ )
18
+ self.num_quantizers = num_quantizers
19
+ self.quantizer_dropout = kwargs.get("quantizer_dropout", 0.0)
20
+ self.dropout_type = kwargs.get("dropout_type", None)
21
+
22
+ def forward(self, x, n_quantizers=None):
23
+ quantized_out = 0.0
24
+ residual = x
25
+
26
+ all_losses = []
27
+ all_indices = []
28
+ all_quantized = []
29
+
30
+ if n_quantizers is None:
31
+ n_quantizers = self.num_quantizers
32
+ if self.training:
33
+ n_quantizers = torch.ones((x.shape[0],)) * self.num_quantizers + 1
34
+ if self.dropout_type == "linear":
35
+ dropout = torch.randint(1, self.num_quantizers + 1, (x.shape[0],))
36
+ elif self.dropout_type == "exp":
37
+ dropout = torch.randint(
38
+ 1, int(math.log2(self.num_quantizers)), (x.shape[0],)
39
+ )
40
+ dropout = torch.pow(2, dropout)
41
+ n_dropout = int(x.shape[0] * self.quantizer_dropout)
42
+ n_quantizers[:n_dropout] = dropout[:n_dropout]
43
+ n_quantizers = n_quantizers.to(x.device)
44
+
45
+ for idx, layer in enumerate(self.layers):
46
+ if not self.training and idx >= n_quantizers:
47
+ break
48
+ quantized, indices, loss = layer(residual)
49
+
50
+ mask = (
51
+ torch.full((x.shape[0],), fill_value=idx, device=x.device)
52
+ < n_quantizers
53
+ )
54
+
55
+ residual = residual - quantized
56
+
57
+ quantized_out = quantized_out + quantized * mask[:, None, None]
58
+
59
+ # loss
60
+ loss = (loss * mask).mean()
61
+
62
+ all_indices.append(indices)
63
+ all_losses.append(loss)
64
+ all_quantized.append(quantized)
65
+ all_losses, all_indices, all_quantized = map(
66
+ torch.stack, (all_losses, all_indices, all_quantized)
67
+ )
68
+ return quantized_out, all_indices, all_losses, all_quantized
69
+
70
+ def vq2emb(self, vq):
71
+ # vq: [n_quantizers, B, T]
72
+ quantized_out = 0.0
73
+ for idx, layer in enumerate(self.layers):
74
+ quantized = layer.vq2emb(vq[idx])
75
+ quantized_out += quantized
76
+ return quantized_out
77
+
78
+ def get_emb(self):
79
+ embs = []
80
+ for idx, layer in enumerate(self.layers):
81
+ embs.append(layer.get_emb())
82
+ return embs
Amphion/models/ns3_codec/transformer.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import math
5
+ from torch.nn import functional as F
6
+
7
+
8
+ class StyleAdaptiveLayerNorm(nn.Module):
9
+ def __init__(self, normalized_shape, eps=1e-5):
10
+ super().__init__()
11
+ self.in_dim = normalized_shape
12
+ self.norm = nn.LayerNorm(self.in_dim, eps=eps, elementwise_affine=False)
13
+ self.style = nn.Linear(self.in_dim, self.in_dim * 2)
14
+ self.style.bias.data[: self.in_dim] = 1
15
+ self.style.bias.data[self.in_dim :] = 0
16
+
17
+
18
+ class PositionalEncoding(nn.Module):
19
+ def __init__(self, d_model, dropout, max_len=5000):
20
+ super().__init__()
21
+
22
+ self.dropout = dropout
23
+ position = torch.arange(max_len).unsqueeze(1)
24
+ div_term = torch.exp(
25
+ torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
26
+ )
27
+ pe = torch.zeros(max_len, 1, d_model)
28
+ pe[:, 0, 0::2] = torch.sin(position * div_term)
29
+ pe[:, 0, 1::2] = torch.cos(position * div_term)
30
+ self.register_buffer("pe", pe)
31
+
32
+ def forward(self, x):
33
+ x = x + self.pe[: x.size(0)]
34
+ return F.dropout(x, self.dropout, training=self.training)
35
+
36
+
37
+ class TransformerFFNLayer(nn.Module):
38
+ def __init__(
39
+ self, encoder_hidden, conv_filter_size, conv_kernel_size, encoder_dropout
40
+ ):
41
+ super().__init__()
42
+
43
+ self.encoder_hidden = encoder_hidden
44
+ self.conv_filter_size = conv_filter_size
45
+ self.conv_kernel_size = conv_kernel_size
46
+ self.encoder_dropout = encoder_dropout
47
+
48
+ self.ffn_1 = nn.Conv1d(
49
+ self.encoder_hidden,
50
+ self.conv_filter_size,
51
+ self.conv_kernel_size,
52
+ padding=self.conv_kernel_size // 2,
53
+ )
54
+ self.ffn_1.weight.data.normal_(0.0, 0.02)
55
+ self.ffn_2 = nn.Linear(self.conv_filter_size, self.encoder_hidden)
56
+ self.ffn_2.weight.data.normal_(0.0, 0.02)
57
+
58
+ def forward(self, x):
59
+ # x: (B, T, d)
60
+ x = self.ffn_1(x.permute(0, 2, 1)).permute(
61
+ 0, 2, 1
62
+ ) # (B, T, d) -> (B, d, T) -> (B, T, d)
63
+ x = F.relu(x)
64
+ x = F.dropout(x, self.encoder_dropout, training=self.training)
65
+ x = self.ffn_2(x)
66
+ return x
67
+
68
+
69
+ class TransformerEncoderLayer(nn.Module):
70
+ def __init__(
71
+ self,
72
+ encoder_hidden,
73
+ encoder_head,
74
+ conv_filter_size,
75
+ conv_kernel_size,
76
+ encoder_dropout,
77
+ use_cln,
78
+ ):
79
+ super().__init__()
80
+ self.encoder_hidden = encoder_hidden
81
+ self.encoder_head = encoder_head
82
+ self.conv_filter_size = conv_filter_size
83
+ self.conv_kernel_size = conv_kernel_size
84
+ self.encoder_dropout = encoder_dropout
85
+ self.use_cln = use_cln
86
+
87
+ if not self.use_cln:
88
+ self.ln_1 = nn.LayerNorm(self.encoder_hidden)
89
+ self.ln_2 = nn.LayerNorm(self.encoder_hidden)
90
+ else:
91
+ self.ln_1 = StyleAdaptiveLayerNorm(self.encoder_hidden)
92
+ self.ln_2 = StyleAdaptiveLayerNorm(self.encoder_hidden)
93
+
94
+ self.self_attn = nn.MultiheadAttention(
95
+ self.encoder_hidden, self.encoder_head, batch_first=True
96
+ )
97
+
98
+ self.ffn = TransformerFFNLayer(
99
+ self.encoder_hidden,
100
+ self.conv_filter_size,
101
+ self.conv_kernel_size,
102
+ self.encoder_dropout,
103
+ )
104
+
105
+ def forward(self, x, key_padding_mask, conditon=None):
106
+ # x: (B, T, d); key_padding_mask: (B, T), mask is 0; condition: (B, T, d)
107
+
108
+ # self attention
109
+ residual = x
110
+ if self.use_cln:
111
+ x = self.ln_1(x, conditon)
112
+ else:
113
+ x = self.ln_1(x)
114
+
115
+ if key_padding_mask != None:
116
+ key_padding_mask_input = ~(key_padding_mask.bool())
117
+ else:
118
+ key_padding_mask_input = None
119
+ x, _ = self.self_attn(
120
+ query=x, key=x, value=x, key_padding_mask=key_padding_mask_input
121
+ )
122
+ x = F.dropout(x, self.encoder_dropout, training=self.training)
123
+ x = residual + x
124
+
125
+ # ffn
126
+ residual = x
127
+ if self.use_cln:
128
+ x = self.ln_2(x, conditon)
129
+ else:
130
+ x = self.ln_2(x)
131
+ x = self.ffn(x)
132
+ x = residual + x
133
+
134
+ return x
135
+
136
+
137
+ class TransformerEncoder(nn.Module):
138
+ def __init__(
139
+ self,
140
+ enc_emb_tokens=None,
141
+ encoder_layer=4,
142
+ encoder_hidden=256,
143
+ encoder_head=4,
144
+ conv_filter_size=1024,
145
+ conv_kernel_size=5,
146
+ encoder_dropout=0.1,
147
+ use_cln=False,
148
+ cfg=None,
149
+ ):
150
+ super().__init__()
151
+
152
+ self.encoder_layer = (
153
+ encoder_layer if encoder_layer is not None else cfg.encoder_layer
154
+ )
155
+ self.encoder_hidden = (
156
+ encoder_hidden if encoder_hidden is not None else cfg.encoder_hidden
157
+ )
158
+ self.encoder_head = (
159
+ encoder_head if encoder_head is not None else cfg.encoder_head
160
+ )
161
+ self.conv_filter_size = (
162
+ conv_filter_size if conv_filter_size is not None else cfg.conv_filter_size
163
+ )
164
+ self.conv_kernel_size = (
165
+ conv_kernel_size if conv_kernel_size is not None else cfg.conv_kernel_size
166
+ )
167
+ self.encoder_dropout = (
168
+ encoder_dropout if encoder_dropout is not None else cfg.encoder_dropout
169
+ )
170
+ self.use_cln = use_cln if use_cln is not None else cfg.use_cln
171
+
172
+ if enc_emb_tokens != None:
173
+ self.use_enc_emb = True
174
+ self.enc_emb_tokens = enc_emb_tokens
175
+ else:
176
+ self.use_enc_emb = False
177
+
178
+ self.position_emb = PositionalEncoding(
179
+ self.encoder_hidden, self.encoder_dropout
180
+ )
181
+
182
+ self.layers = nn.ModuleList([])
183
+ self.layers.extend(
184
+ [
185
+ TransformerEncoderLayer(
186
+ self.encoder_hidden,
187
+ self.encoder_head,
188
+ self.conv_filter_size,
189
+ self.conv_kernel_size,
190
+ self.encoder_dropout,
191
+ self.use_cln,
192
+ )
193
+ for i in range(self.encoder_layer)
194
+ ]
195
+ )
196
+
197
+ if self.use_cln:
198
+ self.last_ln = StyleAdaptiveLayerNorm(self.encoder_hidden)
199
+ else:
200
+ self.last_ln = nn.LayerNorm(self.encoder_hidden)
201
+
202
+ def forward(self, x, key_padding_mask, condition=None):
203
+ if len(x.shape) == 2 and self.use_enc_emb:
204
+ x = self.enc_emb_tokens(x)
205
+ x = self.position_emb(x)
206
+ else:
207
+ x = self.position_emb(x) # (B, T, d)
208
+
209
+ for layer in self.layers:
210
+ x = layer(x, key_padding_mask, condition)
211
+
212
+ if self.use_cln:
213
+ x = self.last_ln(x, condition)
214
+ else:
215
+ x = self.last_ln(x)
216
+
217
+ return x