Humair332 commited on
Commit
f8f2ee2
Β·
verified Β·
1 Parent(s): 7140878

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +645 -215
app.py CHANGED
@@ -1,291 +1,724 @@
 
 
 
 
 
 
1
  import gradio as gr
2
- import torch
3
  import numpy as np
4
  import soundfile as sf
 
 
 
5
  from scipy.signal import resample as scipy_resample
6
- from dataclasses import dataclass, field
 
7
  from huggingface_hub import hf_hub_download
8
- import time
9
- import json
10
 
11
- # =============================
12
- # DACVAE WRAPPER
13
- # =============================
14
 
15
- @dataclass
16
- class SimpleDACCodec:
17
- model: torch.nn.Module
18
- sample_rate: int
19
- hop_size: int # encoder stride in samples β€” probed at load time
20
- device: torch.device
21
-
22
- @classmethod
23
- def load(cls, repo_id="Aratako/Semantic-DACVAE-Japanese-32dim", device="cpu"):
24
- from dacvae import DACVAE
25
- weights_path = hf_hub_download(repo_id=repo_id, filename="weights.pth")
26
- model = DACVAE.load(weights_path).eval().to(device)
27
- sr = int(model.sample_rate)
28
-
29
- # ── Probe the real hop size ───────────────────────────────────────────
30
- # We feed a known-length signal and measure how many frames come out.
31
- # This is the only correct way β€” no magic constants needed.
32
- # hop = input_samples / output_frames (for a signal long enough to
33
- # avoid edge effects we use 1 second = sr samples)
34
- probe_len = sr # exactly 1 second of silence
35
- dummy = torch.zeros(1, 1, probe_len, device=device,
36
- dtype=next(model.parameters()).dtype)
37
- with torch.inference_mode():
38
- z = model.encode(dummy) # (1, D, T_latent)
39
- t_latent = z.shape[2]
40
- hop = probe_len // t_latent # integer hop in samples
41
-
42
- print(f"[codec] sample_rate={sr} probe_frames={t_latent} "
43
- f"hop={hop} frame_rate={sr/hop:.4f} Hz", flush=True)
44
-
45
- return cls(
46
- model = model,
47
- sample_rate = sr,
48
- hop_size = hop,
49
- device = torch.device(device),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  )
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  @property
53
- def frame_rate(self) -> float:
54
- """Latent frames per second."""
55
- return self.sample_rate / self.hop_size
56
 
57
- def frames_to_seconds(self, num_frames: int) -> float:
58
- """Convert latent frame count -> audio duration in seconds."""
59
- return num_frames * self.hop_size / self.sample_rate
60
 
61
- @torch.inference_mode()
62
- def encode(self, audio: torch.Tensor) -> torch.Tensor:
63
- """audio: (1, 1, T) -> latent: (1, T_latent, D)"""
64
- z = self.model.encode(audio) # (B, D, T)
65
- return z.transpose(1, 2) # (B, T, D)
66
 
67
- @torch.inference_mode()
68
- def decode(self, latent: torch.Tensor) -> torch.Tensor:
69
- """latent: (B, T_latent, D) -> audio: (B, 1, T)"""
70
- return self.model.decode(latent.transpose(1, 2))
71
 
72
 
73
- # =============================
74
- # INIT
75
- # =============================
76
 
77
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
78
- print(f"[init] Using device: {DEVICE}")
79
- codec = SimpleDACCodec.load(device=DEVICE)
80
- print(f"[init] Codec ready. Frame rate = {codec.frame_rate:.4f} Hz "
81
- f"(hop={codec.hop_size}, sr={codec.sample_rate})")
 
 
 
 
 
 
 
 
82
 
 
 
83
 
84
- # =============================
85
- # AUDIO UTILS
86
- # =============================
 
 
 
87
 
88
- def load_audio(path: str) -> tuple[np.ndarray, int]:
 
 
 
 
 
 
 
 
 
 
89
  audio, sr = sf.read(path, dtype="float32")
90
  if audio.ndim > 1:
91
- audio = np.mean(audio, axis=1)
92
- return audio, sr
 
93
 
94
 
95
  def resample_audio(audio: np.ndarray, orig_sr: int, target_sr: int) -> np.ndarray:
96
  if orig_sr == target_sr:
97
  return audio
98
- num_samples = int(len(audio) * target_sr / orig_sr)
99
- return scipy_resample(audio, num_samples)
100
 
101
 
102
- def to_tensor(audio: np.ndarray) -> torch.Tensor:
103
- return torch.from_numpy(audio).unsqueeze(0).unsqueeze(0) # (1, 1, T)
 
104
 
105
 
106
- def format_stats(stats: dict) -> str:
107
- """Render stats dict as a clean markdown table for display."""
108
- lines = ["| Property | Value |", "|---|---|"]
109
- for k, v in stats.items():
110
- lines.append(f"| {k} | `{v}` |")
111
- return "\n".join(lines)
112
 
 
 
 
 
 
113
 
114
- # =============================
115
- # ENCODE
116
- # =============================
117
 
118
- def encode_audio(file):
119
- if file is None:
120
- return None, None, "⚠️ Please upload an audio file first."
121
 
122
- t0 = time.perf_counter()
 
 
 
 
123
 
124
- # Load + resample
125
- audio_orig, sr_orig = load_audio(file)
126
- orig_samples = len(audio_orig)
127
- orig_duration = orig_samples / sr_orig
128
 
129
- audio_resampled = resample_audio(audio_orig, sr_orig, codec.sample_rate)
130
- resampled_samples = len(audio_resampled)
 
