HT-Demucs FT β€” Drums Specialist, ONNX

The first working, parity-verified ONNX export of HT-Demucs FT on the Hugging Face Hub. Runs in onnxruntime on CPU out of the box β€” and on CoreML / CUDA / DirectML with a one-line provider change. No PyTorch required at inference.

This repo packages sub-model 0 (the drums specialist) of the htdemucs_ft 4-bag ensemble as a single .onnx file plus a 200-line numpy reference inference script.

Why this exists: every previous attempt at "demucs onnx" on GitHub stalled on one of four blockers β€” complex64 STFT outputs, fractions.Fraction in the segment math, random.randrange in the transformer's positional embedding, or PyTorch's fused _native_multi_head_attention C++ kernel. This export defeats all four. See How it was built below for the patches.


TL;DR

pip install onnxruntime numpy soundfile
python infer.py your-song.mp3 ./out/
# writes ./out/drums.wav at 44.1 kHz stereo

That's it. The repo contains a 316 MB htdemucs_ft_drums.onnx (opset 17), a self-contained infer.py, and this card.


Quality

The drums output of this sub-model is numerically equivalent to the official PyTorch htdemucs_ft ensemble for the drums stem (the bag's weight matrix is one-hot per stem, so the bag's drums output IS this sub-model's drums output). Verified parity on synthetic and real audio: max abs diff = 1.63 Γ— 10⁻⁴ between this ONNX model and PyTorch fp32.

Metric (MUSDB18-HQ test, 50 songs) Value Source
Median drums SDR 10.11 dB StemSplitio/stem-separation-benchmark-2026
Rank among 4 popular open-source separators #2 (mdx_extra_q is #1 at 11.49) same
ONNX vs PyTorch max abs diff 0.000163 included verify_onnx_parity.py

Performance

Runtime Hardware Per 7.8-s segment Per 3-min song Notes
onnxruntime CPU EP Apple M4 Pro (CPU only) 1.59 s ~22 s Measured. 1.31Γ— faster than PyTorch CPU.
PyTorch CPU Apple M4 Pro (CPU only) 2.09 s ~29 s Baseline.
PyTorch MPS (full bag) Apple M4 Pro 1.0 s ~12 s But the full bag is 640 MB and runs all 4 sub-models.
onnxruntime CoreML EP Apple M-series TBD TBD First-time graph compile is slow (~5+ min on M4 Pro). Subsequent loads cached. Day 3 of the StemSplit ONNX project.
onnxruntime CUDA EP NVIDIA L4 ~0.4 s ~5 s Extrapolated from PyTorch CUDA numbers.
onnxruntime DirectML EP RTX 4090 ~0.2 s ~2 s Extrapolated.

Real-time factor on M4 Pro CPU: 0.20.


Quick start

Python β€” minimal

import infer
drums = infer.separate_drums("your-song.mp3")
# drums: numpy array (2, samples) at 44.1 kHz

Python β€” full control

import soundfile as sf
import onnxruntime as ort
import infer

# Optional: pick an execution provider.
# CPU is portable and the default. On macOS swap in CoreML; on Linux/Windows
# with an NVIDIA GPU, swap in CUDA. On Windows with any DX12 GPU, DirectML.
providers = ["CPUExecutionProvider"]
# providers = ["CoreMLExecutionProvider", "CPUExecutionProvider"]
# providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
# providers = ["DmlExecutionProvider", "CPUExecutionProvider"]

audio, sr = sf.read("your-song.mp3", dtype="float32", always_2d=True)
stems = infer.separate(audio.T, sr, providers=providers)
sf.write("drums.wav", stems[0].T, sr)  # stems[0] is the drums row

CLI

python infer.py your-song.mp3 ./out/
python infer.py your-song.mp3 ./out/ --providers coreml      # macOS
python infer.py your-song.mp3 ./out/ --providers cuda        # NVIDIA
python infer.py your-song.mp3 ./out/ --providers dml         # Windows / Xbox / any DX12
python infer.py your-song.mp3 ./out/ --write-all-stems       # also writes by-products

Mobile (iOS / Swift)

import onnxruntime_objc

let modelPath = Bundle.main.path(forResource: "htdemucs_ft_drums", ofType: "onnx")!
let env = try ORTEnv(loggingLevel: .warning)
let options = try ORTSessionOptions()
try options.appendCoreMLExecutionProvider(with: ORTCoreMLExecutionProviderOptions())
let session = try ORTSession(env: env, modelPath: modelPath, sessionOptions: options)

// audio: 1 x 2 x 343980 Float32 buffer, then session.run([...]).

Mobile (Android / Kotlin)

import ai.onnxruntime.OrtEnvironment
import ai.onnxruntime.OrtSession

val env = OrtEnvironment.getEnvironment()
val opts = OrtSession.SessionOptions().apply { addNnapi() }
val session = env.createSession(modelPath, opts)
// audio: FloatBuffer of shape (1, 2, 343980), then session.run(...).

Web (onnxruntime-web)

import * as ort from "onnxruntime-web";
const session = await ort.InferenceSession.create("htdemucs_ft_drums.onnx", {
  executionProviders: ["wasm"],
  graphOptimizationLevel: "all",
});
const tensor = new ort.Tensor("float32", audioBuffer, [1, 2, 343980]);
const out = await session.run({ mix: tensor });
// out.stems.data is a Float32Array of shape (1, 4, 2, 343980); take row 0 for drums.

Input / output spec

Tensor Name Shape Dtype Notes
Input mix (1, 2, 343980) float32 Stereo audio, 44.1 kHz, 7.8 s segment. Values in [-1, 1].
Output stems (1, 4, 2, 343980) float32 Stems in order [drums, bass, other, vocals]. Only stems[0, 0] is meaningfully predicted β€” the other 3 rows are weakly-predicted by-products of the drum specialist.

For longer audio, chunk with overlap-add β€” see infer.py::separate for a working 60-line implementation.


How it was built

The export pipeline lives at github.com/musicai/musicai/tree/main/scripts/coreml-conversion (in our StemSplit monorepo). The patches that made it work:

  1. torch.stft complex output β†’ replaced with a Conv1d using sin/cos kernels that emits 2 real channels directly. Verified to 5 Γ— 10⁻⁢ max diff against torch.stft.
  2. torch.istft β†’ mirror replacement with ConvTranspose1d plus an overlap-add window-squared envelope.
  3. view_as_real / view_as_complex in _magnitude and _mask β†’ rewritten to thread real-channel tensors throughout.
  4. fractions.Fraction in model.segment β†’ coerced to float.
  5. random.randrange in CrossTransformerEncoder._get_pos_embedding β†’ replaced the method with one that hardcodes shift=0 (mathematically identical at inference since sin_random_shift=0).
  6. aten::_native_multi_head_attention (PyTorch's fused C++ MHA kernel has no ONNX symbolic) β†’ wrote a drop-in nn.MultiheadAttention.forward replacement using plain Linear/bmm/softmax. Verified parity unchanged.

Net result: end-to-end parity vs PyTorch fp32 at 1.6 Γ— 10⁻⁴ max abs diff across the entire output, and the model runs in onnxruntime CPU out of the box at 1.31Γ— the speed of PyTorch CPU.


Tooling β€” demucs-onnx Python package

This model can also be run (and re-exported) via the open-source demucs-onnx Python package on PyPI. It auto-downloads from this repo on first use, so you don't have to clone or wrangle file paths.

pip install demucs-onnx

# Single specialist (this repo)
demucs-onnx separate song.mp3 stems/ --stem drums

# Or via the Python API
python -c "from demucs_onnx import separate_stem; \
  drums = separate_stem('song.mp3', 'drums')"

demucs-onnx is also the canonical tool for exporting htdemucs to ONNX yourself β€” it bundles all four blocker fixes (complex STFT, fractions.Fraction, random.randrange, aten::_native_multi_head_attention) so vanilla torch.onnx.export works on your own checkpoints.

pip install "demucs-onnx[export]"
demucs-onnx export htdemucs_ft drums.onnx --stem drums

Related work

The full StemSplit ONNX project:

Day Deliverable Status
1 htdemucs-ft-drums-onnx (this repo) βœ… live
2 htdemucs-ft-{bass,other,vocals}-onnx + htdemucs-ft-onnx bag βœ… live
3 CoreML EP profiling, INT8 quantization, mobile demos πŸ“‹ planned

Sibling ONNX repos from the same export:

Repo Stem Use when
htdemucs-ft-onnx all 4 Full 4-stem separation in one repo
htdemucs-ft-drums-onnx (this) drums Drum extraction, beat transcription
htdemucs-ft-bass-onnx bass Bassline transcription, mix rebalancing
htdemucs-ft-other-onnx other Karaoke instrumentals, sample-flipping
htdemucs-ft-vocals-onnx vocals #1 open-source vocal SDR β€” karaoke, acapella, vocal removal

PyTorch versions for HF Inference Endpoints: htdemucs-ft-pytorch and its 4 sibling specialist repos.

Full benchmark across every popular open-source separator: StemSplitio/stem-separation-benchmark-2026.


Skip the infrastructure β€” use the StemSplit API

Don't want to ship a 316 MB model in your app, manage a GPU pool, or write overlap-add chunking? Use the StemSplit API instead β€” same model under the hood, hosted for you, with credits and a dashboard.

Or try the no-code tools that use this same model family:


Files in this repo

File Size Purpose
htdemucs_ft_drums.onnx 316 MB The exported model. Opset 17. Passes onnx.checker.check_model.
infer.py 6 KB Pure numpy + onnxruntime reference implementation. ~200 lines, no torch.
requirements.txt <1 KB onnxruntime, numpy, soundfile.
README.md this file

License & attribution

This repo is MIT-licensed, matching the original HT-Demucs.

Please cite the original authors if you use this model in research:

@inproceedings{rouard2023hybrid,
  title     = {Hybrid Transformers for Music Source Separation},
  author    = {Rouard, Simon and Massa, Francisco and D{\'e}fossez, Alexandre},
  booktitle = {ICASSP},
  year      = {2023}
}

And if you use the benchmark or this ONNX packaging:

@misc{stemsplit_benchmark_2026,
  title  = {StemSplit Stem-Separation Benchmark 2026},
  author = {StemSplit},
  year   = {2026},
  url    = {https://huggingface.co/datasets/StemSplitio/stem-separation-benchmark-2026}
}
  • Original PyTorch model: facebookresearch/demucs
  • ONNX export, parity verification, and packaging by StemSplit
  • Search keywords: htdemucs onnx, demucs onnx, demucs ios, demucs android, music source separation onnx, stem separation mobile, drum extraction onnx
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Dataset used to train StemSplitio/htdemucs-ft-drums-onnx

Collection including StemSplitio/htdemucs-ft-drums-onnx