DreamVAE

A distilled decoder for the ACE-Step 1.5 Oobleck VAE. Drop-in replacement: same [B, 64, T] latent input, same [B, 2, 1920*T] stereo output at 48 kHz. 51.7M parameters (61% of the teacher), 0.24 dB SNR degradation vs. the teacher across 20 diverse music tracks, and 37.2 ms to decode 60 seconds of audio on an RTX 5090 TensorRT FP16 engine (8.66x faster than the teacher in PyTorch, 1.45x faster than the teacher's own TRT engine).

Usage

import json
import torch
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file

repo = "daydreamlive/DreamVAE"
weights = hf_hub_download(repo, "model.safetensors")
config  = hf_hub_download(repo, "config.json")
mod_py  = hf_hub_download(repo, "modeling.py")

import sys, os
sys.path.insert(0, os.path.dirname(mod_py))
from modeling import FastOobleckDecoder

cfg = json.load(open(config))
model = FastOobleckDecoder(
    channels=cfg["channels"],
    input_channels=cfg["input_channels"],
    audio_channels=cfg["audio_channels"],
    upsampling_ratios=cfg["upsampling_ratios"],
    channel_multiples=cfg["channel_multiples"],
).eval().to("cuda")
model.load_state_dict(load_file(weights))

# latents: [B, 64, T]  ->  audio: [B, 2, 1920 * T]
with torch.no_grad():
    audio = model(latents)

Pair with the stock ACE-Step encoder; only the decoder is distilled.

ONNX export is at onnx/model.onnx. Build a TRT engine with scripts/export_trt.py.

Results

Quality (20 diverse FMA tracks, 10 s clips):

vs. original audio SNR (dB) STFT dist Mel dist
Teacher OobleckDecoder 5.0 (std 2.5) 0.698 0.733
DreamVAE 4.7 (std 2.4) 0.695 0.742
Degradation 0.24 dB better +0.008

The VAE's own reconstruction error (~5 dB SNR) dominates; DreamVAE's additional error is well under a quarter of that. STFT distance is marginally better for the student; mel distance is 1.1% worse. Informal listening finds the difference inaudible.

Speed (RTX 5090, 60 s stereo audio at 48 kHz, TensorRT FP16):

Backend Time (60 s) vs. teacher PyTorch
Teacher PyTorch 321.9 ms 1.00x
Teacher TRT FP16 54.0 ms 5.97x
DreamVAE PyTorch 241.1 ms 1.34x
DreamVAE TRT FP16 37.2 ms 8.66x

Distillation alone is modest (1.45x over the teacher's own TRT engine). The compelling number is the composition with TRT: what a deployment sees moving from stock PyTorch VAE to DreamVAE on TRT is 8.66x.

Architecture

Teacher DreamVAE
Parameters 84.4M 51.7M (61%)
Channel multiples [1, 2, 4, 8, 16] [1, 2, 4, 8, 8]
Residuals/block 3 (dil 1, 3, 9) 2 (dil 1, 3)
Upsampling ratios [10, 6, 4, 4, 2] [10, 6, 4, 4, 2]
Activation Snake1d Snake1d
Hop length 1920 1920

Trained for 650K steps on 8,000 FMA Small tracks (two-phase: 500K reconstruction + 150K adversarial). Released checkpoint is step 635K, selected as the HF-energy-match peak from a six-point trajectory over the adversarial phase.

Files

.
β”œβ”€β”€ README.md
β”œβ”€β”€ config.json            architecture + training hyperparameters
β”œβ”€β”€ model.safetensors      FP32 weights (207 MB, 51.7M params)
β”œβ”€β”€ modeling.py            self-contained FastOobleckDecoder class
β”œβ”€β”€ onnx/model.onnx        ONNX export
└── scripts/               training, eval, speed bench, TRT export/verify

Limitations

  • Signal-level metrics only (no MUSHRA / PESQ / preference tests).
  • Evaluated on VAE round-trip, not end-to-end through the full text-to-DiT-to-VAE pipeline. FAD on generated samples is future work.
  • One point on the quality-speed frontier ([1,2,4,8,8], 2 residual units). Width variants not evaluated.
  • Trained in FP32. INT8 / FP8 quantization would likely stack another 1.5-2x on top of FP16 and is not explored here.

Citation

Technical note (methodology, training recipe with source-paper attribution, adversarial-phase trajectory, and negative results on encoder distillation) is forthcoming on arXiv. Citation block will be added here on release.

License

Apache 2.0, matching the upstream ACE-Step 1.5 license.

Downloads last month
16
Safetensors
Model size
51.7M params
Tensor type
F32
Β·
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for daydreamlive/DreamVAE

Quantized
(3)
this model