Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from torchvision import transforms | |
from PIL import Image | |
from torchvision.transforms import InterpolationMode | |
from torchvision.models import efficientnet_b3 | |
# Model setup | |
class_names = ['glioma_tumor', 'meningioma_tumor', 'no_tumor', 'pituitary_tumor'] | |
model = efficientnet_b3(weights=None) | |
model.classifier[1] = torch.nn.Linear(in_features=1536, out_features=len(class_names)) | |
model.load_state_dict(torch.load( | |
"Eff_net_b3_01_brain_tumor.pth", | |
map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
)) | |
model.eval() | |
# Image transform | |
img_transform = transforms.Compose([ | |
transforms.Resize(320, interpolation=InterpolationMode.BICUBIC), | |
transforms.CenterCrop(300), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]) | |
]) | |
# Prediction function | |
def predict(image): | |
transformed_image = img_transform(image).unsqueeze(0) | |
with torch.inference_mode(): | |
preds = model(transformed_image) | |
probs = torch.softmax(preds, dim=1) | |
label_idx = torch.argmax(probs, dim=1).item() | |
class_label = class_names[label_idx] | |
confidence = probs[0, label_idx].item() | |
return class_label, confidence | |
# Gradio Blocks UI | |
with gr.Blocks(title="π§ Brain Tumor MRI Classifier") as demo: | |
gr.Markdown("## π§ Brain Tumor Classifier (EfficientNet-B3)") | |
gr.Markdown(""" | |
Upload an MRI scan of the brain, and this model will classify it as one of: | |
- **Glioma Tumor** | |
- **Meningioma Tumor** | |
- **Pituitary Tumor** | |
- **No Tumor** | |
Uses EfficientNet-B3 trained on labeled brain MRI dataset. | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
image_input = gr.Image(type="pil", label="Upload MRI Image") | |
predict_button = gr.Button("π Predict") | |
clear_button = gr.Button("π§Ή Clear") | |
with gr.Column(): | |
output_label = gr.Label(label="Predicted Class") | |
confidence_slider = gr.Slider(minimum=0, maximum=1, step=0.01, label="Confidence Score") | |
predict_button.click(fn=predict, inputs=image_input, outputs=[output_label, confidence_slider]) | |
clear_button.click(lambda: (None, None), inputs=[], outputs=[image_input, output_label, confidence_slider]) | |
gr.Markdown("---") | |
gr.Markdown( | |
"<center>π€ Developed by [Sagar Bisht](https://www.linkedin.com/in/sagarbisht123)</center>", | |
elem_id="footer" | |
) | |
demo.launch(share=True) | |