import gradio as gr import torch from torchvision import transforms from PIL import Image from torchvision.transforms import InterpolationMode from torchvision.models import efficientnet_b3 # Model setup class_names = ['glioma_tumor', 'meningioma_tumor', 'no_tumor', 'pituitary_tumor'] model = efficientnet_b3(weights=None) model.classifier[1] = torch.nn.Linear(in_features=1536, out_features=len(class_names)) model.load_state_dict(torch.load( "Eff_net_b3_01_brain_tumor.pth", map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu") )) model.eval() # Image transform img_transform = transforms.Compose([ transforms.Resize(320, interpolation=InterpolationMode.BICUBIC), transforms.CenterCrop(300), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # Prediction function def predict(image): transformed_image = img_transform(image).unsqueeze(0) with torch.inference_mode(): preds = model(transformed_image) probs = torch.softmax(preds, dim=1) label_idx = torch.argmax(probs, dim=1).item() class_label = class_names[label_idx] confidence = probs[0, label_idx].item() return class_label, confidence # Gradio Blocks UI with gr.Blocks(title="๐ง Brain Tumor MRI Classifier") as demo: gr.Markdown("## ๐ง Brain Tumor Classifier (EfficientNet-B3)") gr.Markdown(""" Upload an MRI scan of the brain, and this model will classify it as one of: - **Glioma Tumor** - **Meningioma Tumor** - **Pituitary Tumor** - **No Tumor** Uses EfficientNet-B3 trained on labeled brain MRI dataset. """) with gr.Row(): with gr.Column(): image_input = gr.Image(type="pil", label="Upload MRI Image") predict_button = gr.Button("๐ Predict") clear_button = gr.Button("๐งน Clear") with gr.Column(): output_label = gr.Label(label="Predicted Class") confidence_slider = gr.Slider(minimum=0, maximum=1, step=0.01, label="Confidence Score") predict_button.click(fn=predict, inputs=image_input, outputs=[output_label, confidence_slider]) clear_button.click(lambda: (None, None), inputs=[], outputs=[image_input, output_label, confidence_slider]) gr.Markdown("---") gr.Markdown( "