131
 
132
- wav = to_tensor(audio_resampled).to(DEVICE)
 
 
133
 
134
- # Encode
135
- latent = codec.encode(wav) # (1, T_latent, D)
136
- t_enc = time.perf_counter() - t0
 
137
 
138
- num_frames = latent.shape[1]
139
- latent_dim = latent.shape[2]
140
- calc_dur = codec.frames_to_seconds(num_frames)
141
 
142
- latent_np = latent.squeeze(0).detach().cpu().numpy() # (T, D)
143
- latent_list = latent_np.tolist()
144
 
145
- # Stats
146
  stats = {
147
- "πŸ“ Original sample rate": f"{sr_orig} Hz",
148
- "🎡 Codec sample rate": f"{codec.sample_rate} Hz",
149
- "⏱ Original duration": f"{orig_duration:.4f} s ({orig_samples:,} samples)",
150
- "⏱ Resampled duration": f"{resampled_samples / codec.sample_rate:.4f} s ({resampled_samples:,} samples)",
151
- "πŸ”’ Latent frames (T)": f"{num_frames}",
152
- "πŸ“ Latent dim (D)": f"{latent_dim}",
153
- "πŸ“ Encoder hop size": f"{codec.hop_size} samples",
154
- "πŸ”„ Latent frame rate": f"{codec.frame_rate:.4f} Hz",
155
- "⏳ Duration from latent": f"{calc_dur:.4f} s (T Γ— hop / sr = {num_frames} Γ— {codec.hop_size} / {codec.sample_rate})",
156
- "βœ… Duration match": f"{'βœ“ exact' if abs(calc_dur - resampled_samples / codec.sample_rate) < 0.05 else '⚠ mismatch'}",
157
- "⚑ Encode time": f"{t_enc*1000:.1f} ms",
158
- "πŸ’Ύ Latent tensor size": f"{latent_np.nbytes / 1024:.1f} KB (float32)",
159
- "πŸ“Š Latent value range": f"[{latent_np.min():.4f}, {latent_np.max():.4f}]",
160
- "πŸ“Š Latent mean / std": f"{latent_np.mean():.4f} / {latent_np.std():.4f}",
161
  }
162
 
163
- stats_md = format_stats(stats)
164
- return latent_list, latent_list, stats_md
165
 
166
 
167
- # =============================
168
- # DECODE
169
- # =============================
170
 
171
- def decode_audio(latent_list, stats_md_current):
172
  if latent_list is None:
173
- return None, (stats_md_current or "") + "\n\n⚠️ No latent found. Encode first."
174
-
175
- t0 = time.perf_counter()
176
 
177
  try:
178
- latent = torch.tensor(latent_list, dtype=torch.float32, device=DEVICE)
 
 
 
179
  except Exception as e:
180
- return None, f"⚠️ Invalid latent: {e}"
181
-
182
- if latent.ndim == 2:
183
- latent = latent.unsqueeze(0) # (1, T, D)
184
-
185
- audio = codec.decode(latent) # (B, 1, T_out)
186
- t_dec = time.perf_counter() - t0
187
-
188
- audio_np = audio.squeeze().detach().cpu().numpy()
189
- audio_np = np.nan_to_num(audio_np)
190
- audio_np = np.clip(audio_np, -1.0, 1.0)
191
-
192
- num_frames = latent.shape[1]
193
- out_samples = len(audio_np)
194
- actual_dur = out_samples / codec.sample_rate
195
- calc_dur = codec.frames_to_seconds(num_frames)
196
- actual_hop = out_samples // num_frames
197
-
198
- decode_stats = {
199
- "πŸ”’ Latent frames decoded": f"{num_frames}",
200
- "πŸ”Š Output samples": f"{out_samples:,}",
201
- "⏱ Reconstructed duration": f"{actual_dur:.4f} s",
202
- "⏳ Duration from latent": f"{calc_dur:.4f} s",
203
- "πŸ” Actual output hop": f"{actual_hop} samples/frame (expected {codec.hop_size})",
204
- "βœ… Formula confirmation": f"T={num_frames} Γ— hop={actual_hop} / sr={codec.sample_rate} = {num_frames * actual_hop / codec.sample_rate:.4f} s",
205
- "⚑ Decode time": f"{t_dec*1000:.1f} ms",
206
- "πŸ“Š Output value range": f"[{audio_np.min():.4f}, {audio_np.max():.4f}]",
207
- }
208
 
209
- decode_md = format_stats(decode_stats)
210
- combined = (stats_md_current or "") + "\n\n### Decode Stats\n" + decode_md
211
 
212
- return (codec.sample_rate, audio_np), combined
 
 
213
 
 
 
 
 
 
 
214
 
215
- # =============================
 
 
 
 
216
  # UI
217
- # =============================
218
 
219
- css = """
220
  body, .gradio-container {
221
  background: #0d0d0d !important;
222
- font-family: 'IBM Plex Mono', monospace !important;
223
- color: #e0e0e0 !important;
224
  }
225
- h1, h2, h3 { color: #00e5a0 !important; letter-spacing: 0.08em; }
226
  .gr-button {
227
  background: #00e5a0 !important;
228
  color: #000 !important;
229
  font-weight: 700 !important;
230
- border-radius: 2px !important;
231
  border: none !important;
232
- font-family: 'IBM Plex Mono', monospace !important;
233
- letter-spacing: 0.05em;
234
  }
235
- .gr-button:hover { background: #00ffa8 !important; }
236
  .gr-box, .gr-panel { background: #151515 !important; border: 1px solid #2a2a2a !important; }
237
- table { width: 100%; border-collapse: collapse; font-size: 0.82em; }
238
- th { color: #00e5a0; border-bottom: 1px solid #2a2a2a; padding: 4px 8px; text-align: left; }
239
- td { padding: 4px 8px; border-bottom: 1px solid #1a1a1a; }
240
- td code { background: #1e1e1e; padding: 2px 6px; border-radius: 2px; color: #a8ff78; }
241
  """
