S4G4R-Byte's picture
Initial commit with app.py, .gitignore, and requirements.txt
62a4c73
import gradio as gr
import torch
from torchvision import transforms
from PIL import Image
from torchvision.transforms import InterpolationMode
from torchvision.models import efficientnet_b3
# Model setup
class_names = ['glioma_tumor', 'meningioma_tumor', 'no_tumor', 'pituitary_tumor']
model = efficientnet_b3(weights=None)
model.classifier[1] = torch.nn.Linear(in_features=1536, out_features=len(class_names))
model.load_state_dict(torch.load(
"Eff_net_b3_01_brain_tumor.pth",
map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu")
))
model.eval()
# Image transform
img_transform = transforms.Compose([
transforms.Resize(320, interpolation=InterpolationMode.BICUBIC),
transforms.CenterCrop(300),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
# Prediction function
def predict(image):
transformed_image = img_transform(image).unsqueeze(0)
with torch.inference_mode():
preds = model(transformed_image)
probs = torch.softmax(preds, dim=1)
label_idx = torch.argmax(probs, dim=1).item()
class_label = class_names[label_idx]
confidence = probs[0, label_idx].item()
return class_label, confidence
# Gradio Blocks UI
with gr.Blocks(title="🧠 Brain Tumor MRI Classifier") as demo:
gr.Markdown("## 🧠 Brain Tumor Classifier (EfficientNet-B3)")
gr.Markdown("""
Upload an MRI scan of the brain, and this model will classify it as one of:
- **Glioma Tumor**
- **Meningioma Tumor**
- **Pituitary Tumor**
- **No Tumor**
Uses EfficientNet-B3 trained on labeled brain MRI dataset.
""")
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="Upload MRI Image")
predict_button = gr.Button("πŸ” Predict")
clear_button = gr.Button("🧹 Clear")
with gr.Column():
output_label = gr.Label(label="Predicted Class")
confidence_slider = gr.Slider(minimum=0, maximum=1, step=0.01, label="Confidence Score")
predict_button.click(fn=predict, inputs=image_input, outputs=[output_label, confidence_slider])
clear_button.click(lambda: (None, None), inputs=[], outputs=[image_input, output_label, confidence_slider])
gr.Markdown("---")
gr.Markdown(
"<center>πŸ‘€ Developed by [Sagar Bisht](https://www.linkedin.com/in/sagarbisht123)</center>",
elem_id="footer"
)
demo.launch(share=True)