Spaces:
Sleeping
Sleeping
feat: implement SNAC codec and integrate into the codec registry
Browse files- app.py +130 -88
- compare_codec/__init__.py +1 -0
- compare_codec/snac_codec.py +76 -0
- compare_codec/spectrogram.py +27 -0
- pyproject.toml +7 -0
- tests/test_snac_codec.py +64 -0
- uv.lock +64 -1
app.py
CHANGED
|
@@ -6,120 +6,162 @@ import time
|
|
| 6 |
from pathlib import Path
|
| 7 |
|
| 8 |
import gradio as gr
|
|
|
|
| 9 |
|
| 10 |
from compare_codec import CodecConfig, get_all
|
|
|
|
| 11 |
|
| 12 |
MAX_DURATION_S = 30.0
|
| 13 |
|
| 14 |
|
| 15 |
-
def
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
return choices
|
| 22 |
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
codec = get_all()[codec_name]
|
| 28 |
-
for cfg in codec.configs():
|
| 29 |
-
if cfg.name == cfg_name:
|
| 30 |
-
sr = cfg.params.get("sample_rate", codec.sample_rate)
|
| 31 |
-
return codec_name, cfg, sr
|
| 32 |
-
raise ValueError(f"Unknown config: {label}")
|
| 33 |
|
|
|
|
| 34 |
|
| 35 |
-
def compare(audio_path: str | None, selected: list[str]) -> list[dict]:
|
| 36 |
-
if audio_path is None or not selected:
|
| 37 |
-
return []
|
| 38 |
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
codec_name, cfg, sr = _resolve_selection(label)
|
| 42 |
-
codec = get_all()[codec_name]
|
| 43 |
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
elapsed = time.perf_counter() - t0
|
| 47 |
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
|
|
|
|
|
|
| 51 |
|
| 52 |
-
|
| 53 |
-
return results
|
| 54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
audio_in = gr.Audio(
|
| 70 |
-
sources=["upload", "microphone"],
|
| 71 |
-
type="filepath",
|
| 72 |
-
label="Input audio",
|
| 73 |
)
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
|
|
|
|
|
|
|
|
|
| 77 |
)
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
for i in range(max_codecs):
|
| 96 |
-
if i < len(results):
|
| 97 |
-
r = results[i]
|
| 98 |
-
updates.append(gr.update(visible=True))
|
| 99 |
-
updates.append(
|
| 100 |
-
gr.update(value=f"**{r['label']}** — {r['time']:.2f}s")
|
| 101 |
-
)
|
| 102 |
-
updates.append(gr.update(value=r["audio"], label=r["label"]))
|
| 103 |
-
else:
|
| 104 |
-
updates.append(gr.update(visible=False))
|
| 105 |
-
updates.append(gr.update(value=""))
|
| 106 |
-
updates.append(gr.update(value=None))
|
| 107 |
-
return updates
|
| 108 |
-
|
| 109 |
-
all_outputs: list[gr.Component] = []
|
| 110 |
-
for group, md, audio_out in output_components:
|
| 111 |
-
all_outputs.extend([group, md, audio_out])
|
| 112 |
-
|
| 113 |
-
run_btn.click(
|
| 114 |
-
on_compare,
|
| 115 |
-
inputs=[audio_in, codec_select],
|
| 116 |
outputs=all_outputs,
|
| 117 |
)
|
| 118 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
return demo
|
| 120 |
|
| 121 |
|
| 122 |
demo = build_ui()
|
| 123 |
|
| 124 |
if __name__ == "__main__":
|
| 125 |
-
demo.launch()
|
|
|
|
| 6 |
from pathlib import Path
|
| 7 |
|
| 8 |
import gradio as gr
|
| 9 |
+
import numpy as np
|
| 10 |
|
| 11 |
from compare_codec import CodecConfig, get_all
|
| 12 |
+
from compare_codec.spectrogram import generate as make_spectrogram
|
| 13 |
|
| 14 |
MAX_DURATION_S = 30.0
|
| 15 |
|
| 16 |
|
| 17 |
+
def _encode_decode_one(
|
| 18 |
+
audio_path: Path, codec_name: str, cfg: CodecConfig
|
| 19 |
+
) -> tuple[np.ndarray, int, float]:
|
| 20 |
+
"""Run one codec config and return (audio_array, sample_rate, elapsed_seconds)."""
|
| 21 |
+
codec = get_all()[codec_name]
|
| 22 |
+
sr = cfg.params.get("sample_rate", codec.sample_rate)
|
|
|
|
| 23 |
|
| 24 |
+
t0 = time.perf_counter()
|
| 25 |
+
audio_out = codec.encode_decode(audio_path, cfg)
|
| 26 |
+
elapsed = time.perf_counter() - t0
|
| 27 |
|
| 28 |
+
max_samples = int(MAX_DURATION_S * sr)
|
| 29 |
+
if len(audio_out) > max_samples:
|
| 30 |
+
audio_out = audio_out[:max_samples]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
|
| 32 |
+
return audio_out, sr, elapsed
|
| 33 |
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
+
def build_ui() -> gr.Blocks:
|
| 36 |
+
codecs = get_all()
|
|
|
|
|
|
|
| 37 |
|
| 38 |
+
with gr.Blocks(title="compare-codec") as demo:
|
| 39 |
+
gr.Markdown("# compare-codec")
|
|
|
|
| 40 |
|
| 41 |
+
audio_in = gr.Audio(
|
| 42 |
+
sources=["upload", "microphone"],
|
| 43 |
+
type="filepath",
|
| 44 |
+
label="Input audio",
|
| 45 |
+
)
|
| 46 |
|
| 47 |
+
tab_components: dict[str, dict] = {}
|
|
|
|
| 48 |
|
| 49 |
+
with gr.Tabs():
|
| 50 |
+
for codec_name, codec in codecs.items():
|
| 51 |
+
configs = codec.configs()
|
| 52 |
+
config_labels = [c.name for c in configs]
|
| 53 |
|
| 54 |
+
with gr.Tab(label=codec_name) as tab:
|
| 55 |
+
config_dd = gr.Dropdown(
|
| 56 |
+
choices=config_labels,
|
| 57 |
+
value=config_labels[0],
|
| 58 |
+
label="Configuration",
|
| 59 |
+
)
|
| 60 |
+
audio_out = gr.Audio(
|
| 61 |
+
label="Reconstructed",
|
| 62 |
+
type="numpy",
|
| 63 |
+
interactive=False,
|
| 64 |
+
)
|
| 65 |
+
stats_md = gr.Markdown(value="*Upload audio to compare.*")
|
| 66 |
+
spec_img = gr.Image(
|
| 67 |
+
label="Spectrogram",
|
| 68 |
+
type="filepath",
|
| 69 |
+
interactive=False,
|
| 70 |
+
)
|
| 71 |
|
| 72 |
+
tab_components[codec_name] = {
|
| 73 |
+
"tab": tab,
|
| 74 |
+
"config_dd": config_dd,
|
| 75 |
+
"audio_out": audio_out,
|
| 76 |
+
"stats_md": stats_md,
|
| 77 |
+
"spec_img": spec_img,
|
| 78 |
+
"configs": configs,
|
| 79 |
+
"config_labels": config_labels,
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
active_tab = gr.State(value=list(codecs.keys())[0])
|
| 83 |
+
ordered_names = list(codecs.keys())
|
| 84 |
+
|
| 85 |
+
all_outputs = []
|
| 86 |
+
for name in ordered_names:
|
| 87 |
+
c = tab_components[name]
|
| 88 |
+
all_outputs.extend([c["audio_out"], c["stats_md"], c["spec_img"]])
|
| 89 |
+
|
| 90 |
+
def process_all(audio_path: str | None, current_tab: str, *dropdown_values):
|
| 91 |
+
"""Generator that yields updates tab-by-tab, active tab first."""
|
| 92 |
+
if audio_path is None:
|
| 93 |
+
return
|
| 94 |
+
|
| 95 |
+
dd_map = dict(zip(ordered_names, dropdown_values))
|
| 96 |
+
order = [current_tab] + [n for n in ordered_names if n != current_tab]
|
| 97 |
+
results: dict[str, tuple] = {}
|
| 98 |
+
for name in ordered_names:
|
| 99 |
+
results[name] = (
|
| 100 |
+
gr.update(value=None),
|
| 101 |
+
gr.update(value="*Processing...*"),
|
| 102 |
+
gr.update(value=None),
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
for name in order:
|
| 106 |
+
comps = tab_components[name]
|
| 107 |
+
cfg_label = dd_map[name]
|
| 108 |
+
cfg = next(c for c in comps["configs"] if c.name == cfg_label)
|
| 109 |
|
| 110 |
+
audio_out, sr, elapsed = _encode_decode_one(
|
| 111 |
+
Path(audio_path), name, cfg
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
)
|
| 113 |
+
spec_path = make_spectrogram(audio_out, sr)
|
| 114 |
+
|
| 115 |
+
stats_text = (
|
| 116 |
+
f"**{elapsed:.2f}s** | "
|
| 117 |
+
f"{sr / 1000:.0f} kHz | "
|
| 118 |
+
f"{cfg_label}"
|
| 119 |
)
|
| 120 |
+
|
| 121 |
+
results[name] = (
|
| 122 |
+
gr.update(value=(sr, audio_out)),
|
| 123 |
+
gr.update(value=stats_text),
|
| 124 |
+
gr.update(value=str(spec_path)),
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
flat = []
|
| 128 |
+
for n in ordered_names:
|
| 129 |
+
flat.extend(results[n])
|
| 130 |
+
yield flat
|
| 131 |
+
|
| 132 |
+
all_dropdowns = [tab_components[n]["config_dd"] for n in ordered_names]
|
| 133 |
+
|
| 134 |
+
upload_event = audio_in.change(
|
| 135 |
+
fn=process_all,
|
| 136 |
+
inputs=[audio_in, active_tab] + all_dropdowns,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
outputs=all_outputs,
|
| 138 |
)
|
| 139 |
|
| 140 |
+
dropdown_events = []
|
| 141 |
+
for name in ordered_names:
|
| 142 |
+
evt = tab_components[name]["config_dd"].change(
|
| 143 |
+
fn=process_all,
|
| 144 |
+
inputs=[audio_in, active_tab] + all_dropdowns,
|
| 145 |
+
outputs=all_outputs,
|
| 146 |
+
)
|
| 147 |
+
dropdown_events.append(evt)
|
| 148 |
+
|
| 149 |
+
for codec_name, comps in tab_components.items():
|
| 150 |
+
comps["tab"].select(
|
| 151 |
+
fn=lambda name=codec_name: name,
|
| 152 |
+
inputs=[],
|
| 153 |
+
outputs=[active_tab],
|
| 154 |
+
).then(
|
| 155 |
+
fn=process_all,
|
| 156 |
+
inputs=[audio_in, active_tab] + all_dropdowns,
|
| 157 |
+
outputs=all_outputs,
|
| 158 |
+
cancels=[upload_event] + dropdown_events,
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
return demo
|
| 162 |
|
| 163 |
|
| 164 |
demo = build_ui()
|
| 165 |
|
| 166 |
if __name__ == "__main__":
|
| 167 |
+
demo.launch()
|
compare_codec/__init__.py
CHANGED
|
@@ -47,3 +47,4 @@ def get_all() -> dict[str, AudioCodec]:
|
|
| 47 |
|
| 48 |
# Import codec modules so they self-register on startup.
|
| 49 |
from compare_codec import dac as _dac # noqa: E402, F401
|
|
|
|
|
|
| 47 |
|
| 48 |
# Import codec modules so they self-register on startup.
|
| 49 |
from compare_codec import dac as _dac # noqa: E402, F401
|
| 50 |
+
from compare_codec import snac_codec as _snac # noqa: E402, F401
|
compare_codec/snac_codec.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""SNAC (Multi-Scale Neural Audio Codec) — wraps the snac package."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import torchaudio
|
| 10 |
+
|
| 11 |
+
from compare_codec import CodecConfig, register
|
| 12 |
+
|
| 13 |
+
_MODELS = [
|
| 14 |
+
("hubertsiuzdak/snac_24khz", 24_000),
|
| 15 |
+
("hubertsiuzdak/snac_32khz", 32_000),
|
| 16 |
+
("hubertsiuzdak/snac_44khz", 44_100),
|
| 17 |
+
]
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class SNACCodec:
|
| 21 |
+
"""SNAC codec with lazy model loading."""
|
| 22 |
+
|
| 23 |
+
def __init__(self) -> None:
|
| 24 |
+
self._models: dict[str, object] = {}
|
| 25 |
+
|
| 26 |
+
@property
|
| 27 |
+
def name(self) -> str:
|
| 28 |
+
return "SNAC"
|
| 29 |
+
|
| 30 |
+
@property
|
| 31 |
+
def sample_rate(self) -> int:
|
| 32 |
+
return 24_000
|
| 33 |
+
|
| 34 |
+
def configs(self) -> list[CodecConfig]:
|
| 35 |
+
configs = []
|
| 36 |
+
for model_id, sr in _MODELS:
|
| 37 |
+
label = f"{sr // 1000}kHz"
|
| 38 |
+
configs.append(
|
| 39 |
+
CodecConfig(
|
| 40 |
+
name=label,
|
| 41 |
+
params={"model_id": model_id, "sample_rate": sr},
|
| 42 |
+
)
|
| 43 |
+
)
|
| 44 |
+
return configs
|
| 45 |
+
|
| 46 |
+
def _get_model(self, model_id: str) -> object:
|
| 47 |
+
if model_id not in self._models:
|
| 48 |
+
from snac import SNAC
|
| 49 |
+
|
| 50 |
+
model = SNAC.from_pretrained(model_id)
|
| 51 |
+
self._models[model_id] = model
|
| 52 |
+
return self._models[model_id]
|
| 53 |
+
|
| 54 |
+
@torch.no_grad()
|
| 55 |
+
def encode_decode(self, audio_path: Path, config: CodecConfig) -> np.ndarray:
|
| 56 |
+
model_id: str = config.params["model_id"]
|
| 57 |
+
target_sr: int = config.params["sample_rate"]
|
| 58 |
+
|
| 59 |
+
model = self._get_model(model_id)
|
| 60 |
+
|
| 61 |
+
wav, sr = torchaudio.load(str(audio_path))
|
| 62 |
+
# Mix to mono if needed.
|
| 63 |
+
if wav.shape[0] > 1:
|
| 64 |
+
wav = wav.mean(dim=0, keepdim=True)
|
| 65 |
+
# Resample if needed.
|
| 66 |
+
if sr != target_sr:
|
| 67 |
+
wav = torchaudio.functional.resample(wav, sr, target_sr)
|
| 68 |
+
# SNAC expects (B, 1, T).
|
| 69 |
+
wav = wav.unsqueeze(0)
|
| 70 |
+
|
| 71 |
+
audio_hat, _ = model(wav)
|
| 72 |
+
|
| 73 |
+
return audio_hat.squeeze(0).squeeze(0).cpu().numpy()
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
register(SNACCodec())
|
compare_codec/spectrogram.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Generate mel spectrogram images from audio arrays."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import tempfile
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
import librosa
|
| 9 |
+
import librosa.display
|
| 10 |
+
import matplotlib.pyplot as plt
|
| 11 |
+
import numpy as np
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def generate(audio: np.ndarray, sr: int) -> Path:
|
| 15 |
+
"""Render a mel spectrogram to a temporary PNG and return its path."""
|
| 16 |
+
fig, ax = plt.subplots(1, 1, figsize=(8, 3))
|
| 17 |
+
S = librosa.feature.melspectrogram(y=audio, sr=sr, n_mels=128)
|
| 18 |
+
S_dB = librosa.power_to_db(S, ref=np.max)
|
| 19 |
+
librosa.display.specshow(S_dB, sr=sr, x_axis="time", y_axis="mel", ax=ax)
|
| 20 |
+
ax.set(title=None, xlabel=None, ylabel=None)
|
| 21 |
+
ax.tick_params(labelsize=8)
|
| 22 |
+
fig.tight_layout(pad=0.5)
|
| 23 |
+
|
| 24 |
+
path = Path(tempfile.mktemp(suffix=".png"))
|
| 25 |
+
fig.savefig(path, dpi=100, bbox_inches="tight")
|
| 26 |
+
plt.close(fig)
|
| 27 |
+
return path
|
pyproject.toml
CHANGED
|
@@ -9,13 +9,20 @@ dependencies = [
|
|
| 9 |
"descript-audio-codec>=1.0.0",
|
| 10 |
"encodec>=0.1.1",
|
| 11 |
"gradio",
|
|
|
|
|
|
|
| 12 |
"numpy",
|
| 13 |
"snac>=1.2.1",
|
| 14 |
"torchaudio>=2.11.0",
|
|
|
|
| 15 |
"transformers>=5.5.3",
|
| 16 |
]
|
| 17 |
|
| 18 |
[dependency-groups]
|
| 19 |
dev = [
|
|
|
|
| 20 |
"ruff>=0.15.10",
|
| 21 |
]
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
"descript-audio-codec>=1.0.0",
|
| 10 |
"encodec>=0.1.1",
|
| 11 |
"gradio",
|
| 12 |
+
"librosa>=0.11.0",
|
| 13 |
+
"matplotlib>=3.10.8",
|
| 14 |
"numpy",
|
| 15 |
"snac>=1.2.1",
|
| 16 |
"torchaudio>=2.11.0",
|
| 17 |
+
"torchcodec>=0.11.0",
|
| 18 |
"transformers>=5.5.3",
|
| 19 |
]
|
| 20 |
|
| 21 |
[dependency-groups]
|
| 22 |
dev = [
|
| 23 |
+
"pytest>=9.0.3",
|
| 24 |
"ruff>=0.15.10",
|
| 25 |
]
|
| 26 |
+
|
| 27 |
+
[tool.pytest.ini_options]
|
| 28 |
+
pythonpath = ["."]
|
tests/test_snac_codec.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for the SNAC codec wrapper."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import pytest
|
| 9 |
+
import torch
|
| 10 |
+
import torchaudio
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@pytest.fixture()
|
| 14 |
+
def wav_file(tmp_path: Path) -> Path:
|
| 15 |
+
"""Create a short mono WAV at 24 kHz."""
|
| 16 |
+
sr = 24_000
|
| 17 |
+
samples = torch.randn(1, sr * 2) # 2 seconds
|
| 18 |
+
path = tmp_path / "test.wav"
|
| 19 |
+
torchaudio.save(str(path), samples, sr)
|
| 20 |
+
return path
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def test_snac_name():
|
| 24 |
+
from compare_codec.snac_codec import SNACCodec
|
| 25 |
+
|
| 26 |
+
codec = SNACCodec()
|
| 27 |
+
assert codec.name == "SNAC"
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def test_snac_sample_rate():
|
| 31 |
+
from compare_codec.snac_codec import SNACCodec
|
| 32 |
+
|
| 33 |
+
codec = SNACCodec()
|
| 34 |
+
assert codec.sample_rate == 24_000
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def test_snac_configs_not_empty():
|
| 38 |
+
from compare_codec.snac_codec import SNACCodec
|
| 39 |
+
|
| 40 |
+
codec = SNACCodec()
|
| 41 |
+
configs = codec.configs()
|
| 42 |
+
assert len(configs) >= 3 # at least one per model variant
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def test_snac_configs_have_sample_rate():
|
| 46 |
+
from compare_codec.snac_codec import SNACCodec
|
| 47 |
+
|
| 48 |
+
codec = SNACCodec()
|
| 49 |
+
for cfg in codec.configs():
|
| 50 |
+
assert "sample_rate" in cfg.params
|
| 51 |
+
assert "model_id" in cfg.params
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def test_snac_encode_decode_returns_float32_array(wav_file: Path):
|
| 55 |
+
from compare_codec.snac_codec import SNACCodec
|
| 56 |
+
|
| 57 |
+
codec = SNACCodec()
|
| 58 |
+
cfg = [c for c in codec.configs() if c.params["sample_rate"] == 24_000][0]
|
| 59 |
+
result = codec.encode_decode(wav_file, cfg)
|
| 60 |
+
|
| 61 |
+
assert isinstance(result, np.ndarray)
|
| 62 |
+
assert result.dtype == np.float32
|
| 63 |
+
assert result.ndim == 1
|
| 64 |
+
assert len(result) > 0
|
uv.lock
CHANGED
|
@@ -350,14 +350,18 @@ dependencies = [
|
|
| 350 |
{ name = "descript-audio-codec" },
|
| 351 |
{ name = "encodec" },
|
| 352 |
{ name = "gradio" },
|
|
|
|
|
|
|
| 353 |
{ name = "numpy" },
|
| 354 |
{ name = "snac" },
|
| 355 |
{ name = "torchaudio" },
|
|
|
|
| 356 |
{ name = "transformers" },
|
| 357 |
]
|
| 358 |
|
| 359 |
[package.dev-dependencies]
|
| 360 |
dev = [
|
|
|
|
| 361 |
{ name = "ruff" },
|
| 362 |
]
|
| 363 |
|
|
@@ -366,14 +370,20 @@ requires-dist = [
|
|
| 366 |
{ name = "descript-audio-codec", specifier = ">=1.0.0" },
|
| 367 |
{ name = "encodec", specifier = ">=0.1.1" },
|
| 368 |
{ name = "gradio" },
|
|
|
|
|
|
|
| 369 |
{ name = "numpy" },
|
| 370 |
{ name = "snac", specifier = ">=1.2.1" },
|
| 371 |
{ name = "torchaudio", specifier = ">=2.11.0" },
|
|
|
|
| 372 |
{ name = "transformers", specifier = ">=5.5.3" },
|
| 373 |
]
|
| 374 |
|
| 375 |
[package.metadata.requires-dev]
|
| 376 |
-
dev = [
|
|
|
|
|
|
|
|
|
|
| 377 |
|
| 378 |
[[package]]
|
| 379 |
name = "contourpy"
|
|
@@ -951,6 +961,15 @@ wheels = [
|
|
| 951 |
{ url = "https://files.pythonhosted.org/packages/8a/db/55a262f3606bebcae07cc14095338471ad7c0bbcaa37707e6f0ee49725b7/importlib_resources-7.1.0-py3-none-any.whl", hash = "sha256:1bd7b48b4088eddb2cd16382150bb515af0bd2c70128194392725f82ad2c96a1", size = 37232, upload-time = "2026-04-12T16:36:08.219Z" },
|
| 952 |
]
|
| 953 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 954 |
[[package]]
|
| 955 |
name = "ipython"
|
| 956 |
version = "9.12.0"
|
|
@@ -1855,6 +1874,15 @@ wheels = [
|
|
| 1855 |
{ url = "https://files.pythonhosted.org/packages/75/a6/a0a304dc33b49145b21f4808d763822111e67d1c3a32b524a1baf947b6e1/platformdirs-4.9.6-py3-none-any.whl", hash = "sha256:e61adb1d5e5cb3441b4b7710bea7e4c12250ca49439228cc1021c00dcfac0917", size = 21348, upload-time = "2026-04-09T00:04:09.463Z" },
|
| 1856 |
]
|
| 1857 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1858 |
[[package]]
|
| 1859 |
name = "pooch"
|
| 1860 |
version = "1.9.0"
|
|
@@ -2056,6 +2084,22 @@ wheels = [
|
|
| 2056 |
{ url = "https://files.pythonhosted.org/packages/1a/22/60cd92bd3ec00948800984410f4cf5ded5bd8e9b715729f3642efe0edb3d/pystoi-0.4.1-py2.py3-none-any.whl", hash = "sha256:e277b671663d26d35a2416c9c8010a74084e6c3970354506398051a554896939", size = 8218, upload-time = "2023-12-29T16:48:03.748Z" },
|
| 2057 |
]
|
| 2058 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2059 |
[[package]]
|
| 2060 |
name = "python-dateutil"
|
| 2061 |
version = "2.9.0.post0"
|
|
@@ -2753,6 +2797,25 @@ wheels = [
|
|
| 2753 |
{ url = "https://files.pythonhosted.org/packages/93/f7/ee5da8c03f1a3c7662c6c6a119f24a4b3e646da94be56dce3201e3a6ee9b/torchaudio-2.11.0-cp314-cp314t-win_amd64.whl", hash = "sha256:88fb5e29f670a33d9bac6aabb1d2734460cf6e461bde5cdc352826035851b16d", size = 328661, upload-time = "2026-03-23T18:13:20.1Z" },
|
| 2754 |
]
|
| 2755 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2756 |
[[package]]
|
| 2757 |
name = "tqdm"
|
| 2758 |
version = "4.67.3"
|
|
|
|
| 350 |
{ name = "descript-audio-codec" },
|
| 351 |
{ name = "encodec" },
|
| 352 |
{ name = "gradio" },
|
| 353 |
+
{ name = "librosa" },
|
| 354 |
+
{ name = "matplotlib" },
|
| 355 |
{ name = "numpy" },
|
| 356 |
{ name = "snac" },
|
| 357 |
{ name = "torchaudio" },
|
| 358 |
+
{ name = "torchcodec" },
|
| 359 |
{ name = "transformers" },
|
| 360 |
]
|
| 361 |
|
| 362 |
[package.dev-dependencies]
|
| 363 |
dev = [
|
| 364 |
+
{ name = "pytest" },
|
| 365 |
{ name = "ruff" },
|
| 366 |
]
|
| 367 |
|
|
|
|
| 370 |
{ name = "descript-audio-codec", specifier = ">=1.0.0" },
|
| 371 |
{ name = "encodec", specifier = ">=0.1.1" },
|
| 372 |
{ name = "gradio" },
|
| 373 |
+
{ name = "librosa", specifier = ">=0.11.0" },
|
| 374 |
+
{ name = "matplotlib", specifier = ">=3.10.8" },
|
| 375 |
{ name = "numpy" },
|
| 376 |
{ name = "snac", specifier = ">=1.2.1" },
|
| 377 |
{ name = "torchaudio", specifier = ">=2.11.0" },
|
| 378 |
+
{ name = "torchcodec", specifier = ">=0.11.0" },
|
| 379 |
{ name = "transformers", specifier = ">=5.5.3" },
|
| 380 |
]
|
| 381 |
|
| 382 |
[package.metadata.requires-dev]
|
| 383 |
+
dev = [
|
| 384 |
+
{ name = "pytest", specifier = ">=9.0.3" },
|
| 385 |
+
{ name = "ruff", specifier = ">=0.15.10" },
|
| 386 |
+
]
|
| 387 |
|
| 388 |
[[package]]
|
| 389 |
name = "contourpy"
|
|
|
|
| 961 |
{ url = "https://files.pythonhosted.org/packages/8a/db/55a262f3606bebcae07cc14095338471ad7c0bbcaa37707e6f0ee49725b7/importlib_resources-7.1.0-py3-none-any.whl", hash = "sha256:1bd7b48b4088eddb2cd16382150bb515af0bd2c70128194392725f82ad2c96a1", size = 37232, upload-time = "2026-04-12T16:36:08.219Z" },
|
| 962 |
]
|
| 963 |
|
| 964 |
+
[[package]]
|
| 965 |
+
name = "iniconfig"
|
| 966 |
+
version = "2.3.0"
|
| 967 |
+
source = { registry = "https://pypi.org/simple" }
|
| 968 |
+
sdist = { url = "https://files.pythonhosted.org/packages/72/34/14ca021ce8e5dfedc35312d08ba8bf51fdd999c576889fc2c24cb97f4f10/iniconfig-2.3.0.tar.gz", hash = "sha256:c76315c77db068650d49c5b56314774a7804df16fee4402c1f19d6d15d8c4730", size = 20503, upload-time = "2025-10-18T21:55:43.219Z" }
|
| 969 |
+
wheels = [
|
| 970 |
+
{ url = "https://files.pythonhosted.org/packages/cb/b1/3846dd7f199d53cb17f49cba7e651e9ce294d8497c8c150530ed11865bb8/iniconfig-2.3.0-py3-none-any.whl", hash = "sha256:f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12", size = 7484, upload-time = "2025-10-18T21:55:41.639Z" },
|
| 971 |
+
]
|
| 972 |
+
|
| 973 |
[[package]]
|
| 974 |
name = "ipython"
|
| 975 |
version = "9.12.0"
|
|
|
|
| 1874 |
{ url = "https://files.pythonhosted.org/packages/75/a6/a0a304dc33b49145b21f4808d763822111e67d1c3a32b524a1baf947b6e1/platformdirs-4.9.6-py3-none-any.whl", hash = "sha256:e61adb1d5e5cb3441b4b7710bea7e4c12250ca49439228cc1021c00dcfac0917", size = 21348, upload-time = "2026-04-09T00:04:09.463Z" },
|
| 1875 |
]
|
| 1876 |
|
| 1877 |
+
[[package]]
|
| 1878 |
+
name = "pluggy"
|
| 1879 |
+
version = "1.6.0"
|
| 1880 |
+
source = { registry = "https://pypi.org/simple" }
|
| 1881 |
+
sdist = { url = "https://files.pythonhosted.org/packages/f9/e2/3e91f31a7d2b083fe6ef3fa267035b518369d9511ffab804f839851d2779/pluggy-1.6.0.tar.gz", hash = "sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3", size = 69412, upload-time = "2025-05-15T12:30:07.975Z" }
|
| 1882 |
+
wheels = [
|
| 1883 |
+
{ url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" },
|
| 1884 |
+
]
|
| 1885 |
+
|
| 1886 |
[[package]]
|
| 1887 |
name = "pooch"
|
| 1888 |
version = "1.9.0"
|
|
|
|
| 2084 |
{ url = "https://files.pythonhosted.org/packages/1a/22/60cd92bd3ec00948800984410f4cf5ded5bd8e9b715729f3642efe0edb3d/pystoi-0.4.1-py2.py3-none-any.whl", hash = "sha256:e277b671663d26d35a2416c9c8010a74084e6c3970354506398051a554896939", size = 8218, upload-time = "2023-12-29T16:48:03.748Z" },
|
| 2085 |
]
|
| 2086 |
|
| 2087 |
+
[[package]]
|
| 2088 |
+
name = "pytest"
|
| 2089 |
+
version = "9.0.3"
|
| 2090 |
+
source = { registry = "https://pypi.org/simple" }
|
| 2091 |
+
dependencies = [
|
| 2092 |
+
{ name = "colorama", marker = "sys_platform == 'win32'" },
|
| 2093 |
+
{ name = "iniconfig" },
|
| 2094 |
+
{ name = "packaging" },
|
| 2095 |
+
{ name = "pluggy" },
|
| 2096 |
+
{ name = "pygments" },
|
| 2097 |
+
]
|
| 2098 |
+
sdist = { url = "https://files.pythonhosted.org/packages/7d/0d/549bd94f1a0a402dc8cf64563a117c0f3765662e2e668477624baeec44d5/pytest-9.0.3.tar.gz", hash = "sha256:b86ada508af81d19edeb213c681b1d48246c1a91d304c6c81a427674c17eb91c", size = 1572165, upload-time = "2026-04-07T17:16:18.027Z" }
|
| 2099 |
+
wheels = [
|
| 2100 |
+
{ url = "https://files.pythonhosted.org/packages/d4/24/a372aaf5c9b7208e7112038812994107bc65a84cd00e0354a88c2c77a617/pytest-9.0.3-py3-none-any.whl", hash = "sha256:2c5efc453d45394fdd706ade797c0a81091eccd1d6e4bccfcd476e2b8e0ab5d9", size = 375249, upload-time = "2026-04-07T17:16:16.13Z" },
|
| 2101 |
+
]
|
| 2102 |
+
|
| 2103 |
[[package]]
|
| 2104 |
name = "python-dateutil"
|
| 2105 |
version = "2.9.0.post0"
|
|
|
|
| 2797 |
{ url = "https://files.pythonhosted.org/packages/93/f7/ee5da8c03f1a3c7662c6c6a119f24a4b3e646da94be56dce3201e3a6ee9b/torchaudio-2.11.0-cp314-cp314t-win_amd64.whl", hash = "sha256:88fb5e29f670a33d9bac6aabb1d2734460cf6e461bde5cdc352826035851b16d", size = 328661, upload-time = "2026-03-23T18:13:20.1Z" },
|
| 2798 |
]
|
| 2799 |
|
| 2800 |
+
[[package]]
|
| 2801 |
+
name = "torchcodec"
|
| 2802 |
+
version = "0.11.0"
|
| 2803 |
+
source = { registry = "https://pypi.org/simple" }
|
| 2804 |
+
wheels = [
|
| 2805 |
+
{ url = "https://files.pythonhosted.org/packages/72/f2/85da3abfef5443b0fd7a70706dabf54e0fb5592ed6b03b3f8bfccff06af0/torchcodec-0.11.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:3e5abd61ad9de69a7008545f5c08736b66298f4e895b1f9fad01ae41bce75252", size = 4368773, upload-time = "2026-03-24T15:56:22.952Z" },
|
| 2806 |
+
{ url = "https://files.pythonhosted.org/packages/73/4c/b466aa762abf2e771cfd865bac1c03259a66c482b83978898c00810cb97f/torchcodec-0.11.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:aca62724bf7d4b5b70db60183e8bee67ba77f4f0afdb052e6d1900528b97de6e", size = 2397523, upload-time = "2026-03-24T15:56:25.026Z" },
|
| 2807 |
+
{ url = "https://files.pythonhosted.org/packages/f3/01/2488f553e8014c911652c2bf29265b73ec3753ac13ac4816b9c831ad27e2/torchcodec-0.11.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:669895505b7f0cd17bd6a71cf8bfa85190a9de1dc87f77051682c73149d58023", size = 2545573, upload-time = "2026-03-24T15:56:26.765Z" },
|
| 2808 |
+
{ url = "https://files.pythonhosted.org/packages/21/a7/e12e7cc5d69dd55e3edef7f16f46bcca5978c6262c86e51ccd3913bff92b/torchcodec-0.11.0-cp312-cp312-win_amd64.whl", hash = "sha256:8a1f023e26dddca77c1c81de83cf9201ecb363b818405557da70e4777b966697", size = 1921067, upload-time = "2026-03-24T15:56:28.417Z" },
|
| 2809 |
+
{ url = "https://files.pythonhosted.org/packages/4a/24/84c597fb17481eec5885d02cc3f95ba047f7e0411009f7c0e8087c3a52a2/torchcodec-0.11.0-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:9308e1e9c22a14f8e3b35a16ae8167eadc5939dd7d699b1cb6e38f57bfcc6563", size = 4289581, upload-time = "2026-03-24T15:56:30.112Z" },
|
| 2810 |
+
{ url = "https://files.pythonhosted.org/packages/ee/57/97f90c0e2abc8253412281bb1185375dc9b9aa9bd519d14ceeb52ecbf6ec/torchcodec-0.11.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:0120f91479e2845ed481ce84ad5248f70d71009fb2f9b5ea182260c5c901c804", size = 2399066, upload-time = "2026-03-24T15:56:32.016Z" },
|
| 2811 |
+
{ url = "https://files.pythonhosted.org/packages/f4/c5/e078ef510cd4ad15bb185c184b0f0fbb68c8ffe8afeed24c3d18e7acaadb/torchcodec-0.11.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:0e0fd9a9045271f53d78e2d1153ac88d1369808fa0df8db135941d3464080b76", size = 2548454, upload-time = "2026-03-24T15:56:33.832Z" },
|
| 2812 |
+
{ url = "https://files.pythonhosted.org/packages/00/9f/f1ecdeb6e53e9ab5d78e6cb191bd1191683e3f0c02950022cf02c939c5fc/torchcodec-0.11.0-cp313-cp313-win_amd64.whl", hash = "sha256:bda3a330a5578ccc753317f388b492ac19296d9e0e0bb50a43faff512981f995", size = 1920982, upload-time = "2026-03-24T15:56:37.022Z" },
|
| 2813 |
+
{ url = "https://files.pythonhosted.org/packages/6d/a8/2bb67ce8f849c52efdbb45a14dfed7ad378004df3b584a033b47c771b618/torchcodec-0.11.0-cp314-cp314-macosx_12_0_arm64.whl", hash = "sha256:386dffbd76ae7fcf77f463755830a0489bfecaca55b354a4e794c06acc74540e", size = 4408063, upload-time = "2026-03-24T15:56:38.875Z" },
|
| 2814 |
+
{ url = "https://files.pythonhosted.org/packages/53/cf/3aad20b6b913b6800192114c9e89d342783cfba2cd234cf5032cd175d474/torchcodec-0.11.0-cp314-cp314-manylinux_2_28_aarch64.whl", hash = "sha256:ba874466cb5eb7be30062ce1a580d795af42b609587929af631e068693b55233", size = 2402237, upload-time = "2026-03-24T15:56:40.684Z" },
|
| 2815 |
+
{ url = "https://files.pythonhosted.org/packages/79/e7/11ed146f043658bbc0f35cc9ef9064a55401a10d41ef1b2714e6bd2763c6/torchcodec-0.11.0-cp314-cp314-manylinux_2_28_x86_64.whl", hash = "sha256:71b4991dd45759dd1836dd0b11ce8b509a6b74487ec47f2b0e679f87e2927cba", size = 2549919, upload-time = "2026-03-24T15:56:42.197Z" },
|
| 2816 |
+
{ url = "https://files.pythonhosted.org/packages/51/e9/e5fe67d84ee1ce2bf38c10b1424d87b3719b5279afbca3316fbd439010c5/torchcodec-0.11.0-cp314-cp314-win_amd64.whl", hash = "sha256:9269c559e21d678401febe3fae59e3ea585896e9bf7993fb859bebb4468b6a1d", size = 1925004, upload-time = "2026-03-24T15:56:44.411Z" },
|
| 2817 |
+
]
|
| 2818 |
+
|
| 2819 |
[[package]]
|
| 2820 |
name = "tqdm"
|
| 2821 |
version = "4.67.3"
|