242
 
243
- with gr.Blocks(css=css, title="DACVAE Inspector") as demo:
244
-
245
- gr.HTML("""
246
- <link href="https://fonts.googleapis.com/css2?family=IBM+Plex+Mono:wght@400;700&display=swap" rel="stylesheet">
247
- <div style="padding: 24px 0 8px 0;">
248
- <h1 style="font-size:1.6em; margin:0; letter-spacing:0.12em;">
249
- β—ˆ DACVAE CODEC INSPECTOR
250
- </h1>
251
- <p style="color:#666; margin:4px 0 0 0; font-size:0.78em; letter-spacing:0.06em;">
252
- Aratako/Semantic-DACVAE-Japanese-32dim &nbsp;Β·&nbsp;
253
- sr={sr} Hz &nbsp;Β·&nbsp; hop={hop} &nbsp;Β·&nbsp; frame_rate={fr:.4f} Hz
254
- </p>
255
- </div>
256
- """.format(sr=codec.sample_rate, hop=codec.hop_size, fr=codec.frame_rate))
257
 
258
  latent_state = gr.State()
259
 
260
  with gr.Row():
261
- # ── Left column ───────────────────────────────
262
- with gr.Column(scale=1):
263
- audio_in = gr.Audio(type="filepath", label="Input Audio")
264
- with gr.Row():
265
- encode_btn = gr.Button("β–Ά ENCODE", variant="primary")
266
- decode_btn = gr.Button("β—€ DECODE", variant="primary")
267
- audio_out = gr.Audio(label="Reconstructed Audio", interactive=False)
268
-
269
- # ── Right column ──────────────────────────────
270
- with gr.Column(scale=1):
271
- stats_out = gr.Markdown(
272
- value="*Stats will appear here after encoding.*",
273
- label="Stats"
274
- )
275
 
276
- with gr.Accordion("Raw Latent JSON (first 3 frames)", open=False):
277
- latent_preview = gr.JSON(label="Latent preview")
278
 
279
- # ── Wire up ───────────────────────────────────────
280
- def encode_and_preview(file):
281
- latent_list, _, stats_md = encode_audio(file)
282
- if latent_list is None:
283
- return None, None, stats_md
284
- preview = latent_list[:3] if latent_list else []
285
- return latent_list, preview, stats_md
286
 
287
  encode_btn.click(
288
- fn=encode_and_preview,
289
  inputs=audio_in,
290
  outputs=[latent_state, latent_preview, stats_out],
291
  )
@@ -296,9 +729,6 @@ with gr.Blocks(css=css, title="DACVAE Inspector") as demo:
296
  outputs=[audio_out, stats_out],
297
  )
298
 
299
- # =============================
300
- # RUN
301
- # =============================
302
 
303
  if __name__ == "__main__":
304
- demo.launch(share=True)
 
1
+ import math
2
+ import os
3
+ import tempfile
4
+ from dataclasses import dataclass
5
+ from typing import List, Optional, Tuple
6
+
7
  import gradio as gr
 
8
  import numpy as np
9
  import soundfile as sf
10
+ import torch
11
+ import torch.nn.functional as F
12
+ from pydantic import BaseModel
13
  from scipy.signal import resample as scipy_resample
14
+ from torch import nn
15
+ from torch.nn.utils import weight_norm
16
  from huggingface_hub import hf_hub_download
 
 
17
 
 
 
 
18
 
