LDW-CNet — Learnable Discrete Wavelet CNN for Brain Tumor MRI Classification
LDW-CNet is a convolutional neural network that combines a learnable Discrete Wavelet Transform (DWT) front-end with a residual+SE backbone for 4-class brain MRI classification. It is trained on the [masoudnickparvar brain-tumor-mri-dataset] (https://www.kaggle.com/datasets/masoudnickparvar/brain-tumor-mri-dataset) and is intended as a research artefact demonstrating the value of integrating classical signal-processing priors into deep networks.
Model description
The model has three stages:
- Learnable DWT layer. Decomposes the input into LL/LH/HL/HH subbands using
depthwise separable convolutions whose filters are initialised from
db4but are learnable. A wavelet-constraint loss keepsg = QMF(h)and‖h‖² = ‖g‖² = 1during training, so the filters remain valid orthogonal wavelets. - SE-recalibration of the wavelet subbands so the network can re-weight subband importance per input.
- Residual CNN backbone with 4 stages, channel attention (SE), and linearly increasing stochastic-depth regularization (0 → 0.2). The final classifier is a 2-layer LayerNorm + SiLU MLP.
Variants in this repo
| Model | Description |
|---|---|
PlainCNN |
Backbone alone — no wavelet front-end |
FixedDWT_CNN |
Backbone with fixed db4 wavelet |
LearnDWT_NoLwc |
Learnable DWT, no constraint loss, no SE on subbands |
LearnDWT_SE_NoLwc |
Learnable DWT + SE, no constraint loss |
LDWCNet_Full |
Full model: learnable DWT + SE + wavelet-constraint loss + stochastic depth |
Intended use
- Primary: a research demonstration of wavelet–CNN hybrids for medical imaging.
- Out of scope: clinical decision making. This model is not a medical device and has not been validated for diagnostic use. It must not be used to inform real patient care.
How to load
import torch
from huggingface_hub import hf_hub_download
# pip install pywavelets
ckpt = torch.load(hf_hub_download(
'Shanmuk4622/ldwcnet-brain-mri-5-model-model', 'checkpoints/LDWCNet_Full_best.pt',
repo_type='model'), map_location='cpu', weights_only=False)
# Build the architecture (copy LDWCNet class from this repo's notebooks)
# from notebook_02_architecture import LDWCNet
model = LDWCNet(in_channels=3, num_classes=4, wavelet_init='db4',
stochastic_depth_max=0.2)
model.load_state_dict(ckpt['state_dict'])
model.eval()
Performance — held-out test split
Headline metrics (Test-Time Augmentation, 8-way)
| Metric | Value |
|---|---|
| Accuracy | 0.8950 |
| Balanced Accuracy | 0.8950 |
| Macro F1 | 0.8925 |
| Weighted F1 | 0.8925 |
| MCC | 0.8627 |
| Cohen's κ | 0.8600 |
| Macro ROC-AUC | 0.9749 |
| Macro AP | 0.9506 |
Bootstrap 95% CI on Macro-F1 (n=2000 resamples): [0.8761, 0.9071]
Per-class breakdown — LDWCNet_Full
| Class | Precision | Recall | F1 | Support |
|---|---|---|---|---|
| glioma | 0.9589 | 0.7575 | 0.8464 | 400 |
| meningioma | 0.8760 | 0.8475 | 0.8615 | 400 |
| notumor | 0.8955 | 0.9850 | 0.9381 | 400 |
| pituitary | 0.8665 | 0.9900 | 0.9242 | 400 |
All variants
| Model | Accuracy | Macro F1 | MCC | Macro ROC-AUC |
|---|---|---|---|---|
| LDWCNet_Full | 0.8950 | 0.8925 | 0.8627 | 0.9749 |
| FixedDWT_CNN | 0.8625 | 0.8589 | 0.8204 | 0.9605 |
| LearnDWT_SE_NoLwc | 0.8500 | 0.8460 | 0.8032 | 0.9559 |
| LearnDWT_NoLwc | 0.8469 | 0.8431 | 0.7996 | 0.9521 |
| PlainCNN | 0.8444 | 0.8388 | 0.7971 | 0.9569 |
Efficiency
| Metric | Value |
|---|---|
| Parameters | 5.08M |
| GFLOPs | 3.19 |
| Latency (batch=1) | 2.89 ± 0.10 ms |
| Throughput (batch=32) | 528 img/s |
Measured on the training device (see reports/efficiency.csv).
Training procedure
- Optimizer: AdamW, base LR 3e-4, weight decay 1e-4 for non-DWT params (0 for DWT)
- Schedule: OneCycleLR, cosine anneal, 10% warm-up
- Loss: class-weighted label-smoothing CE (smoothing=0.05) + cosine-scheduled wavelet-constraint loss (λ_wc_max=0.5, warm-up 10 ep, decays to 0)
- Regularisation: Mixup + CutMix at p=0.5 (50/50 split), stochastic depth 0→0.2, EMA (decay 0.9998 with timm-style warm-up ramp), TTA at evaluation
- Augmentation: hflip, vflip, rotate ±20°, RandomBrightnessContrast, ShiftScaleRotate, ElasticTransform, GridDistortion, CLAHE, GaussNoise, CoarseDropout
- Warm-start for LDWCNet_Full: backbone + SE + classifier initialised from
FixedDWT_CNN_best.pt; DWT filters frozen for the first 30 epochs, then unfrozen with λ_wc=0.5 to keep them as valid wavelets - Epochs: 120 | Batch size: 32 per GPU | Image size: 224×224
Limitations and biases
- Trained on a single public dataset; performance on data from other scanners, populations, or protocols is unknown and likely worse.
- Patient demographics in the dataset are not balanced and not fully documented; the model's behaviour across demographic subgroups has not been audited.
- The model assumes a clean axial T1/T2 brain MRI slice as input. Behaviour on non-brain or non-MRI inputs is undefined.
- Calibration: ECE before temperature scaling is 0.0665; consider applying T=0.96 before using probabilities as confidence.
Citation
@misc{ldwcnet,
title = {LDW-CNet: Learnable Discrete Wavelet CNN for Brain Tumor MRI Classification},
author = {LDW-CNet authors},
year = {2026},
howpublished = {HuggingFace, url:{https://huggingface.co/Shanmuk4622/ldwcnet-brain-mri-5-model-model}}
}
License
MIT
Evaluation results
- accuracy on Brain Tumor MRI Datasetself-reported0.895
- Macro F1 on Brain Tumor MRI Datasetself-reported0.892
- matthews_correlation on Brain Tumor MRI Datasetself-reported0.863
- Macro ROC-AUC on Brain Tumor MRI Datasetself-reported0.975