TTS + Comb Artifact Detectors

Three lightweight CNN-based binary classifiers that detect waveform-domain artifacts in TTS-generated audio. These models identify vocoder artifacts (metallic resonance, buzzing) and comb-filtering artifacts that emerge from neural vocoders such as BigVGAN/HiFi-GAN.

All models operate on 16 kHz mono waveforms and output an artifact score in [0, 1]: 0 = clean, 1 = artifact detected.

Models

Model Architecture Parameters Val Accuracy Val F1 File
STFT Classifier Multi-resolution STFT + 2D CNN 1,688,833 99.88% 99.87% best_stft_classifier.pt
Mel CNN Mel spectrogram + 2D CNN 1,012,417 99.75% 99.75% best_mel_classifier.pt
Waveform 1D Raw waveform 1D CNN 1,898,753 97.44% 97.39% best_waveform_1d.pt

Per-Source Validation Accuracy

Each model was validated on three source types:

Model Predicted (vocoder artifacts) Real (clean audio) Comb (synthetic comb filter)
STFT Classifier 100% 100% 99.0%
Mel CNN 100% 100% 98.0%
Waveform 1D 99.8% 99.4% 82.5%

Quick Start

Installation

pip install torch torchaudio soundfile numpy

Python API

from artifact_detector import load_all_models, score_file, score_file_all

# Load all 3 models
models = load_all_models(".", device="cuda")

# Score a single file with all models
scores = score_file_all(models, "my_tts_output.wav")
for name, score in scores.items():
    print(f"  {name}: {score:.4f}")
# β†’ stft_classifier: 0.9823
# β†’ waveform_1d: 0.8901
# β†’ mel_classifier: 0.0012

# Ensemble average
import numpy as np
ensemble = np.mean(list(scores.values()))
print(f"  ensemble: {ensemble:.4f}")

# Single model
from artifact_detector import load_model, score_file
model = load_model("best_stft_classifier.pt", device="cuda")
score = score_file(model, "my_tts_output.wav")

Command Line

# Score a file with all models
python artifact_detector.py --model-dir . --input audio.wav

# Score a directory of WAV files
python artifact_detector.py --model-dir . --input /path/to/wavs/ --ext wav

# Score with a specific model only
python artifact_detector.py --checkpoint best_stft_classifier.pt --input audio.wav

# Custom threshold (default 0.5)
python artifact_detector.py --model-dir . --input audio.wav --threshold 0.3

Scoring Raw Tensors

import torch
from artifact_detector import load_model, score_waveform

model = load_model("best_stft_classifier.pt")

# waveform should be a 1-D float32 tensor at 16 kHz
waveform = torch.randn(160000)  # 10 seconds
score = score_waveform(model, waveform)

Using as a Differentiable Loss

All models are fully differentiable and can be used as auxiliary training losses:

model = load_model("best_stft_classifier.pt", device="cuda")

# During training (gradients flow through)
waveform = vocoder(mel)  # [B, 1, T] @ 16kHz
artifact_score = model.artifact_score(waveform)  # [B, 1]
artifact_loss = artifact_score.mean()

total_loss = reconstruction_loss + 0.1 * artifact_loss
total_loss.backward()

Architecture Details

1. Multi-Resolution STFT Classifier (recommended)

The best-performing model. Computes STFT at 4 resolutions (256/512/1024/2048 FFT sizes), processes each through a 5-layer 2D CNN on log-magnitude spectrograms, concatenates the 128-dim feature vectors from each resolution (512 total), and classifies through an MLP head.

  • Input: [B, 1, T] mono waveform at 16 kHz
  • STFT computed internally (fully differentiable)
  • 4 resolution blocks with adaptive average pooling
  • MLP head: 512 β†’ 256 β†’ 1

2. Waveform 1D CNN

Direct 1D convolution on raw waveform samples. 6 conv blocks with channels [1β†’64β†’128β†’256β†’256β†’512β†’512], kernel sizes [15, 11, 7, 5, 3, 3], BatchNorm, GELU activation, and MaxPool downsampling. Global average pooling feeds into a 2-layer MLP head.

  • Input: [B, 1, T] mono waveform at 16 kHz
  • No spectral transform needed
  • Largest model (1.9M params) but lowest accuracy

3. Mel CNN Classifier

Computes an 80-band mel spectrogram from the raw waveform (differentiable), then runs through a 5-layer 2D CNN with BatchNorm, GELU, and MaxPool2d. Adaptive average pooling feeds into a 2-layer MLP head.

  • Input: [B, 1, T] mono waveform at 16 kHz
  • Mel spectrogram computed internally (n_fft=1024, hop=256, 80 mels)
  • Smallest model (1.0M params), very high accuracy

Training Data

Models were trained on a balanced binary classification task:

  • Clean (label=0): 2,000 real podcast audio clips (oversampled 4Γ— = 8,000)
  • Artifact (label=1): 8,000 total
    • 6,000 TTS-predicted audio decoded through a BigVGAN-based vocoder at 3 noise levels (Οƒ = 0.15, 0.275, 0.4)
    • 2,000 synthetic comb-filtered versions of the clean audio (applied on-the-fly with random delay 0–6ms and wet mix 0.3–0.95)

Total: 16,000 samples per epoch. 90/10 train/val split, stratified by source type.

Training Configuration

  • Optimizer: AdamW (lr=3e-4, weight_decay=1e-4)
  • Scheduler: CosineAnnealingLR
  • Loss: BCEWithLogitsLoss
  • Batch size: 32
  • Max audio length: 10 seconds (160,000 samples)
  • Gradient clipping: max_norm=1.0
  • 30 epochs, best checkpoint selected by validation accuracy

Checkpoint Format

Each .pt file contains:

{
    "model_state_dict": ...,
    "architecture": "stft_classifier",  # or "waveform_1d" or "mel_classifier"
    "epoch": 25,
    "val_acc": 0.9988,
    "val_f1": 0.9987,
    "n_params": 1688833,
    "input_sr": 16000,
}

Intended Use

  • Quality filtering of TTS training data
  • Automated QA for text-to-speech pipelines
  • Differentiable auxiliary loss for vocoder fine-tuning
  • Detection of comb-filter artifacts in audio processing chains

Limitations

  • Trained on a specific vocoder architecture (BigVGAN-based). May not generalize to all TTS systems without fine-tuning.
  • Models disagree on some samples β€” the STFT model is most reliable overall, while Waveform1D tends to have higher false positive rates.
  • Not trained to detect other audio quality issues (clipping, noise, bandwidth limitation).
  • 10-second maximum context window; longer files are truncated.

License

Apache 2.0

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support