File size: 2,519 Bytes
62a4c73 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
import gradio as gr
import torch
from torchvision import transforms
from PIL import Image
from torchvision.transforms import InterpolationMode
from torchvision.models import efficientnet_b3
# Model setup
class_names = ['glioma_tumor', 'meningioma_tumor', 'no_tumor', 'pituitary_tumor']
model = efficientnet_b3(weights=None)
model.classifier[1] = torch.nn.Linear(in_features=1536, out_features=len(class_names))
model.load_state_dict(torch.load(
"Eff_net_b3_01_brain_tumor.pth",
map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu")
))
model.eval()
# Image transform
img_transform = transforms.Compose([
transforms.Resize(320, interpolation=InterpolationMode.BICUBIC),
transforms.CenterCrop(300),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
# Prediction function
def predict(image):
transformed_image = img_transform(image).unsqueeze(0)
with torch.inference_mode():
preds = model(transformed_image)
probs = torch.softmax(preds, dim=1)
label_idx = torch.argmax(probs, dim=1).item()
class_label = class_names[label_idx]
confidence = probs[0, label_idx].item()
return class_label, confidence
# Gradio Blocks UI
with gr.Blocks(title="π§ Brain Tumor MRI Classifier") as demo:
gr.Markdown("## π§ Brain Tumor Classifier (EfficientNet-B3)")
gr.Markdown("""
Upload an MRI scan of the brain, and this model will classify it as one of:
- **Glioma Tumor**
- **Meningioma Tumor**
- **Pituitary Tumor**
- **No Tumor**
Uses EfficientNet-B3 trained on labeled brain MRI dataset.
""")
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="Upload MRI Image")
predict_button = gr.Button("π Predict")
clear_button = gr.Button("π§Ή Clear")
with gr.Column():
output_label = gr.Label(label="Predicted Class")
confidence_slider = gr.Slider(minimum=0, maximum=1, step=0.01, label="Confidence Score")
predict_button.click(fn=predict, inputs=image_input, outputs=[output_label, confidence_slider])
clear_button.click(lambda: (None, None), inputs=[], outputs=[image_input, output_label, confidence_slider])
gr.Markdown("---")
gr.Markdown(
"<center>π€ Developed by [Sagar Bisht](https://www.linkedin.com/in/sagarbisht123)</center>",
elem_id="footer"
)
demo.launch(share=True)
|