19
+ # =========================================================
20
+ # AudioVAE model definition (single-file, standalone)
21
+ # =========================================================
22
+
23
+ def WNConv1d(*args, **kwargs):
24
+ return weight_norm(nn.Conv1d(*args, **kwargs))
25
+
26
+
27
+ def WNConvTranspose1d(*args, **kwargs):
28
+ return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
29
+
30
+
31
+ class CausalConv1d(nn.Conv1d):
32
+ def __init__(self, *args, padding: int = 0, output_padding: int = 0, **kwargs):
33
+ super().__init__(*args, **kwargs)
34
+ self.__padding = padding
35
+ self.__output_padding = output_padding
36
+
37
+ def forward(self, x):
38
+ x_pad = F.pad(x, (self.__padding * 2 - self.__output_padding, 0))
39
+ return super().forward(x_pad)
40
+
41
+
42
+ class CausalTransposeConv1d(nn.ConvTranspose1d):
43
+ def __init__(self, *args, padding: int = 0, output_padding: int = 0, **kwargs):
44
+ super().__init__(*args, **kwargs)
45
+ self.__padding = padding
46
+ self.__output_padding = output_padding
47
+
48
+ def forward(self, x):
49
+ return super().forward(x)[..., : -(self.__padding * 2 - self.__output_padding)]
50
+
51
+
52
+
53
+ def WNCausalConv1d(*args, **kwargs):
54
+ return weight_norm(CausalConv1d(*args, **kwargs))
55
+
56
+
57
+
58
+ def WNCausalTransposeConv1d(*args, **kwargs):
59
+ return weight_norm(CausalTransposeConv1d(*args, **kwargs))
60
+
61
+
62
+ @torch.jit.script
63
+ def snake(x, alpha):
64
+ shape = x.shape
65
+ x = x.reshape(shape[0], shape[1], -1)
66
+ x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
67
+ x = x.reshape(shape)
68
+ return x
69
+
70
+
71
+ class Snake1d(nn.Module):
72
+ def __init__(self, channels):
73
+ super().__init__()
74
+ self.alpha = nn.Parameter(torch.ones(1, channels, 1))
75
+
76
+ def forward(self, x):
77
+ return snake(x, self.alpha)
78
+
79
+
80
+ class CausalResidualUnit(nn.Module):
81
+ def __init__(self, dim: int = 16, dilation: int = 1, kernel: int = 7, groups: int = 1):
82
+ super().__init__()
83
+ pad = ((7 - 1) * dilation) // 2
84
+ self.block = nn.Sequential(
85
+ Snake1d(dim),
86
+ WNCausalConv1d(
87
+ dim,
88
+ dim,
89
+ kernel_size=kernel,
90
+ dilation=dilation,
91
+ padding=pad,
92
+ groups=groups,
93
+ ),
94
+ Snake1d(dim),
95
+ WNCausalConv1d(dim, dim, kernel_size=1),
96
+ )
97
+
98
+ def forward(self, x):
99
+ y = self.block(x)
100
+ pad = (x.shape[-1] - y.shape[-1]) // 2
101
+ assert pad == 0
102
+ if pad > 0:
103
+ x = x[..., pad:-pad]
104
+ return x + y
105
+
106
+
107
+ class CausalEncoderBlock(nn.Module):
108
+ def __init__(self, output_dim: int = 16, input_dim=None, stride: int = 1, groups=1):
109
+ super().__init__()
110
+ input_dim = input_dim or output_dim // 2
111
+ self.block = nn.Sequential(
112
+ CausalResidualUnit(input_dim, dilation=1, groups=groups),
113
+ CausalResidualUnit(input_dim, dilation=3, groups=groups),
114
+ CausalResidualUnit(input_dim, dilation=9, groups=groups),
115
+ Snake1d(input_dim),
116
+ WNCausalConv1d(
117
+ input_dim,
118
+ output_dim,
119
+ kernel_size=2 * stride,
120
+ stride=stride,
121
+ padding=math.ceil(stride / 2),
122
+ output_padding=stride % 2,
123
+ ),
124
+ )
125
+
126
+ def forward(self, x):
127
+ return self.block(x)
128
+
129
+
130
+ class CausalEncoder(nn.Module):
131
+ def __init__(
132
+ self,
133
+ d_model: int = 64,
134
+ latent_dim: int = 32,
135
+ strides: list = [2, 4, 8, 8],
136
+ depthwise: bool = False,
137
+ ):
138
+ super().__init__()
139
+ self.block = [WNCausalConv1d(1, d_model, kernel_size=7, padding=3)]
140
+
141
+ for stride in strides:
142
+ d_model *= 2
143
+ groups = d_model // 2 if depthwise else 1
144
+ self.block += [CausalEncoderBlock(output_dim=d_model, stride=stride, groups=groups)]
145
+
146
+ self.fc_mu = WNCausalConv1d(d_model, latent_dim, kernel_size=3, padding=1)
147
+ self.fc_logvar = WNCausalConv1d(d_model, latent_dim, kernel_size=3, padding=1)
148
+
149
+ self.block = nn.Sequential(*self.block)
150
+ self.enc_dim = d_model
151
+
152
+ def forward(self, x):
153
+ hidden_state = self.block(x)
154
+ return {
155
+ "hidden_state": hidden_state,
156
+ "mu": self.fc_mu(hidden_state),
157
+ "logvar": self.fc_logvar(hidden_state),
158
+ }
159
+
160
+
161
+ class NoiseBlock(nn.Module):
162
+ def __init__(self, dim):
163
+ super().__init__()
164
+ self.linear = WNCausalConv1d(dim, dim, kernel_size=1, bias=False)
165
+
166
+ def forward(self, x):
167
+ B, C, T = x.shape
168
+ noise = torch.randn((B, 1, T), device=x.device, dtype=x.dtype)
169
+ h = self.linear(x)
170
+ n = noise * h
171
+ return x + n
172
+
173
+
174
+ class CausalDecoderBlock(nn.Module):
175
+ def __init__(
176
+ self,
177
+ input_dim: int = 16,
178
+ output_dim: int = 8,
179
+ stride: int = 1,
180
+ groups=1,
181
+ use_noise_block: bool = False,
182
+ ):
183
+ super().__init__()
184
+ layers = [
185
+ Snake1d(input_dim),
186
+ WNCausalTransposeConv1d(
187
+ input_dim,
188
+ output_dim,
189
+ kernel_size=2 * stride,
190
+ stride=stride,
191
+ padding=math.ceil(stride / 2),
192
+ output_padding=stride % 2,
193
+ ),
194
+ ]
195
+ if use_noise_block:
196
+ layers.append(NoiseBlock(output_dim))
197
+ layers.extend(
198
+ [
199
+ CausalResidualUnit(output_dim, dilation=1, groups=groups),
200
+ CausalResidualUnit(output_dim, dilation=3, groups=groups),
201
+ CausalResidualUnit(output_dim, dilation=9, groups=groups),
202
+ ]
203
+ )
204
+ self.block = nn.Sequential(*layers)
205
+ self.input_channels = input_dim
206
+
207
+ def forward(self, x):
208
+ return self.block(x)
209
+
210
+
211
+ class TransposeLastTwoDim(torch.nn.Module):
212
+ def forward(self, x):
213
+ return torch.transpose(x, -1, -2)
214
+
215
+
216
+ class SampleRateConditionLayer(nn.Module):
217
+ def __init__(
218
+ self,
219
+ input_dim: int,
220
+ sr_bin_buckets: int = None,
221
+ cond_type: str = "scale_bias",
222
+ cond_dim: int = 128,
223
+ out_layer: bool = False,
224
+ ):
225
+ super().__init__()
226
+
227
+ self.cond_type, out_layer_in_dim = cond_type, input_dim
228
+
229
+ if cond_type == "scale_bias":
230
+ self.scale_embed = nn.Embedding(sr_bin_buckets, input_dim)
231
+ self.bias_embed = nn.Embedding(sr_bin_buckets, input_dim)
232
+ nn.init.ones_(self.scale_embed.weight)
233
+ nn.init.zeros_(self.bias_embed.weight)
234
+ elif cond_type == "scale_bias_init":
235
+ self.scale_embed = nn.Embedding(sr_bin_buckets, input_dim)
236
+ self.bias_embed = nn.Embedding(sr_bin_buckets, input_dim)
237
+ nn.init.normal_(self.scale_embed.weight, mean=1)
238
+ nn.init.normal_(self.bias_embed.weight)
239
+ elif cond_type == "add":
240
+ self.cond_embed = nn.Embedding(sr_bin_buckets, input_dim)
241
+ nn.init.normal_(self.cond_embed.weight)
242
+ elif cond_type == "concat":
243
+ self.cond_embed = nn.Embedding(sr_bin_buckets, cond_dim)
244
+ assert out_layer, "out_layer must be True for concat cond_type"
245
+ out_layer_in_dim = input_dim + cond_dim
246
+ else:
247
+ raise ValueError(f"Invalid cond_type: {cond_type}")
248
+
249
+ if out_layer:
250
+ self.out_layer = nn.Sequential(
251
+ Snake1d(out_layer_in_dim),
252
+ WNCausalConv1d(out_layer_in_dim, input_dim, kernel_size=1),
253
+ )
254
+ else:
255
+ self.out_layer = nn.Identity()
256
+
257
+ def forward(self, x, sr_cond):
258
+ if self.cond_type in ("scale_bias", "scale_bias_init"):
259
+ x = x * self.scale_embed(sr_cond).unsqueeze(-1) + self.bias_embed(sr_cond).unsqueeze(-1)
260
+ elif self.cond_type == "add":
261
+ x = x + self.cond_embed(sr_cond).unsqueeze(-1)
262
+ elif self.cond_type == "concat":
263
+ x = torch.cat([x, self.cond_embed(sr_cond).unsqueeze(-1).repeat(1, 1, x.shape[-1])], dim=1)
264
+
265
+ return self.out_layer(x)
266
+
267
+
268
+ class CausalDecoder(nn.Module):
269
+ def __init__(
270
+ self,
271
+ input_channel,
272
+ channels,
273
+ rates,
274
+ depthwise: bool = False,
275
+ d_out: int = 1,
276
+ use_noise_block: bool = False,
277
+ sr_bin_boundaries: List[int] = None,
278
+ cond_type: str = "scale_bias",
279
+ cond_dim: int = 128,
280
+ cond_out_layer: bool = False,
281
+ ):
282
+ super().__init__()
283
+
284
+ if depthwise:
285
+ layers = [
286
+ WNCausalConv1d(input_channel, input_channel, kernel_size=7, padding=3, groups=input_channel),
287
+ WNCausalConv1d(input_channel, channels, kernel_size=1),
288
+ ]
289
+ else:
290
+ layers = [WNCausalConv1d(input_channel, channels, kernel_size=7, padding=3)]
291
+
292
+ for i, stride in enumerate(rates):
293
+ input_dim = channels // 2**i
294
+ output_dim = channels // 2 ** (i + 1)
295
+ groups = output_dim if depthwise else 1
296
+ layers += [
297
+ CausalDecoderBlock(
298
+ input_dim,
299
+ output_dim,
300
+ stride,
301
+ groups=groups,
302
+ use_noise_block=use_noise_block,
303
+ )
304
+ ]
305
+
306
+ layers += [
307
+ Snake1d(output_dim),
308
+ WNCausalConv1d(output_dim, d_out, kernel_size=7, padding=3),
309
+ nn.Tanh(),
310
+ ]
311
+
312
+ if sr_bin_boundaries is None:
313
+ self.model = nn.Sequential(*layers)
314
+ self.sr_bin_boundaries = None
315
+ else:
316
+ self.model = nn.ModuleList(layers)
317
+ self.register_buffer("sr_bin_boundaries", torch.tensor(sr_bin_boundaries, dtype=torch.int32))
318
+ self.sr_bin_buckets = len(sr_bin_boundaries) + 1
319
+
320
+ cond_layers = []
321
+ for layer in self.model:
322
+ if layer.__class__.__name__ == "CausalDecoderBlock":
323
+ cond_layers.append(
324
+ SampleRateConditionLayer(
325
+ input_dim=layer.input_channels,
326
+ sr_bin_buckets=self.sr_bin_buckets,
327
+ cond_type=cond_type,
328
+ cond_dim=cond_dim,
329
+ out_layer=cond_out_layer,
330
+ )
331
+ )
332
+ else:
333
+ cond_layers.append(None)
334
+ self.sr_cond_model = nn.ModuleList(cond_layers)
335
+
336
+ def get_sr_idx(self, sr):
337
+ return torch.bucketize(sr, self.sr_bin_boundaries)
338
+
339
+ def forward(self, x, sr_cond=None):
340
+ if self.sr_bin_boundaries is not None:
341
+ sr_cond = self.get_sr_idx(sr_cond)
342
+ for layer, sr_cond_layer in zip(self.model, self.sr_cond_model):
343
+ if sr_cond_layer is not None:
344
+ x = sr_cond_layer(x, sr_cond)
345
+ x = layer(x)
346
+ return x
347
+ return self.model(x)
348
+
349
+
350
+ class AudioVAEConfig(BaseModel):
351
+ encoder_dim: int = 128
352
+ encoder_rates: List[int] = [2, 5, 8, 8]
353
+ latent_dim: int = 64
354
+ decoder_dim: int = 2048
355
+ decoder_rates: List[int] = [8, 6, 5, 2, 2, 2]
356
+ depthwise: bool = True
357
+ sample_rate: int = 16000
358
+ out_sample_rate: int = 48000
359
+ use_noise_block: bool = False
360
+ sr_bin_boundaries: Optional[List[int]] = [20000, 30000, 40000]
361
+ cond_type: str = "scale_bias"
362
+ cond_dim: int = 128
363
+ cond_out_layer: bool = False
364
+
365
+
366
+ class AudioVAE(nn.Module):
367
+ def __init__(self, config: AudioVAEConfig = None):
368
+ if config is None:
369
+ config = AudioVAEConfig()
370
+
371
+ super().__init__()
372
+
373
+ self.encoder_dim = config.encoder_dim
374
+ self.encoder_rates = config.encoder_rates
375
+ self.decoder_dim = config.decoder_dim
376
+ self.decoder_rates = config.decoder_rates
377
+ self.depthwise = config.depthwise
378
+ self.use_noise_block = config.use_noise_block
379
+
380
+ latent_dim = config.latent_dim
381
+ if latent_dim is None:
382
+ latent_dim = config.encoder_dim * (2 ** len(config.encoder_rates))
383
+
384
+ self.latent_dim = latent_dim
385
+ self.hop_length = int(np.prod(config.encoder_rates))
386
+
387
+ self.encoder = CausalEncoder(
388
+ config.encoder_dim,
389
+ latent_dim,
390
+ config.encoder_rates,
391
+ depthwise=config.depthwise,
392
+ )
393
+
394
+ self.decoder = CausalDecoder(
395
+ latent_dim,
396
+ config.decoder_dim,
397
+ config.decoder_rates,
398
+ depthwise=config.depthwise,
399
+ use_noise_block=config.use_noise_block,
400
+ sr_bin_boundaries=config.sr_bin_boundaries,
401
+ cond_type=config.cond_type,
402
+ cond_dim=config.cond_dim,
403
+ cond_out_layer=config.cond_out_layer,
404
  )
