TTS + Comb Artifact Detectors
Three lightweight CNN-based binary classifiers that detect waveform-domain artifacts in TTS-generated audio. These models identify vocoder artifacts (metallic resonance, buzzing) and comb-filtering artifacts that emerge from neural vocoders such as BigVGAN/HiFi-GAN.
All models operate on 16 kHz mono waveforms and output an artifact score in [0, 1]: 0 = clean, 1 = artifact detected.
Models
| Model | Architecture | Parameters | Val Accuracy | Val F1 | File |
|---|---|---|---|---|---|
| STFT Classifier | Multi-resolution STFT + 2D CNN | 1,688,833 | 99.88% | 99.87% | best_stft_classifier.pt |
| Mel CNN | Mel spectrogram + 2D CNN | 1,012,417 | 99.75% | 99.75% | best_mel_classifier.pt |
| Waveform 1D | Raw waveform 1D CNN | 1,898,753 | 97.44% | 97.39% | best_waveform_1d.pt |
Per-Source Validation Accuracy
Each model was validated on three source types:
| Model | Predicted (vocoder artifacts) | Real (clean audio) | Comb (synthetic comb filter) |
|---|---|---|---|
| STFT Classifier | 100% | 100% | 99.0% |
| Mel CNN | 100% | 100% | 98.0% |
| Waveform 1D | 99.8% | 99.4% | 82.5% |
Quick Start
Installation
pip install torch torchaudio soundfile numpy
Python API
from artifact_detector import load_all_models, score_file, score_file_all
# Load all 3 models
models = load_all_models(".", device="cuda")
# Score a single file with all models
scores = score_file_all(models, "my_tts_output.wav")
for name, score in scores.items():
print(f" {name}: {score:.4f}")
# β stft_classifier: 0.9823
# β waveform_1d: 0.8901
# β mel_classifier: 0.0012
# Ensemble average
import numpy as np
ensemble = np.mean(list(scores.values()))
print(f" ensemble: {ensemble:.4f}")
# Single model
from artifact_detector import load_model, score_file
model = load_model("best_stft_classifier.pt", device="cuda")
score = score_file(model, "my_tts_output.wav")
Command Line
# Score a file with all models
python artifact_detector.py --model-dir . --input audio.wav
# Score a directory of WAV files
python artifact_detector.py --model-dir . --input /path/to/wavs/ --ext wav
# Score with a specific model only
python artifact_detector.py --checkpoint best_stft_classifier.pt --input audio.wav
# Custom threshold (default 0.5)
python artifact_detector.py --model-dir . --input audio.wav --threshold 0.3
Scoring Raw Tensors
import torch
from artifact_detector import load_model, score_waveform
model = load_model("best_stft_classifier.pt")
# waveform should be a 1-D float32 tensor at 16 kHz
waveform = torch.randn(160000) # 10 seconds
score = score_waveform(model, waveform)
Using as a Differentiable Loss
All models are fully differentiable and can be used as auxiliary training losses:
model = load_model("best_stft_classifier.pt", device="cuda")
# During training (gradients flow through)
waveform = vocoder(mel) # [B, 1, T] @ 16kHz
artifact_score = model.artifact_score(waveform) # [B, 1]
artifact_loss = artifact_score.mean()
total_loss = reconstruction_loss + 0.1 * artifact_loss
total_loss.backward()
Architecture Details
1. Multi-Resolution STFT Classifier (recommended)
The best-performing model. Computes STFT at 4 resolutions (256/512/1024/2048 FFT sizes), processes each through a 5-layer 2D CNN on log-magnitude spectrograms, concatenates the 128-dim feature vectors from each resolution (512 total), and classifies through an MLP head.
- Input:
[B, 1, T]mono waveform at 16 kHz - STFT computed internally (fully differentiable)
- 4 resolution blocks with adaptive average pooling
- MLP head: 512 β 256 β 1
2. Waveform 1D CNN
Direct 1D convolution on raw waveform samples. 6 conv blocks with channels [1β64β128β256β256β512β512], kernel sizes [15, 11, 7, 5, 3, 3], BatchNorm, GELU activation, and MaxPool downsampling. Global average pooling feeds into a 2-layer MLP head.
- Input:
[B, 1, T]mono waveform at 16 kHz - No spectral transform needed
- Largest model (1.9M params) but lowest accuracy
3. Mel CNN Classifier
Computes an 80-band mel spectrogram from the raw waveform (differentiable), then runs through a 5-layer 2D CNN with BatchNorm, GELU, and MaxPool2d. Adaptive average pooling feeds into a 2-layer MLP head.
- Input:
[B, 1, T]mono waveform at 16 kHz - Mel spectrogram computed internally (n_fft=1024, hop=256, 80 mels)
- Smallest model (1.0M params), very high accuracy
Training Data
Models were trained on a balanced binary classification task:
- Clean (label=0): 2,000 real podcast audio clips (oversampled 4Γ = 8,000)
- Artifact (label=1): 8,000 total
- 6,000 TTS-predicted audio decoded through a BigVGAN-based vocoder at 3 noise levels (Ο = 0.15, 0.275, 0.4)
- 2,000 synthetic comb-filtered versions of the clean audio (applied on-the-fly with random delay 0β6ms and wet mix 0.3β0.95)
Total: 16,000 samples per epoch. 90/10 train/val split, stratified by source type.
Training Configuration
- Optimizer: AdamW (lr=3e-4, weight_decay=1e-4)
- Scheduler: CosineAnnealingLR
- Loss: BCEWithLogitsLoss
- Batch size: 32
- Max audio length: 10 seconds (160,000 samples)
- Gradient clipping: max_norm=1.0
- 30 epochs, best checkpoint selected by validation accuracy
Checkpoint Format
Each .pt file contains:
{
"model_state_dict": ...,
"architecture": "stft_classifier", # or "waveform_1d" or "mel_classifier"
"epoch": 25,
"val_acc": 0.9988,
"val_f1": 0.9987,
"n_params": 1688833,
"input_sr": 16000,
}
Intended Use
- Quality filtering of TTS training data
- Automated QA for text-to-speech pipelines
- Differentiable auxiliary loss for vocoder fine-tuning
- Detection of comb-filter artifacts in audio processing chains
Limitations
- Trained on a specific vocoder architecture (BigVGAN-based). May not generalize to all TTS systems without fine-tuning.
- Models disagree on some samples β the STFT model is most reliable overall, while Waveform1D tends to have higher false positive rates.
- Not trained to detect other audio quality issues (clipping, noise, bandwidth limitation).
- 10-second maximum context window; longer files are truncated.
License
Apache 2.0