|
import gradio as gr |
|
import torch |
|
from torchvision import transforms |
|
from PIL import Image |
|
from torchvision.transforms import InterpolationMode |
|
from torchvision.models import efficientnet_b3 |
|
|
|
|
|
class_names = ['glioma_tumor', 'meningioma_tumor', 'no_tumor', 'pituitary_tumor'] |
|
model = efficientnet_b3(weights=None) |
|
model.classifier[1] = torch.nn.Linear(in_features=1536, out_features=len(class_names)) |
|
|
|
model.load_state_dict(torch.load( |
|
"Eff_net_b3_01_brain_tumor.pth", |
|
map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
)) |
|
model.eval() |
|
|
|
|
|
img_transform = transforms.Compose([ |
|
transforms.Resize(320, interpolation=InterpolationMode.BICUBIC), |
|
transforms.CenterCrop(300), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], |
|
std=[0.229, 0.224, 0.225]) |
|
]) |
|
|
|
|
|
def predict(image): |
|
transformed_image = img_transform(image).unsqueeze(0) |
|
with torch.inference_mode(): |
|
preds = model(transformed_image) |
|
probs = torch.softmax(preds, dim=1) |
|
label_idx = torch.argmax(probs, dim=1).item() |
|
class_label = class_names[label_idx] |
|
confidence = probs[0, label_idx].item() |
|
return class_label, confidence |
|
|
|
|
|
with gr.Blocks(title="π§ Brain Tumor MRI Classifier") as demo: |
|
gr.Markdown("## π§ Brain Tumor Classifier (EfficientNet-B3)") |
|
gr.Markdown(""" |
|
Upload an MRI scan of the brain, and this model will classify it as one of: |
|
- **Glioma Tumor** |
|
- **Meningioma Tumor** |
|
- **Pituitary Tumor** |
|
- **No Tumor** |
|
|
|
Uses EfficientNet-B3 trained on labeled brain MRI dataset. |
|
""") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
image_input = gr.Image(type="pil", label="Upload MRI Image") |
|
predict_button = gr.Button("π Predict") |
|
clear_button = gr.Button("π§Ή Clear") |
|
|
|
with gr.Column(): |
|
output_label = gr.Label(label="Predicted Class") |
|
confidence_slider = gr.Slider(minimum=0, maximum=1, step=0.01, label="Confidence Score") |
|
|
|
predict_button.click(fn=predict, inputs=image_input, outputs=[output_label, confidence_slider]) |
|
clear_button.click(lambda: (None, None), inputs=[], outputs=[image_input, output_label, confidence_slider]) |
|
|
|
gr.Markdown("---") |
|
gr.Markdown( |
|
"<center>π€ Developed by [Sagar Bisht](https://www.linkedin.com/in/sagarbisht123)</center>", |
|
elem_id="footer" |
|
) |
|
|
|
demo.launch(share=True) |
|
|