405
 
406
+ self.sample_rate = config.sample_rate
407
+ self.out_sample_rate = config.out_sample_rate
408
+ self.sr_bin_boundaries = config.sr_bin_boundaries
409
+ self.chunk_size = math.prod(config.encoder_rates)
410
+ self.decode_chunk_size = math.prod(config.decoder_rates)
411
+
412
+ def preprocess(self, audio_data, sample_rate):
413
+ if sample_rate is None:
414
+ sample_rate = self.sample_rate
415
+ assert sample_rate == self.sample_rate
416
+ pad_to = self.hop_length
417
+ length = audio_data.shape[-1]
418
+ right_pad = math.ceil(length / pad_to) * pad_to - length
419
+ audio_data = nn.functional.pad(audio_data, (0, right_pad))
420
+ return audio_data
421
+
422
+ def decode(self, z: torch.Tensor, sr_cond: torch.Tensor = None):
423
+ if self.sr_bin_boundaries is not None and sr_cond is None:
424
+ sr_cond = torch.tensor([self.out_sample_rate], device=z.device, dtype=torch.int32)
425
+ return self.decoder(z, sr_cond)
426
+
427
+ def streaming_decode(self):
428
+ return StreamingVAEDecoder(self)
429
+
430
+ def encode(self, audio_data: torch.Tensor, sample_rate: int):
431
+ if audio_data.ndim == 2:
432
+ audio_data = audio_data.unsqueeze(1)
433
+ audio_data = self.preprocess(audio_data, sample_rate)
434
+ return self.encoder(audio_data)["mu"]
435
+
436
+
437
+ class StreamingVAEDecoder:
438
+ def __init__(self, vae: AudioVAE):
439
+ self._vae = vae
440
+ self._states: dict = {}
441
+ self._originals: list = []
442
+
443
+ def __enter__(self):
444
+ self._states.clear()
445
+ self._install()
446
+ return self
447
+
448
+ def __exit__(self, *exc):
449
+ self._restore()
450
+ self._states.clear()
451
+
452
+ def decode_chunk(self, z_chunk: torch.Tensor) -> torch.Tensor:
453
+ return self._vae.decode(z_chunk)
454
+
455
+ def _install(self):
456
+ for _, mod in self._vae.decoder.named_modules():
457
+ if isinstance(mod, CausalConv1d):
458
+ pad = mod._CausalConv1d__padding * 2 - mod._CausalConv1d__output_padding
459
+ if pad > 0:
460
+ self._patch_causal_conv(mod, pad)
461
+ elif isinstance(mod, CausalTransposeConv1d):
462
+ trim = mod._CausalTransposeConv1d__padding * 2 - mod._CausalTransposeConv1d__output_padding
463
+ ctx = (mod.kernel_size[0] - 1) // mod.stride[0]
464
+ if ctx > 0:
465
+ self._patch_transpose_conv(mod, ctx, trim)
466
+
467
+ def _patch_causal_conv(self, mod, pad_size):
468
+ states = self._states
469
+ key = id(mod)
470
+ orig = mod.forward
471
+
472
+ def fwd(x, _k=key, _p=pad_size, _m=mod):
473
+ x_pad = torch.cat([states[_k], x], dim=-1) if _k in states else F.pad(x, (_p, 0))
474
+ if x.shape[-1] >= _p:
475
+ states[_k] = x[:, :, -_p:].detach()
476
+ else:
477
+ prev = states.get(_k, torch.zeros(x.shape[0], x.shape[1], _p, device=x.device, dtype=x.dtype))
478
+ states[_k] = torch.cat([prev, x], dim=-1)[:, :, -_p:].detach()
479
+ return nn.Conv1d.forward(_m, x_pad)
480
+
481
+ mod.forward = fwd
482
+ self._originals.append((mod, orig))
483
+
484
+ def _patch_transpose_conv(self, mod, ctx, trim):
485
+ states = self._states
486
+ key = id(mod)
487
+ orig = mod.forward
488
+
489
+ def fwd(x, _k=key, _c=ctx, _t=trim, _m=mod):
490
+ x_full = torch.cat([states[_k], x], dim=-1) if _k in states else F.pad(x, (_c, 0))
491
+ states[_k] = x[:, :, -_c:].detach()
492
+ out = nn.ConvTranspose1d.forward(_m, x_full)
493
+ left = _c * _m.stride[0]
494
+ return out[..., left:-_t] if _t > 0 else out[..., left:]
495
+
496
+ mod.forward = fwd
497
+ self._originals.append((mod, orig))
498
+
499
+ def _restore(self):
500
+ for mod, orig in self._originals:
501
+ mod.forward = orig
502
+ self._originals.clear()
503
+
504
+
505
+ # =========================================================
506
+ # Loading utilities
507
+ # =========================================================
508
+
509
+ REPO_ID = os.environ.get("AUDIOVAE_REPO", "openbmb/VoxCPM2")
510
+ WEIGHTS_NAME = os.environ.get("AUDIOVAE_WEIGHTS", "audiovae.pth")
511
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
512
+ TARGET_SR = 16000
513
+
514
+
515
+ @dataclass
516
+ class LoadedCodec:
517
+ model: AudioVAE
518
+ device: str
519
+
520
  @property
