WavLM + LoRA for Speech Disfluency (Stuttering) Detection

Model Description

This model is a highly efficient speech disfluency detector fine-tuned on the SEP-28k dataset. It identifies various types of stuttering events in human speech.

By utilizing LoRA (Low-Rank Adaptation) on the microsoft/wavlm-base-plus backbone, this model achieves competitive performance while training only ~838K parameters (0.88% of the total model). This lightweight approach drastically reduces compute requirements without sacrificing accuracy, making it ideal for resource-constrained environments.

  • Base Model: microsoft/wavlm-base-plus
  • Task: Audio Classification (Speech Disfluency)
  • Language: English
  • Classes: 5 (Prolongation, Repetition, Block, Interjection, Fluent)

Intended Uses & Limitations

  • Intended Use: Detecting and analyzing stuttering patterns in podcast audio or speech recordings.
  • Limitations: The model struggles most with the Block class (silent pauses/blocks), which is a known challenge in audio-only disfluency detection. It is trained on the SEP-28k dataset (mostly English podcasts) and may not generalize perfectly to other languages or studio-quality clean audio without background noise.

Training Data

The model was fine-tuned on the SEP-28k Dataset, which contains 28,000 clips of stuttering events extracted from public podcasts. Labeling was determined via majority voting from multiple annotators.

Performance & Evaluation Results

The model was evaluated on a held-out test set from SEP-28k.

Overall Metrics:

  • Accuracy: 65.82%
  • Weighted F1-Score: 66.98%

Per-Class F1-Score Breakdown:

  • Interjection: 0.7569
  • Fluent: 0.7591
  • Repetition: 0.5858
  • Prolongation: 0.5433
  • Block: 0.3587

How to Use

Since this is a PEFT/LoRA model, you need to load the base model first and then apply the LoRA weights. Here is a quick start guide:

import torch
from peft import PeftModel
from transformers import AutoConfig, WavLMForSequenceClassification
import torchaudio

# 1. Configuration
base_model_id = "microsoft/wavlm-base-plus"
peft_model_id = "stuttering-detection-wavlm-lora"
num_classes = 5

# 2. Load Base Model
config = AutoConfig.from_pretrained(base_model_id, num_labels=num_classes)
base_model = WavLMForSequenceClassification.from_pretrained(base_model_id, config=config, ignore_mismatched_sizes=True)

# 3. Apply LoRA Weights
model = PeftModel.from_pretrained(base_model, peft_model_id)
model.eval()

# 4. Inference (Example)
waveform, sample_rate = torchaudio.load("path_to_audio.wav")
inputs = processor(waveform, sampling_rate=sample_rate, return_tensors="pt")
with torch.no_grad():
    logits = model(**inputs).logits
    predictions = torch.argmax(logits, dim=-1)
print(f"Predicted Class ID: {predictions.item()}")
Downloads last month
-
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for pmootr/stuttering-detection-wavlm-lora

Adapter
(1)
this model