smulelabs commited on
Commit
c60dea4
·
verified ·
1 Parent(s): 60e84e1

Upload folder using huggingface_hub

Browse files
Files changed (7) hide show
  1. LICENSE +21 -0
  2. README.md +24 -3
  3. main.py +113 -0
  4. model.py +437 -0
  5. requirements.txt +4 -0
  6. smule-renaissance-small.pt +3 -0
  7. spectral_ops.py +33 -0
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 smulelabs
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,3 +1,24 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Smule Renaissance Small
2
+
3
+ A 10.4M paramater generative audio model for restoring degraded vocals in any situation that runs 10.5x faster than real-time on iPhone 12's CPU.
4
+
5
+ Technical Report: [![Technical Report](https://img.shields.io/badge/arXiv-2510.21659-blue.svg)](https://arxiv.org/abs/2510.21659)
6
+
7
+ HuggingFace Model: [![Hugging Face Model](https://img.shields.io/badge/Hugging%20Face-SmuleRenaissanceSmall-yellow.svg)](https://huggingface.co/smulelabs/Smule-Renaissance-Small)
8
+
9
+ Extreme Degradation Bench: [![Hugging Face Model](https://img.shields.io/badge/Hugging%20Face-ExtremeDegradationBench-green.svg)](https://huggingface.co/datasets/smulelabs/ExtremeDegradationBench)
10
+
11
+ ---
12
+
13
+ ## Getting Started
14
+ ### Setting up environment
15
+ ```bash
16
+ # Create a virtual environment
17
+ uv venv cleanup --python=3.10
18
+ source cleanup/bin/activate
19
+ uv pip install -r requirements.txt
20
+ ```
21
+ ### Running the model
22
+ ```bash
23
+ python main.py {path-to-input} -o {path-to-output} -c {path-to-checkpoint}
24
+ ```
main.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ import torchaudio
4
+ from pathlib import Path
5
+ from spectral_ops import STFT, iSTFT
6
+ from model import Renaissance
7
+
8
+ def load_and_preprocess_audio(input_path, device, dtype):
9
+ waveform, sr = torchaudio.load(input_path)
10
+
11
+ if waveform.shape[0] > 1:
12
+ waveform = torch.mean(waveform, dim=0, keepdim=True)
13
+ print(f"Converted to mono from {waveform.shape[0]} channels")
14
+
15
+ if sr != 48000:
16
+ print(f"Resampling from {sr} Hz to 48000 Hz")
17
+ resampler = torchaudio.transforms.Resample(sr, 48000)
18
+ waveform = resampler(waveform)
19
+
20
+ waveform = torchaudio.functional.highpass_biquad(
21
+ waveform, 48000, cutoff_freq=60.0
22
+ )
23
+
24
+ waveform = waveform.to(device).to(dtype)
25
+
26
+ return waveform
27
+
28
+ def normalize_audio(audio):
29
+ normalization_factor = torch.max(torch.abs(audio))
30
+ if normalization_factor > 0:
31
+ normalized_audio = audio / normalization_factor
32
+ else:
33
+ normalized_audio = audio
34
+ return normalized_audio, normalization_factor
35
+
36
+
37
+ def process_audio(model, stft, istft, input_wav, device):
38
+ input_wav_norm, norm_factor = normalize_audio(input_wav)
39
+
40
+ with torch.no_grad():
41
+ input_stft = stft(input_wav_norm)
42
+
43
+ with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
44
+ enhanced_stft = model(input_stft)
45
+
46
+ enhanced_wav = istft(enhanced_stft)
47
+
48
+ if norm_factor > 0:
49
+ enhanced_wav = enhanced_wav * norm_factor
50
+
51
+ return enhanced_wav
52
+
53
+
54
+ def main():
55
+ parser = argparse.ArgumentParser(
56
+ description="Smule Renaissance Vocal Restoration"
57
+ )
58
+ parser.add_argument(
59
+ "input",
60
+ type=str,
61
+ help="Input audio file path"
62
+ )
63
+ parser.add_argument(
64
+ "-o", "--output",
65
+ type=str,
66
+ default=None,
67
+ help="Output audio file path (default: input_enhanced.wav)"
68
+ )
69
+ parser.add_argument(
70
+ "-c", "--checkpoint",
71
+ type=str,
72
+ required=True,
73
+ help="Model checkpoint path"
74
+ )
75
+
76
+ args = parser.parse_args()
77
+
78
+ if args.output is None:
79
+ input_path = Path(args.input)
80
+ args.output = str(input_path.parent / f"{input_path.stem}_enhanced.wav")
81
+
82
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
83
+ if torch.cuda.is_available():
84
+ print("Using device: CUDA with FP16 precision")
85
+ dtype = torch.float16
86
+ else:
87
+ print("Using device: CPU with FP32 precision")
88
+ dtype = torch.float32
89
+
90
+ print(f"Loading model from {args.checkpoint}...")
91
+ model = Renaissance().to(device).to(dtype)
92
+ model.load_state_dict(torch.load(args.checkpoint, map_location=device))
93
+ model.eval()
94
+
95
+ stft = STFT(n_fft=4096, hop_length=2048, win_length=4096)
96
+ istft = iSTFT(n_fft=4096, hop_length=2048, win_length=4096)
97
+
98
+ print(f"Loading audio from {args.input}...")
99
+ input_wav = load_and_preprocess_audio(args.input, device, dtype)
100
+ print(f"Audio duration: {input_wav.shape[1] / 48000:.2f} seconds")
101
+
102
+ print("Processing audio...")
103
+ enhanced_wav = process_audio(model, stft, istft, input_wav, device)
104
+
105
+ print(f"Saving enhanced audio to {args.output}...")
106
+ enhanced_wav_cpu = enhanced_wav.cpu().to(torch.float32)
107
+ torchaudio.save(args.output, enhanced_wav_cpu, 48000)
108
+
109
+ print("Done!")
110
+
111
+
112
+ if __name__ == "__main__":
113
+ main()
model.py ADDED
@@ -0,0 +1,437 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ from typing import Tuple, List, Optional
6
+
7
+ class RMSNorm(nn.Module):
8
+ def __init__(self, dimension: int, eps: float = 1e-5):
9
+ super().__init__()
10
+ self.weight = nn.Parameter(torch.ones(dimension))
11
+ self.eps = eps
12
+
13
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
14
+ input_float = input.half()
15
+ variance = input_float.pow(2).mean(dim=1, keepdim=True)
16
+ input_norm = input_float * torch.rsqrt(variance + self.eps)
17
+ return (input_norm * self.weight.unsqueeze(0).unsqueeze(-1)).type_as(input)
18
+
19
+
20
+ class RotaryEmbedding(nn.Module):
21
+ def __init__(self, dim: int, max_position_embeddings: int = 2048, base: int = 10000):
22
+ super().__init__()
23
+ self.dim = dim
24
+ self.max_position_embeddings = max_position_embeddings
25
+ self.base = base
26
+
27
+ inv_freq = 1. / (self.base ** (torch.arange(0, self.dim, 2).half() / self.dim))
28
+ self.register_buffer('inv_freq', inv_freq)
29
+
30
+ self._set_cos_sin_cache(
31
+ seq_len=max_position_embeddings,
32
+ device=self.inv_freq.device,
33
+ dtype=torch.get_default_dtype()
34
+ )
35
+
36
+ def _set_cos_sin_cache(self, seq_len: int, device: torch.device, dtype: torch.dtype):
37
+ self.max_seq_len_cached = seq_len
38
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
39
+
40
+ freqs = torch.einsum('i,j->ij', t, self.inv_freq)
41
+ emb = torch.cat((freqs, freqs), dim=-1)
42
+
43
+ self.register_buffer('cos_cached', emb.cos()[None, None, :, :].to(dtype), persistent=False)
44
+ self.register_buffer('sin_cached', emb.sin()[None, None, :, :].to(dtype), persistent=False)
45
+
46
+ def forward(self, x: torch.Tensor, seq_len: int) -> Tuple[torch.Tensor, torch.Tensor]:
47
+ if seq_len > self.max_seq_len_cached:
48
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
49
+
50
+ return (
51
+ self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
52
+ self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
53
+ )
54
+
55
+
56
+ def rotate_half(x: torch.Tensor) -> torch.Tensor:
57
+ x1 = x[..., : x.shape[-1] // 2]
58
+ x2 = x[..., x.shape[-1] // 2 :]
59
+ return torch.cat((-x2, x1), dim=-1)
60
+
61
+
62
+ class RoformerLayer(nn.Module):
63
+ def __init__(
64
+ self,
65
+ feature_dim: int,
66
+ num_heads: int = 8,
67
+ max_seq_len: int = 10000,
68
+ dropout: float = 0.0,
69
+ mlp_ratio: float = 4.0,
70
+ rope_base: int = 10000
71
+ ):
72
+ super().__init__()
73
+ assert feature_dim % num_heads == 0, "feature_dim must be divisible by num_heads"
74
+
75
+ self.feature_dim = feature_dim
76
+ self.num_heads = num_heads
77
+ self.head_dim = feature_dim // num_heads
78
+
79
+ self.rotary_emb = RotaryEmbedding(self.head_dim, max_position_embeddings=max_seq_len, base=rope_base)
80
+ self.dropout = dropout
81
+
82
+ self.input_norm = RMSNorm(feature_dim)
83
+ self.qkv_proj = nn.Linear(feature_dim, feature_dim * 3, bias=False)
84
+ self.output_proj = nn.Linear(feature_dim, feature_dim, bias=False)
85
+
86
+ mlp_hidden_dim = int(feature_dim * mlp_ratio)
87
+ self.mlp_norm = RMSNorm(feature_dim)
88
+ self.mlp_up = nn.Linear(feature_dim, mlp_hidden_dim * 2, bias=False)
89
+ self.mlp_down = nn.Linear(mlp_hidden_dim, feature_dim, bias=False)
90
+
91
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
92
+ B, N, T = x.shape
93
+ x_residual = x
94
+ x_norm = self.input_norm(x).transpose(1, 2)
95
+
96
+ qkv = self.qkv_proj(x_norm)
97
+ qkv = qkv.view(B, T, 3, self.num_heads, self.head_dim)
98
+ qkv = qkv.permute(2, 0, 3, 1, 4)
99
+ Q, K, V = qkv[0], qkv[1], qkv[2]
100
+
101
+ cos, sin = self.rotary_emb(Q, seq_len=T)
102
+ Q = (Q * cos) + (rotate_half(Q) * sin)
103
+ K = (K * cos) + (rotate_half(K) * sin)
104
+
105
+ attn_output = F.scaled_dot_product_attention(
106
+ Q, K, V, dropout_p=self.dropout if self.training else 0.0, is_causal=False
107
+ )
108
+
109
+ attn_output = attn_output.permute(0, 2, 1, 3).contiguous().view(B, T, N)
110
+ attn_output = self.output_proj(attn_output).transpose(1, 2)
111
+
112
+ x = x_residual + attn_output
113
+
114
+ x_residual = x
115
+ x_norm = self.mlp_norm(x).transpose(1, 2)
116
+
117
+ mlp_out = self.mlp_up(x_norm)
118
+ gate, values = mlp_out.chunk(2, dim=-1)
119
+ mlp_out = F.silu(gate) * values
120
+ mlp_out = self.mlp_down(mlp_out)
121
+
122
+ output = x_residual + mlp_out.transpose(1, 2)
123
+
124
+ return output
125
+
126
+ class Roformer(nn.Module):
127
+ def __init__(self, input_size, hidden_size, num_head=8, theta=10000, window=10000,
128
+ input_drop=0., attention_drop=0., causal=True):
129
+ super().__init__()
130
+
131
+ self.input_size = input_size
132
+ self.hidden_size = hidden_size // num_head
133
+ self.num_head = num_head
134
+ self.theta = theta
135
+ self.window = window
136
+ cos_freq, sin_freq = self._calc_rotary_emb()
137
+ self.register_buffer("cos_freq", cos_freq)
138
+ self.register_buffer("sin_freq", sin_freq)
139
+
140
+ self.attention_drop = attention_drop
141
+ self.causal = causal
142
+ self.eps = 1e-5
143
+
144
+ self.input_norm = RMSNorm(self.input_size)
145
+ self.input_drop = nn.Dropout(p=input_drop)
146
+ self.weight = nn.Conv1d(self.input_size, self.hidden_size*self.num_head*3, 1, bias=False)
147
+ self.output = nn.Conv1d(self.hidden_size*self.num_head, self.input_size, 1, bias=False)
148
+
149
+ self.MLP = nn.Sequential(RMSNorm(self.input_size),
150
+ nn.Conv1d(self.input_size, self.input_size*8, 1, bias=False),
151
+ nn.SiLU()
152
+ )
153
+ self.MLP_output = nn.Conv1d(self.input_size*4, self.input_size, 1, bias=False)
154
+
155
+ def _calc_rotary_emb(self):
156
+ freq = 1. / (self.theta ** (torch.arange(0, self.hidden_size, 2)[:(self.hidden_size // 2)] / self.hidden_size))
157
+ freq = freq.reshape(1, -1)
158
+ pos = torch.arange(0, self.window).reshape(-1, 1)
159
+ cos_freq = torch.cos(pos*freq)
160
+ sin_freq = torch.sin(pos*freq)
161
+ cos_freq = torch.stack([cos_freq]*2, -1).reshape(self.window, self.hidden_size)
162
+ sin_freq = torch.stack([sin_freq]*2, -1).reshape(self.window, self.hidden_size)
163
+
164
+ return cos_freq, sin_freq
165
+
166
+ def _add_rotary_emb(self, feature, pos):
167
+ N = feature.shape[-1]
168
+
169
+ feature_reshape = feature.reshape(-1, N)
170
+ pos = min(pos, self.window-1)
171
+ cos_freq = self.cos_freq[pos]
172
+ sin_freq = self.sin_freq[pos]
173
+ reverse_sign = torch.from_numpy(np.asarray([-1, 1])).to(feature.device).type(feature.dtype)
174
+ feature_reshape_neg = (torch.flip(feature_reshape.reshape(-1, N//2, 2), [-1]) * reverse_sign.reshape(1, 1, 2)).reshape(-1, N)
175
+ feature_rope = feature_reshape * cos_freq.unsqueeze(0) + feature_reshape_neg * sin_freq.unsqueeze(0)
176
+
177
+ return feature_rope.reshape(feature.shape)
178
+
179
+ def _add_rotary_sequence(self, feature):
180
+ T, N = feature.shape[-2:]
181
+ feature_reshape = feature.reshape(-1, T, N)
182
+
183
+ cos_freq = self.cos_freq[:T]
184
+ sin_freq = self.sin_freq[:T]
185
+ reverse_sign = torch.from_numpy(np.asarray([-1, 1])).to(feature.device).type(feature.dtype)
186
+ feature_reshape_neg = (torch.flip(feature_reshape.reshape(-1, N//2, 2), [-1]) * reverse_sign.reshape(1, 1, 2)).reshape(-1, T, N)
187
+ feature_rope = feature_reshape * cos_freq.unsqueeze(0) + feature_reshape_neg * sin_freq.unsqueeze(0)
188
+
189
+ return feature_rope.reshape(feature.shape)
190
+
191
+ def forward(self, input):
192
+ B, _, T = input.shape
193
+
194
+ weight = self.weight(self.input_drop(self.input_norm(input))).reshape(B, self.num_head, self.hidden_size*3, T).transpose(-2,-1)
195
+ Q, K, V = torch.split(weight, self.hidden_size, dim=-1)
196
+ Q_rot = self._add_rotary_sequence(Q)
197
+ K_rot = self._add_rotary_sequence(K)
198
+
199
+ attention_output = F.scaled_dot_product_attention(Q_rot.contiguous(), K_rot.contiguous(), V.contiguous(), dropout_p=self.attention_drop, is_causal=self.causal) # B, num_head, T, N
200
+ attention_output = attention_output.transpose(-2,-1).reshape(B, -1, T)
201
+ output = self.output(attention_output) + input
202
+
203
+ gate, z = self.MLP(output).chunk(2, dim=1)
204
+ output = output + self.MLP_output(F.silu(gate) * z)
205
+
206
+ return output
207
+
208
+ class ConvBlock(nn.Module):
209
+ def __init__(self, channels: int, kernel_size: int, dilation: int, expansion: int = 4):
210
+ super().__init__()
211
+ padding = (kernel_size - 1) * dilation // 2
212
+
213
+ self.dwconv = nn.Conv1d(
214
+ channels, channels, kernel_size, padding=padding, dilation=dilation, groups=channels
215
+ )
216
+ self.norm = RMSNorm(channels)
217
+ self.pwconv1 = nn.Conv1d(channels, channels * expansion, 1)
218
+ self.act = nn.GLU(dim=1)
219
+ self.pwconv2 = nn.Conv1d(channels * expansion // 2, channels, 1)
220
+
221
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
222
+ x = self.dwconv(x)
223
+ x = self.norm(x)
224
+ x = self.pwconv1(x)
225
+ x = self.act(x)
226
+ x = self.pwconv2(x)
227
+ return x
228
+
229
+
230
+ class ICB(nn.Module):
231
+ def __init__(self, channels: int, kernel_size: int = 7, dilation: int = 1, layer_scale_init_value: float = 1e-6):
232
+ super().__init__()
233
+ self.block1 = ConvBlock(channels, kernel_size, 1, )
234
+ self.block2 = ConvBlock(channels, kernel_size, dilation)
235
+ self.block3 = ConvBlock(channels, kernel_size, 1)
236
+ self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((channels)), requires_grad=True)
237
+
238
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
239
+ residual = x
240
+ x = self.block1(x)
241
+ x = self.block2(x)
242
+ x = self.block3(x)
243
+ return x * self.gamma.unsqueeze(0).unsqueeze(-1) + residual
244
+
245
+
246
+ class BSNet(nn.Module):
247
+ def __init__(
248
+ self,
249
+ feature_dim: int,
250
+ kernel_size: int,
251
+ dilation_rate: int,
252
+ num_heads: int,
253
+ max_bands: int = 512,
254
+ band_rope_base: int = 10000,
255
+ layer_scale_init_value: float = 1e-6
256
+ ):
257
+ super().__init__()
258
+ self.band_net = Roformer(feature_dim, feature_dim, num_head=num_heads, window=max_bands, causal=False)
259
+
260
+ self.seq_net = ICB(
261
+ feature_dim,
262
+ kernel_size=kernel_size,
263
+ dilation=dilation_rate,
264
+ layer_scale_init_value=layer_scale_init_value
265
+ )
266
+
267
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
268
+ B, nband, N, T = input.shape
269
+
270
+ band_input = input.permute(0, 3, 2, 1).reshape(B * T, N, nband)
271
+ band_output = self.band_net(band_input)
272
+ band_output = band_output.view(B, T, N, nband).permute(0, 3, 2, 1)
273
+
274
+ seq_input = band_output.reshape(B * nband, N, T)
275
+ seq_output = self.seq_net(seq_input)
276
+ output = seq_output.view(B, nband, N, T)
277
+
278
+ return output
279
+
280
+ class Renaissance(nn.Module):
281
+ def __init__(
282
+ self,
283
+ n_freqs: int = 2049,
284
+ feature_dim: int = 128,
285
+ layer: int = 9,
286
+ sample_rate: int = 48000,
287
+ dilation_start_layer: int = 3,
288
+ n_bands: int = 80,
289
+ num_heads: int = 16,
290
+ max_seq_len: int = 10000,
291
+ band_rope_base: int = 10000,
292
+ temporal_rope_base: int = 10000
293
+ ):
294
+ super().__init__()
295
+ self.enc_dim = n_freqs
296
+ self.feature_dim = feature_dim
297
+ self.eps = 1e-7
298
+ self.dilation_start_layer = dilation_start_layer
299
+ self.n_bands = n_bands
300
+ self.sr = sample_rate
301
+ self.max_seq_len = max_seq_len
302
+ self.band_rope_base = band_rope_base
303
+ self.temporal_rope_base = temporal_rope_base
304
+
305
+ self.band_width = self._generate_mel_bandwidths()
306
+ self.nband = len(self.band_width)
307
+ assert self.enc_dim == sum(self.band_width), "Mel band splitting failed to cover all frequencies."
308
+
309
+ self._build_feature_extractor()
310
+ self._build_main_network(layer, num_heads)
311
+ self._build_output_synthesis()
312
+
313
+ self.apply(self._init_weights)
314
+
315
+ def _init_weights(self, module):
316
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
317
+ nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
318
+ if module.bias is not None:
319
+ nn.init.zeros_(module.bias)
320
+
321
+ def _generate_mel_bandwidths(self) -> List[int]:
322
+ def hz_to_mel(hz): return 2595 * np.log10(1 + hz / 700.)
323
+ def mel_to_hz(mel): return 700 * (10**(mel / 2595) - 1)
324
+
325
+ min_freq, max_freq = 0.1, self.sr / 2
326
+ min_mel, max_mel = hz_to_mel(min_freq), hz_to_mel(max_freq)
327
+
328
+ mel_points = np.linspace(min_mel, max_mel, self.n_bands + 1)
329
+ hz_points = mel_to_hz(mel_points)
330
+
331
+ bin_width = self.sr / 2 / self.enc_dim
332
+ bw = np.round(np.diff(hz_points) / bin_width).astype(int)
333
+
334
+ bw = np.maximum(1, bw)
335
+
336
+ remainder = self.enc_dim - np.sum(bw)
337
+ if remainder != 0:
338
+ sorted_indices = np.argsort(bw)
339
+ op = 1 if remainder > 0 else -1
340
+ indices_to_adjust = sorted_indices if op == 1 else sorted_indices[::-1]
341
+
342
+ for i in range(abs(remainder)):
343
+ idx = indices_to_adjust[i % len(indices_to_adjust)]
344
+ if bw[idx] + op > 0:
345
+ bw[idx] += op
346
+
347
+ if np.sum(bw) != self.enc_dim:
348
+ bw[-1] += self.enc_dim - np.sum(bw)
349
+
350
+ return bw.tolist()
351
+
352
+ def _build_feature_extractor(self):
353
+ self.feature_extractor_layers = nn.ModuleList([
354
+ nn.Sequential(RMSNorm(bw * 2 + 1), nn.Conv1d(bw * 2 + 1, self.feature_dim, 1))
355
+ for bw in self.band_width
356
+ ])
357
+
358
+ def _build_main_network(self, num_layers, num_heads):
359
+ self.net = nn.ModuleList()
360
+ max_bands = max(512, self.nband * 2)
361
+
362
+ layer_scale_init = 1e-6
363
+
364
+ for i in range(num_layers):
365
+ dilation = min(2 ** max(0, i - self.dilation_start_layer + 1), 4)
366
+ self.net.append(BSNet(
367
+ self.feature_dim,
368
+ kernel_size=7,
369
+ dilation_rate=dilation,
370
+ num_heads=num_heads,
371
+ max_bands=max_bands,
372
+ band_rope_base=self.band_rope_base,
373
+ layer_scale_init_value=layer_scale_init
374
+ ))
375
+
376
+ def _build_output_synthesis(self):
377
+ self.output_layers = nn.ModuleList([
378
+ nn.Sequential(
379
+ RMSNorm(self.feature_dim),
380
+ nn.Conv1d(self.feature_dim, self.feature_dim * 2, 1),
381
+ nn.SiLU(),
382
+ nn.Conv1d(self.feature_dim * 2, bw * 4, kernel_size=1),
383
+ nn.GLU(dim=1),
384
+ ) for bw in self.band_width
385
+ ])
386
+
387
+ def spec_band_split(self, spec: torch.Tensor) -> Tuple[List[torch.Tensor], torch.Tensor]:
388
+ subband_spec_ri = []
389
+ subband_power = []
390
+ band_idx = 0
391
+ for width in self.band_width:
392
+ this_spec_ri = spec[:, band_idx : band_idx + width, :, :]
393
+ subband_spec_ri.append(this_spec_ri)
394
+
395
+ power = (this_spec_ri.pow(2).sum(dim=-1)).sum(dim=1, keepdim=True).add(self.eps).sqrt()
396
+ subband_power.append(power)
397
+ band_idx += width
398
+
399
+ subband_power = torch.cat(subband_power, 1)
400
+ return subband_spec_ri, subband_power
401
+
402
+ def feature_extraction(self, input_spec: torch.Tensor) -> torch.Tensor:
403
+ subband_spec_ri, subband_power = self.spec_band_split(input_spec)
404
+ features = []
405
+ for i in range(self.nband):
406
+ power_for_norm = subband_power[:, i:i+1, :].unsqueeze(1)
407
+ norm_spec_ri = subband_spec_ri[i] / (power_for_norm.transpose(2,3) + self.eps)
408
+ B, F_band, T, _ = norm_spec_ri.shape
409
+ norm_spec_flat = norm_spec_ri.permute(0, 3, 1, 2).reshape(B, F_band*2, T)
410
+
411
+ log_power_feature = torch.log(power_for_norm.squeeze(1) + self.eps)
412
+ feature_input = torch.cat([norm_spec_flat, log_power_feature], dim=1)
413
+
414
+ features.append(self.feature_extractor_layers[i](feature_input))
415
+
416
+ return torch.stack(features, 1)
417
+
418
+ def forward(self, input_spec: torch.Tensor) -> torch.Tensor:
419
+ B, F, T, _ = input_spec.shape
420
+
421
+ features = self.feature_extraction(input_spec)
422
+
423
+ residual_features = features
424
+ processed = features
425
+ for layer in self.net:
426
+ processed = layer(processed)
427
+ processed = processed + residual_features
428
+
429
+ est_spec_bands = []
430
+ for i in range(self.nband):
431
+ band_output = self.output_layers[i](processed[:, i])
432
+ bw = self.band_width[i]
433
+ est_spec_band = band_output.view(B, bw, 2, T).permute(0, 1, 3, 2)
434
+ est_spec_bands.append(est_spec_band)
435
+ est_spec_full = torch.cat(est_spec_bands, dim=1)
436
+
437
+ return est_spec_full
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch==2.7.1
2
+ torchaudio
3
+ argparse
4
+ numpy==2.2.6
smule-renaissance-small.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dd73b487ed058fdea66a25915f9f5a50b3d2dd97c03480f268a1d50a00fe2b06
3
+ size 42064415
spectral_ops.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ class STFT:
4
+ def __init__(self, n_fft, hop_length, win_length):
5
+ self.n_fft = n_fft
6
+ self.hop_length = hop_length
7
+ self.win_length = win_length
8
+ self.window = torch.hann_window(win_length)
9
+
10
+ def __call__(self, y):
11
+ self.window = self.window.to(y.device)
12
+ stft_matrix = torch.stft(
13
+ y,
14
+ n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length,
15
+ window=self.window, return_complex=False, center=True, pad_mode='reflect'
16
+ )
17
+ return stft_matrix
18
+
19
+ class iSTFT:
20
+ def __init__(self, n_fft, hop_length, win_length):
21
+ self.n_fft = n_fft
22
+ self.hop_length = hop_length
23
+ self.win_length = win_length
24
+ self.window = torch.hann_window(win_length)
25
+
26
+ def __call__(self, X):
27
+ self.window = self.window.to(X.device)
28
+ X = torch.view_as_complex(X.contiguous())
29
+ return torch.istft(
30
+ X,
31
+ n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length,
32
+ window=self.window, center=True
33
+ )