521
+ def sample_rate(self) -> int:
522
+ return int(self.model.sample_rate)
 
523
 
524
+ @property
525
+ def hop_length(self) -> int:
526
+ return int(self.model.hop_length)
527
 
528
+ def encode(self, wav: torch.Tensor) -> torch.Tensor:
529
+ return self.model.encode(wav, self.sample_rate)
 
 
 
530
 
531
+ def decode(self, z: torch.Tensor) -> torch.Tensor:
532
+ return self.model.decode(z)
 
 
533
 
534
 
 
 
 
535
 
536
+ def _pick_state_dict(obj):
537
+ if isinstance(obj, dict):
538
+ for key in ("state_dict", "model", "vae", "audio_vae", "module"):
539
+ if key in obj and isinstance(obj[key], dict):
540
+ return obj[key]
541
+ return obj
542
+
543
+
544
+ @torch.inference_mode()
545
+ def load_codec(repo_id: str = REPO_ID, filename: str = WEIGHTS_NAME, device: str = DEVICE) -> LoadedCodec:
546
+ path = hf_hub_download(repo_id=repo_id, filename=filename)
547
+ ckpt = torch.load(path, map_location="cpu")
548
+ state = _pick_state_dict(ckpt)
549
 
550
+ model = AudioVAE()
551
+ missing, unexpected = model.load_state_dict(state, strict=False)
552
 
