vaevox / app.py
Humair332's picture
Update app.py
da16507 verified
import math
import os
import tempfile
from dataclasses import dataclass
from typing import List, Optional, Tuple
import gradio as gr
import numpy as np
import soundfile as sf
import torch
import torch.nn.functional as F
from pydantic import BaseModel
from scipy.signal import resample as scipy_resample
from torch import nn
from torch.nn.utils import weight_norm
from huggingface_hub import hf_hub_download
# =========================================================
# AudioVAE model definition (single-file, standalone)
# =========================================================
def WNConv1d(*args, **kwargs):
return weight_norm(nn.Conv1d(*args, **kwargs))
def WNConvTranspose1d(*args, **kwargs):
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
class CausalConv1d(nn.Conv1d):
def __init__(self, *args, padding: int = 0, output_padding: int = 0, **kwargs):
super().__init__(*args, **kwargs)
self.__padding = padding
self.__output_padding = output_padding
def forward(self, x):
x_pad = F.pad(x, (self.__padding * 2 - self.__output_padding, 0))
return super().forward(x_pad)
class CausalTransposeConv1d(nn.ConvTranspose1d):
def __init__(self, *args, padding: int = 0, output_padding: int = 0, **kwargs):
super().__init__(*args, **kwargs)
self.__padding = padding
self.__output_padding = output_padding
def forward(self, x):
return super().forward(x)[..., : -(self.__padding * 2 - self.__output_padding)]
def WNCausalConv1d(*args, **kwargs):
return weight_norm(CausalConv1d(*args, **kwargs))
def WNCausalTransposeConv1d(*args, **kwargs):
return weight_norm(CausalTransposeConv1d(*args, **kwargs))
@torch.jit.script
def snake(x, alpha):
shape = x.shape
x = x.reshape(shape[0], shape[1], -1)
x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
x = x.reshape(shape)
return x
class Snake1d(nn.Module):
def __init__(self, channels):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1, channels, 1))
def forward(self, x):
return snake(x, self.alpha)
class CausalResidualUnit(nn.Module):
def __init__(self, dim: int = 16, dilation: int = 1, kernel: int = 7, groups: int = 1):
super().__init__()
pad = ((7 - 1) * dilation) // 2
self.block = nn.Sequential(
Snake1d(dim),
WNCausalConv1d(
dim,
dim,
kernel_size=kernel,
dilation=dilation,
padding=pad,
groups=groups,
),
Snake1d(dim),
WNCausalConv1d(dim, dim, kernel_size=1),
)
def forward(self, x):
y = self.block(x)
pad = (x.shape[-1] - y.shape[-1]) // 2
assert pad == 0
if pad > 0:
x = x[..., pad:-pad]
return x + y
class CausalEncoderBlock(nn.Module):
def __init__(self, output_dim: int = 16, input_dim=None, stride: int = 1, groups=1):
super().__init__()
input_dim = input_dim or output_dim // 2
self.block = nn.Sequential(
CausalResidualUnit(input_dim, dilation=1, groups=groups),
CausalResidualUnit(input_dim, dilation=3, groups=groups),
CausalResidualUnit(input_dim, dilation=9, groups=groups),
Snake1d(input_dim),
WNCausalConv1d(
input_dim,
output_dim,
kernel_size=2 * stride,
stride=stride,
padding=math.ceil(stride / 2),
output_padding=stride % 2,
),
)
def forward(self, x):
return self.block(x)
class CausalEncoder(nn.Module):
def __init__(
self,
d_model: int = 64,
latent_dim: int = 32,
strides: list = [2, 4, 8, 8],
depthwise: bool = False,
):
super().__init__()
self.block = [WNCausalConv1d(1, d_model, kernel_size=7, padding=3)]
for stride in strides:
d_model *= 2
groups = d_model // 2 if depthwise else 1
self.block += [CausalEncoderBlock(output_dim=d_model, stride=stride, groups=groups)]
self.fc_mu = WNCausalConv1d(d_model, latent_dim, kernel_size=3, padding=1)
self.fc_logvar = WNCausalConv1d(d_model, latent_dim, kernel_size=3, padding=1)
self.block = nn.Sequential(*self.block)
self.enc_dim = d_model
def forward(self, x):
hidden_state = self.block(x)
return {
"hidden_state": hidden_state,
"mu": self.fc_mu(hidden_state),
"logvar": self.fc_logvar(hidden_state),
}
class NoiseBlock(nn.Module):
def __init__(self, dim):
super().__init__()
self.linear = WNCausalConv1d(dim, dim, kernel_size=1, bias=False)
def forward(self, x):
B, C, T = x.shape
noise = torch.randn((B, 1, T), device=x.device, dtype=x.dtype)
h = self.linear(x)
n = noise * h
return x + n
class CausalDecoderBlock(nn.Module):
def __init__(
self,
input_dim: int = 16,
output_dim: int = 8,
stride: int = 1,
groups=1,
use_noise_block: bool = False,
):
super().__init__()
layers = [
Snake1d(input_dim),
WNCausalTransposeConv1d(
input_dim,
output_dim,
kernel_size=2 * stride,
stride=stride,
padding=math.ceil(stride / 2),
output_padding=stride % 2,
),
]
if use_noise_block:
layers.append(NoiseBlock(output_dim))
layers.extend(
[
CausalResidualUnit(output_dim, dilation=1, groups=groups),
CausalResidualUnit(output_dim, dilation=3, groups=groups),
CausalResidualUnit(output_dim, dilation=9, groups=groups),
]
)
self.block = nn.Sequential(*layers)
self.input_channels = input_dim
def forward(self, x):
return self.block(x)
class TransposeLastTwoDim(torch.nn.Module):
def forward(self, x):
return torch.transpose(x, -1, -2)
class SampleRateConditionLayer(nn.Module):
def __init__(
self,
input_dim: int,
sr_bin_buckets: int = None,
cond_type: str = "scale_bias",
cond_dim: int = 128,
out_layer: bool = False,
):
super().__init__()
self.cond_type, out_layer_in_dim = cond_type, input_dim
if cond_type == "scale_bias":
self.scale_embed = nn.Embedding(sr_bin_buckets, input_dim)
self.bias_embed = nn.Embedding(sr_bin_buckets, input_dim)
nn.init.ones_(self.scale_embed.weight)
nn.init.zeros_(self.bias_embed.weight)
elif cond_type == "scale_bias_init":
self.scale_embed = nn.Embedding(sr_bin_buckets, input_dim)
self.bias_embed = nn.Embedding(sr_bin_buckets, input_dim)
nn.init.normal_(self.scale_embed.weight, mean=1)
nn.init.normal_(self.bias_embed.weight)
elif cond_type == "add":
self.cond_embed = nn.Embedding(sr_bin_buckets, input_dim)
nn.init.normal_(self.cond_embed.weight)
elif cond_type == "concat":
self.cond_embed = nn.Embedding(sr_bin_buckets, cond_dim)
assert out_layer, "out_layer must be True for concat cond_type"
out_layer_in_dim = input_dim + cond_dim
else:
raise ValueError(f"Invalid cond_type: {cond_type}")
if out_layer:
self.out_layer = nn.Sequential(
Snake1d(out_layer_in_dim),
WNCausalConv1d(out_layer_in_dim, input_dim, kernel_size=1),
)
else:
self.out_layer = nn.Identity()
def forward(self, x, sr_cond):
if self.cond_type in ("scale_bias", "scale_bias_init"):
x = x * self.scale_embed(sr_cond).unsqueeze(-1) + self.bias_embed(sr_cond).unsqueeze(-1)
elif self.cond_type == "add":
x = x + self.cond_embed(sr_cond).unsqueeze(-1)
elif self.cond_type == "concat":
x = torch.cat([x, self.cond_embed(sr_cond).unsqueeze(-1).repeat(1, 1, x.shape[-1])], dim=1)
return self.out_layer(x)
class CausalDecoder(nn.Module):
def __init__(
self,
input_channel,
channels,
rates,
depthwise: bool = False,
d_out: int = 1,
use_noise_block: bool = False,
sr_bin_boundaries: List[int] = None,
cond_type: str = "scale_bias",
cond_dim: int = 128,
cond_out_layer: bool = False,
):
super().__init__()
if depthwise:
layers = [
WNCausalConv1d(input_channel, input_channel, kernel_size=7, padding=3, groups=input_channel),
WNCausalConv1d(input_channel, channels, kernel_size=1),
]
else:
layers = [WNCausalConv1d(input_channel, channels, kernel_size=7, padding=3)]
for i, stride in enumerate(rates):
input_dim = channels // 2**i
output_dim = channels // 2 ** (i + 1)
groups = output_dim if depthwise else 1
layers += [
CausalDecoderBlock(
input_dim,
output_dim,
stride,
groups=groups,
use_noise_block=use_noise_block,
)
]
layers += [
Snake1d(output_dim),
WNCausalConv1d(output_dim, d_out, kernel_size=7, padding=3),
nn.Tanh(),
]
if sr_bin_boundaries is None:
self.model = nn.Sequential(*layers)
self.sr_bin_boundaries = None
else:
self.model = nn.ModuleList(layers)
self.register_buffer("sr_bin_boundaries", torch.tensor(sr_bin_boundaries, dtype=torch.int32))
self.sr_bin_buckets = len(sr_bin_boundaries) + 1
cond_layers = []
for layer in self.model:
if layer.__class__.__name__ == "CausalDecoderBlock":
cond_layers.append(
SampleRateConditionLayer(
input_dim=layer.input_channels,
sr_bin_buckets=self.sr_bin_buckets,
cond_type=cond_type,
cond_dim=cond_dim,
out_layer=cond_out_layer,
)
)
else:
cond_layers.append(None)
self.sr_cond_model = nn.ModuleList(cond_layers)
def get_sr_idx(self, sr):
return torch.bucketize(sr, self.sr_bin_boundaries)
def forward(self, x, sr_cond=None):
if self.sr_bin_boundaries is not None:
sr_cond = self.get_sr_idx(sr_cond)
for layer, sr_cond_layer in zip(self.model, self.sr_cond_model):
if sr_cond_layer is not None:
x = sr_cond_layer(x, sr_cond)
x = layer(x)
return x
return self.model(x)
class AudioVAEConfig(BaseModel):
encoder_dim: int = 128
encoder_rates: List[int] = [2, 5, 8, 8]
latent_dim: int = 64
decoder_dim: int = 2048
decoder_rates: List[int] = [8, 6, 5, 2, 2, 2]
depthwise: bool = True
sample_rate: int = 16000
out_sample_rate: int = 48000
use_noise_block: bool = False
sr_bin_boundaries: Optional[List[int]] = [20000, 30000, 40000]
cond_type: str = "scale_bias"
cond_dim: int = 128
cond_out_layer: bool = False
class AudioVAE(nn.Module):
def __init__(self, config: AudioVAEConfig = None):
if config is None:
config = AudioVAEConfig()
super().__init__()
self.encoder_dim = config.encoder_dim
self.encoder_rates = config.encoder_rates
self.decoder_dim = config.decoder_dim
self.decoder_rates = config.decoder_rates
self.depthwise = config.depthwise
self.use_noise_block = config.use_noise_block
latent_dim = config.latent_dim
if latent_dim is None:
latent_dim = config.encoder_dim * (2 ** len(config.encoder_rates))
self.latent_dim = latent_dim
self.hop_length = int(np.prod(config.encoder_rates))
self.encoder = CausalEncoder(
config.encoder_dim,
latent_dim,
config.encoder_rates,
depthwise=config.depthwise,
)
self.decoder = CausalDecoder(
latent_dim,
config.decoder_dim,
config.decoder_rates,
depthwise=config.depthwise,
use_noise_block=config.use_noise_block,
sr_bin_boundaries=config.sr_bin_boundaries,
cond_type=config.cond_type,
cond_dim=config.cond_dim,
cond_out_layer=config.cond_out_layer,
)
self.sample_rate = config.sample_rate
self.out_sample_rate = config.out_sample_rate
self.sr_bin_boundaries = config.sr_bin_boundaries
self.chunk_size = math.prod(config.encoder_rates)
self.decode_chunk_size = math.prod(config.decoder_rates)
def preprocess(self, audio_data, sample_rate):
if sample_rate is None:
sample_rate = self.sample_rate
assert sample_rate == self.sample_rate
pad_to = self.hop_length
length = audio_data.shape[-1]
right_pad = math.ceil(length / pad_to) * pad_to - length
audio_data = nn.functional.pad(audio_data, (0, right_pad))
return audio_data
def decode(self, z: torch.Tensor, sr_cond: torch.Tensor = None):
if self.sr_bin_boundaries is not None and sr_cond is None:
sr_cond = torch.tensor([self.out_sample_rate], device=z.device, dtype=torch.int32)
return self.decoder(z, sr_cond)
def streaming_decode(self):
return StreamingVAEDecoder(self)
def encode(self, audio_data: torch.Tensor, sample_rate: int):
if audio_data.ndim == 2:
audio_data = audio_data.unsqueeze(1)
audio_data = self.preprocess(audio_data, sample_rate)
return self.encoder(audio_data)["mu"]
class StreamingVAEDecoder:
def __init__(self, vae: AudioVAE):
self._vae = vae
self._states: dict = {}
self._originals: list = []
def __enter__(self):
self._states.clear()
self._install()
return self
def __exit__(self, *exc):
self._restore()
self._states.clear()
def decode_chunk(self, z_chunk: torch.Tensor) -> torch.Tensor:
return self._vae.decode(z_chunk)
def _install(self):
for _, mod in self._vae.decoder.named_modules():
if isinstance(mod, CausalConv1d):
pad = mod._CausalConv1d__padding * 2 - mod._CausalConv1d__output_padding
if pad > 0:
self._patch_causal_conv(mod, pad)
elif isinstance(mod, CausalTransposeConv1d):
trim = mod._CausalTransposeConv1d__padding * 2 - mod._CausalTransposeConv1d__output_padding
ctx = (mod.kernel_size[0] - 1) // mod.stride[0]
if ctx > 0:
self._patch_transpose_conv(mod, ctx, trim)
def _patch_causal_conv(self, mod, pad_size):
states = self._states
key = id(mod)
orig = mod.forward
def fwd(x, _k=key, _p=pad_size, _m=mod):
x_pad = torch.cat([states[_k], x], dim=-1) if _k in states else F.pad(x, (_p, 0))
if x.shape[-1] >= _p:
states[_k] = x[:, :, -_p:].detach()
else:
prev = states.get(_k, torch.zeros(x.shape[0], x.shape[1], _p, device=x.device, dtype=x.dtype))
states[_k] = torch.cat([prev, x], dim=-1)[:, :, -_p:].detach()
return nn.Conv1d.forward(_m, x_pad)
mod.forward = fwd
self._originals.append((mod, orig))
def _patch_transpose_conv(self, mod, ctx, trim):
states = self._states
key = id(mod)
orig = mod.forward
def fwd(x, _k=key, _c=ctx, _t=trim, _m=mod):
x_full = torch.cat([states[_k], x], dim=-1) if _k in states else F.pad(x, (_c, 0))
states[_k] = x[:, :, -_c:].detach()
out = nn.ConvTranspose1d.forward(_m, x_full)
left = _c * _m.stride[0]
return out[..., left:-_t] if _t > 0 else out[..., left:]
mod.forward = fwd
self._originals.append((mod, orig))
def _restore(self):
for mod, orig in self._originals:
mod.forward = orig
self._originals.clear()
# =========================================================
# Loading utilities
# =========================================================
REPO_ID = os.environ.get("AUDIOVAE_REPO", "openbmb/VoxCPM2")
WEIGHTS_NAME = os.environ.get("AUDIOVAE_WEIGHTS", "audiovae.pth")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TARGET_SR = 16000
@dataclass
class LoadedCodec:
model: AudioVAE
device: str
@property
def sample_rate(self) -> int:
return int(self.model.sample_rate)
@property
def out_sample_rate(self) -> int: # ✅ NEW: expose out_sample_rate
return int(self.model.out_sample_rate)
@property
def hop_length(self) -> int:
return int(self.model.hop_length)
def encode(self, wav: torch.Tensor) -> torch.Tensor:
return self.model.encode(wav, self.sample_rate)
def decode(self, z: torch.Tensor) -> torch.Tensor:
return self.model.decode(z)
def _pick_state_dict(obj):
if isinstance(obj, dict):
for key in ("state_dict", "model", "vae", "audio_vae", "module"):
if key in obj and isinstance(obj[key], dict):
return obj[key]
return obj
@torch.inference_mode()
def load_codec(repo_id: str = REPO_ID, filename: str = WEIGHTS_NAME, device: str = DEVICE) -> LoadedCodec:
path = hf_hub_download(repo_id=repo_id, filename=filename)
ckpt = torch.load(path, map_location="cpu")
state = _pick_state_dict(ckpt)
model = AudioVAE()
missing, unexpected = model.load_state_dict(state, strict=False)
model.to(device).eval()
print(f"[load] repo={repo_id} file={filename} device={device}")
if missing:
print(f"[load] missing keys: {len(missing)}")
if unexpected:
print(f"[load] unexpected keys: {len(unexpected)}")
return LoadedCodec(model=model, device=device)
codec = load_codec()
# =========================================================
# Audio helpers
# =========================================================
def load_audio_file(path: str) -> Tuple[np.ndarray, int]:
audio, sr = sf.read(path, dtype="float32")
if audio.ndim > 1:
audio = audio.mean(axis=1)
return audio.astype(np.float32), int(sr)
def resample_audio(audio: np.ndarray, orig_sr: int, target_sr: int) -> np.ndarray:
if orig_sr == target_sr:
return audio
num_samples = int(round(len(audio) * target_sr / orig_sr))
return scipy_resample(audio, num_samples).astype(np.float32)
def to_tensor(audio: np.ndarray, device: str) -> torch.Tensor:
return torch.from_numpy(audio).unsqueeze(0).unsqueeze(0).to(device)
def save_wav_temp(wav: np.ndarray, sr: int) -> str:
fd, path = tempfile.mkstemp(suffix=".wav")
os.close(fd)
sf.write(path, wav.astype(np.float32), sr)
return path
def fmt_stats(kv: dict) -> str:
lines = ["| Property | Value |", "|---|---|"]
for k, v in kv.items():
lines.append(f"| {k} | `{v}` |")
return "\n".join(lines)
# =========================================================
# Encode / Decode
# =========================================================
def encode_audio(file_path):
if file_path is None:
return None, None, "Upload an audio file first."
audio, sr = load_audio_file(file_path)
orig_len = len(audio)
audio = resample_audio(audio, sr, codec.sample_rate)
wav = to_tensor(audio, codec.device)
with torch.inference_mode():
z = codec.encode(wav) # (B, D, T)
z_btd = z.transpose(1, 2).contiguous() # (B, T, D)
latent = z_btd.squeeze(0).detach().cpu().numpy()
stats = {
"Original SR": f"{sr} Hz",
"Model input SR": f"{codec.sample_rate} Hz",
"Model output SR": f"{codec.out_sample_rate} Hz", # ✅ shown for clarity
"Original samples": f"{orig_len:,}",
"Resampled samples": f"{len(audio):,}",
"Latent shape": str(tuple(latent.shape)),
"Latent dim": f"{latent.shape[-1]}",
"Frames": f"{latent.shape[0]}",
"Hop length": f"{codec.hop_length} samples",
"Approx duration": f"{latent.shape[0] * codec.hop_length / codec.sample_rate:.4f} s",
"Latent min/max": f"{latent.min():.4f} / {latent.max():.4f}",
"Latent mean/std": f"{latent.mean():.4f} / {latent.std():.4f}",
}
return latent.tolist(), latent.tolist(), fmt_stats(stats)
def decode_audio(latent_list, current_stats):
if latent_list is None:
return None, (current_stats or "") + "\n\nNo latent found. Encode first."
try:
z = torch.tensor(latent_list, dtype=torch.float32, device=codec.device)
if z.ndim == 2:
z = z.unsqueeze(0) # (B, T, D)
z = z.transpose(1, 2).contiguous() # (B, D, T)
except Exception as e:
return None, f"Invalid latent: {e}"
with torch.inference_mode():
audio = codec.decode(z)
wav = audio.squeeze().detach().cpu().numpy()
wav = np.nan_to_num(wav)
wav = np.clip(wav, -1.0, 1.0)
# ✅ FIX: use out_sample_rate (48000), NOT sample_rate (16000).
# The decoder upsamples by prod(decoder_rates) = 8×6×5×2×2×2 = 1920,
# so the output SR is 48000 Hz, not 16000 Hz.
out_sr = codec.out_sample_rate
stats = {
"Decoded samples": f"{len(wav):,}",
"Output SR": f"{out_sr} Hz", # ✅ 48000
"Duration": f"{len(wav) / out_sr:.4f} s", # ✅ correct duration
"Wave min/max": f"{wav.min():.4f} / {wav.max():.4f}",
}
merged = (current_stats or "") + "\n\n### Decode Stats\n" + fmt_stats(stats)
return (out_sr, wav), merged # ✅ tell Gradio correct SR
# =========================================================
# UI
# =========================================================
CSS = """
body, .gradio-container {
background: #0d0d0d !important;
color: #eaeaea !important;
}
h1, h2, h3 { color: #00e5a0 !important; }
.gr-button {
background: #00e5a0 !important;
color: #000 !important;
font-weight: 700 !important;
border: none !important;
}
.gr-box, .gr-panel { background: #151515 !important; border: 1px solid #2a2a2a !important; }
code { background: #1e1e1e; padding: 2px 6px; border-radius: 2px; }
"""
with gr.Blocks(css=CSS, title="AudioVAE Encode / Decode") as demo:
gr.Markdown(
f"""
# AudioVAE Encode / Decode
Standalone one-file app for `audiovae.pth`.
**Repo:** `{REPO_ID}`
**Model input SR:** `{codec.sample_rate} Hz`
**Model output SR:** `{codec.out_sample_rate} Hz`
**Hop length:** `{codec.hop_length}`
"""
)
latent_state = gr.State()
with gr.Row():
audio_in = gr.Audio(type="filepath", label="Input Audio")
audio_out = gr.Audio(label="Reconstructed Audio", interactive=False)
with gr.Row():
encode_btn = gr.Button("Encode")
decode_btn = gr.Button("Decode")
stats_out = gr.Markdown(value="Upload an audio file and press Encode.")
with gr.Accordion("Raw latent preview", open=False):
latent_preview = gr.JSON(label="Latent JSON")
encode_btn.click(
fn=encode_audio,
inputs=audio_in,
outputs=[latent_state, latent_preview, stats_out],
)
decode_btn.click(
fn=decode_audio,
inputs=[latent_state, stats_out],
outputs=[audio_out, stats_out],
)
if __name__ == "__main__":
demo.launch()