π§ Hybrid CNN-ViT for Brain Tumor Classification
A novel deep learning framework for automated brain tumor detection and classification from MRI images. Combines a ResNet50 CNN backbone with a 6-layer Vision Transformer and learnable radiomics features via multimodal fusion.
Model Description
This model classifies brain MRI scans into 4 categories:
| Label | Description |
|---|---|
glioma |
Glioma tumor |
meningioma |
Meningioma tumor |
no_tumor |
Healthy brain (no tumor) |
pituitary |
Pituitary tumor |
Architecture
Input MRI (224Γ224Γ3)
β
ββββΊ ResNet50 CNN βββΊ Feature Maps (7Γ7Γ2048)
β β
β Patch Embedding
β β
β ViT Encoder (6 blocks, 8 heads)
β β
β CLS Token (512-d)
β
ββββΊ Radiomics Branch βββΊ Texture + Shape Features (128-d)
β
ββββΊ CNN Global Pool βββΊ CNN Features (2048-d)
β
ββββββββββββββββΌβββββββββββββββ
β β β
CNN (2048) ViT (512) Radiomics (128)
β β β
βββββββββ Concat Fusion βββββββ
β
MLP Classifier
β
4 Class Logits
Key Innovations
- Hybrid CNN + ViT: CNN captures local texture/shape; ViT captures global context and long-range dependencies
- Learnable Radiomics: Dual-branch CNN (texture + shape) providing hand-crafted-style features in a differentiable way
- Feature Fusion: Concatenation-based fusion with LayerNorm and GELU for stable multimodal learning
- Self-Supervised Pre-Training: Masked Autoencoder (MAE) pre-training for better generalization
Performance
| Model Variant | Accuracy | F1-Score | AUC |
|---|---|---|---|
| ResNet50 (baseline) | 93% | 0.92 | 0.97 |
| Hybrid CNN-ViT | 96% | 0.95 | 0.99 |
| + Self-Supervised Pre-Training | 97% | 0.96 | 0.99 |
| + Radiomics (Full Model) | 98% | 0.97 | 0.99 |
Usage
Quick Inference
import torch
from PIL import Image
from torchvision import transforms
from model import HybridCNNViT
# Load model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = HybridCNNViT(
num_classes=4,
cnn_backbone="resnet50",
cnn_pretrained=False,
vit_embed_dim=512,
vit_depth=6,
vit_num_heads=8,
use_radiomics=True,
radiomics_dim=128,
dropout=0.3,
)
checkpoint = torch.load("best_model.pth", map_location=device)
state_dict = checkpoint.get("model_state_dict", checkpoint)
model.load_state_dict(state_dict, strict=False)
model.eval().to(device)
# Preprocess
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
image = Image.open("brain_mri.jpg").convert("RGB")
input_tensor = transform(image).unsqueeze(0).to(device)
# Predict
with torch.no_grad():
output = model(input_tensor)
probs = torch.softmax(output["logits"], dim=-1)
pred_class = probs.argmax(dim=-1).item()
class_names = ["glioma", "meningioma", "no_tumor", "pituitary"]
print(f"Prediction: {class_names[pred_class]} ({probs[0][pred_class]:.1%})")
Training Details
- Dataset: Brain Tumor MRI Dataset (~7,000 MRI images)
- Optimizer: AdamW (lr=1e-4, weight_decay=0.01)
- Scheduler: Cosine annealing with 5-epoch warmup
- Augmentation: Random rotation (Β±15Β°), horizontal flip, elastic deformation, MixUp (Ξ±=0.2)
- Regularization: Label smoothing (0.1), gradient clipping (1.0), dropout (0.3)
- Hardware: NVIDIA GPU with mixed precision (FP16) training
Limitations & Ethical Considerations
β οΈ This model is for research and educational purposes only.
- Not FDA-approved for clinical diagnosis
- Trained on a single publicly available dataset β may not generalize to all MRI scanners/protocols
- Should be used as a decision-support tool, not a replacement for radiologist evaluation
- Performance may vary on MRI sequences not seen during training (e.g., contrast-enhanced)
Citation
@misc{vishnuk2024braintumor,
title={Hybrid CNN-ViT Framework for Brain Tumor Classification with Radiomics Integration},
author={Vishnu K},
year={2024},
publisher={Hugging Face},
url={https://huggingface.co/ZorroJurro/brain-tumor-cnn-vit}
}
Author
Vishnu K β Hugging Face Β· GitHub
- Downloads last month
- 36
Space using Zorrojurro/brain-tumor-cnn-vit 1
Evaluation results
- accuracy on Brain Tumor MRI Datasetself-reported0.980
- f1 on Brain Tumor MRI Datasetself-reported0.970
- roc_auc on Brain Tumor MRI Datasetself-reported0.990