553
+ model.to(device).eval()
554
+ print(f"[load] repo={repo_id} file={filename} device={device}")
555
+ if missing:
556
+ print(f"[load] missing keys: {len(missing)}")
557
+ if unexpected:
558
+ print(f"[load] unexpected keys: {len(unexpected)}")
559
 
560
+ return LoadedCodec(model=model, device=device)
561
+
562
+
563
+ codec = load_codec()
564
+
565
+
566
+ # =========================================================
567
+ # Audio helpers
568
+ # =========================================================
569
+
570
+ def load_audio_file(path: str) -> Tuple[np.ndarray, int]:
571
  audio, sr = sf.read(path, dtype="float32")
572
  if audio.ndim > 1:
573
+ audio = audio.mean(axis=1)
574
+ return audio.astype(np.float32), int(sr)
575
+
576
 
577
 
578
  def resample_audio(audio: np.ndarray, orig_sr: int, target_sr: int) -> np.ndarray:
579
  if orig_sr == target_sr:
580
  return audio
581
+ num_samples = int(round(len(audio) * target_sr / orig_sr))
582
+ return scipy_resample(audio, num_samples).astype(np.float32)
583
 
584
 
585
+
586
+ def to_tensor(audio: np.ndarray, device: str) -> torch.Tensor:
587
+ return torch.from_numpy(audio).unsqueeze(0).unsqueeze(0).to(device)
588
 
589
 
 
 
 
 
 
 
590
 
591
+ def save_wav_temp(wav: np.ndarray, sr: int) -> str:
592
+ fd, path = tempfile.mkstemp(suffix=".wav")
593
+ os.close(fd)
594
+ sf.write(path, wav.astype(np.float32), sr)
595
+ return path
596
 
 
 
 
597
 
 
 
 
598
 
