omniASR-CTC-300M-pruned-21L

A depth-pruned variant of facebook/omniASR-CTC-300M. The top 3 of 24 Transformer encoder layers are removed, leaving a 21-layer encoder. The final LayerNorm and the CTC projection head are kept unchanged.

Original Pruned (this model)
Encoder layers 24 21
Parameters 325.5 M 287.7 M (−37.8 M, −11.6 %)
Encoder compute 100 % −12.5 %

No fine-tuning was applied — the pruned model reuses the original weights directly.

Why this works

The cut point was chosen empirically with a logit-lens analysis: the CTC head was applied to every intermediate encoder layer and the resulting transcription quality measured per layer. The CTC prediction forms almost entirely in layers L16–L18 and saturates from ~L18 onward — the top layers (L21–L23) contribute essentially nothing to the transcription. Removing them is therefore near-lossless.

Validation

Single utterance (German, vs. whisper-large-v3 reference, normalized: lowercased, punctuation removed):

Exit layer L17 L18 L20 L23 (full)
WER 45 % 20 % 20 % 20 %

The 21-layer model reproduces the full model's WER (20.5 %) exactly.

Cross-lingual (FLEURS, 6 languages, character error rate vs. gold transcription): mean CER per exit layer saturates at ~L18–L20 for every language tested (DE, EN, ES, FR, RU, ZH). Pruning after L20 costs ≈ +1 CER point on average; the saturation layer is language-independent and also robust to additive noise (SNR 20/10/0 dB).

Analysis

Per-exit-layer WER (single utterance, clean + additive noise). The CTC prediction forms in L16–L18 and saturates from ~L18; layers L19–L23 add nothing. The saturation point is the same under noise (SNR 20/10/0 dB).

WER per exit layer

Cross-lingual validation on FLEURS (6 languages, CER vs. gold). Every language saturates at the same depth (~L18–L20); only the absolute level differs. This makes the prune point language-independent.

CER per exit layer on FLEURS

Logit-lens decoding per layer. Greedy CTC transcription read off each encoder layer, colored by entropy, with per-layer WER vs. a whisper-large-v3 reference (normalized). Gibberish until L15, readable German from L18 on.

Logit-lens decoded transcription per layer

Usage

Requires fairseq2 and omnilingual-asr.

import dataclasses, torch, torchaudio
from huggingface_hub import hf_hub_download
from fairseq2 import init_fairseq2
from fairseq2.nn import BatchLayout
from fairseq2.models.wav2vec2.asr import create_wav2vec2_asr_model
from fairseq2.data.tokenizers.sentencepiece import load_sentencepiece_model
from omnilingual_asr.models.wav2vec2_asr.config import get_config, Wav2Vec2AsrConfig

ctx = init_fairseq2()

# Build a 21-layer encoder config from the base "300m" arch
cfg = get_config(ctx, Wav2Vec2AsrConfig, "300m")
cfg = dataclasses.replace(cfg, encoder_config=dataclasses.replace(
    cfg.encoder_config, num_encoder_layers=21, layer_drop_p=0.0))
model = create_wav2vec2_asr_model(cfg).eval()

# Load the pruned weights
ckpt = hf_hub_download("ChipCracker/omniASR-CTC-300M-pruned-21L",
                       "omniASR-CTC-300M-pruned21L.pt")
model.load_state_dict(torch.load(ckpt, map_location="cpu")["model"])

# Char tokenizer (CTC blank index = 0)
tok_path = hf_hub_download("ChipCracker/omniASR-CTC-300M-pruned-21L",
                           "omniASR_tokenizer.model")
sp = load_sentencepiece_model(tok_path)

# Transcribe a 16 kHz mono wav
wav, sr = torchaudio.load("audio.wav")
audio = wav.mean(0) if wav.dim() > 1 else wav.squeeze(0)
if sr != 16000:
    audio = torchaudio.transforms.Resample(sr, 16000)(audio)

x = audio.view(1, -1)
with torch.inference_mode():
    logits, _ = model(x, BatchLayout.of(x))
pred = logits[0].argmax(-1)
keep = torch.ones_like(pred, dtype=torch.bool)
keep[1:] = pred[1:] != pred[:-1]          # CTC collapse
ids = pred[keep]
ids = ids[ids != 0]                        # drop blank
print("".join(sp.index_to_token(int(i)) for i in ids))

How it was produced

# 1. load the full 24-layer model, drop encoder.layers.{21,22,23}
# 2. build a 21-layer config, load the remaining weights strict
# 3. verify per-layer logit-lens WER/CER (single + FLEURS multilingual)

The full analysis & pruning code lives in the omni-viz experiment (pruning_analysis.py, pruning_fleurs.py, prune_model.py).

Limitations

  • Derived, not retrained. A small early-exit head fine-tuned at layer 20 would likely match or beat this; here the original head is reused as-is.
  • Validation scale. Saturation was confirmed on one utterance + a 6-language × 5-sentence FLEURS subset (CER). Evaluate on your target languages/domain before production use.
  • Output style. Character-level CTC output keeps the base model's quirks (e.g. "das is" for "das ist"); CER is higher for scripts without word boundaries (e.g. Chinese).
  • Inherits all behavior and biases of the base model.

License & attribution

Apache-2.0, inherited from the base model facebook/omniASR-CTC-300M (Omnilingual ASR, Meta AI). Please cite the original work:

@misc{omnilingual_asr_2025,
  title  = {Omnilingual ASR},
  author = {Meta AI},
  year   = {2025},
  url    = {https://huggingface.co/facebook/omniASR-CTC-300M}
}
Downloads last month
29
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for ChipCracker/omniASR-CTC-300M-pruned-21L

Finetuned
(3)
this model

Datasets used to train ChipCracker/omniASR-CTC-300M-pruned-21L