599
+ def fmt_stats(kv: dict) -> str:
600
+ lines = ["| Property | Value |", "|---|---|"]
601
+ for k, v in kv.items():
602
+ lines.append(f"| {k} | `{v}` |")
603
+ return "\n".join(lines)
604
 
 
 
 
 
605
 
606
+ # =========================================================
607
+ # Encode / Decode
608
+ # =========================================================
609
 
610
+ def encode_audio(file_path):
611
+ if file_path is None:
612
+ return None, None, "Upload an audio file first."
613
 
614
+ audio, sr = load_audio_file(file_path)
615
+ orig_len = len(audio)
616
+ audio = resample_audio(audio, sr, codec.sample_rate)
617
+ wav = to_tensor(audio, codec.device)
618
 
619
+ with torch.inference_mode():
620
+ z = codec.encode(wav) # (B, D, T)
 
621
 
622
+ z_btd = z.transpose(1, 2).contiguous() # (B, T, D)
623
+ latent = z_btd.squeeze(0).detach().cpu().numpy()
624
 
 
625
  stats = {
626
+ "Original SR": f"{sr} Hz",
627
+ "Model SR": f"{codec.sample_rate} Hz",
628
+ "Original samples": f"{orig_len:,}",
629
+ "Resampled samples": f"{len(audio):,}",
630
+ "Latent shape": str(tuple(latent.shape)),
631
+ "Latent dim": f"{latent.shape[-1]}",
632
+ "Frames": f"{latent.shape[0]}",
633
+ "Hop length": f"{codec.hop_length} samples",
634
+ "Approx duration": f"{latent.shape[0] * codec.hop_length / codec.sample_rate:.4f} s",
635
+ "Latent min/max": f"{latent.min():.4f} / {latent.max():.4f}",
636
+ "Latent mean/std": f"{latent.mean():.4f} / {latent.std():.4f}",
 
 
 
637
  }
638
 
639
+ return latent.tolist(), latent.tolist(), fmt_stats(stats)
 
640
 
641
 
 
 
 
642
 
643
+ def decode_audio(latent_list, current_stats):
644
  if latent_list is None:
645
+ return None, (current_stats or "") + "\n\nNo latent found. Encode first."
 
 
646
 
647
  try:
648
+ z = torch.tensor(latent_list, dtype=torch.float32, device=codec.device)
649
+ if z.ndim == 2:
650
+ z = z.unsqueeze(0) # (B, T, D)
651
+ z = z.transpose(1, 2).contiguous() # (B, D, T)
652
  except Exception as e:
653
+ return None, f"Invalid latent: {e}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
654
 
655
+ with torch.inference_mode():
656
+ audio = codec.decode(z)
657
 
658
+ wav = audio.squeeze().detach().cpu().numpy()
659
+ wav = np.nan_to_num(wav)
660
+ wav = np.clip(wav, -1.0, 1.0)
661
 
662
+ stats = {
663
+ "Decoded samples": f"{len(wav):,}",
664
+ "Output SR": f"{codec.sample_rate} Hz",
665
+ "Duration": f"{len(wav) / codec.sample_rate:.4f} s",
666
+ "Wave min/max": f"{wav.min():.4f} / {wav.max():.4f}",
667
+ }
668
 
669
+ merged = (current_stats or "") + "\n\n### Decode Stats\n" + fmt_stats(stats)
670
+ return (codec.sample_rate, wav), merged
671
+
672
+
673
+ # =========================================================
674
  # UI
675
+ # =========================================================
676
 
677
+ CSS = """
678
  body, .gradio-container {
679
  background: #0d0d0d !important;
680
+ color: #eaeaea !important;
 
681
  }
682
+ h1, h2, h3 { color: #00e5a0 !important; }
683
  .gr-button {
684
  background: #00e5a0 !important;
685
  color: #000 !important;
686
  font-weight: 700 !important;
 
687
  border: none !important;
 
 
688
  }
 
689
  .gr-box, .gr-panel { background: #151515 !important; border: 1px solid #2a2a2a !important; }
690
+ code { background: #1e1e1e; padding: 2px 6px; border-radius: 2px; }
 
 
 
691
  """
692
 
693
+ with gr.Blocks(css=CSS, title="AudioVAE Encode / Decode") as demo:
694
+ gr.Markdown(
695
+ f"""
696
+ # AudioVAE Encode / Decode
697
+ Standalone one-file app for `audiovae.pth`.
698
+
699
+ **Repo:** `{REPO_ID}`
700
+ **Model SR:** `{codec.sample_rate} Hz`
701
+ **Hop length:** `{codec.hop_length}`
702
+ """
703
+ )
 
 
 
704
 
705
  latent_state = gr.State()
706
 
707
  with gr.Row():
708
+ audio_in = gr.Audio(type="filepath", label="Input Audio")
709
+ audio_out = gr.Audio(label="Reconstructed Audio", interactive=False)
710
+
711
+ with gr.Row():
712
+ encode_btn = gr.Button("Encode")
713
+ decode_btn = gr.Button("Decode")
 
 
 
 
 
 
 
 
714
 
715
+ stats_out = gr.Markdown(value="Upload an audio file and press Encode.")
 
716
 
717
+ with gr.Accordion("Raw latent preview", open=False):
718
+ latent_preview = gr.JSON(label="Latent JSON")
 
 
 
 
 
719
 
720
  encode_btn.click(
721
+ fn=encode_audio,
722
  inputs=audio_in,
723
  outputs=[latent_state, latent_preview, stats_out],
724
  )
 
729
  outputs=[audio_out, stats_out],
730
  )
731
 
 
 
 
732
 
733
  if __name__ == "__main__":
734
+ demo.launch()