File size: 2,519 Bytes
62a4c73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import gradio as gr
import torch
from torchvision import transforms
from PIL import Image
from torchvision.transforms import InterpolationMode
from torchvision.models import efficientnet_b3

# Model setup
class_names = ['glioma_tumor', 'meningioma_tumor', 'no_tumor', 'pituitary_tumor']
model = efficientnet_b3(weights=None)
model.classifier[1] = torch.nn.Linear(in_features=1536, out_features=len(class_names))

model.load_state_dict(torch.load(
    "Eff_net_b3_01_brain_tumor.pth",
    map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu")
))
model.eval()

# Image transform
img_transform = transforms.Compose([
    transforms.Resize(320, interpolation=InterpolationMode.BICUBIC),
    transforms.CenterCrop(300),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

# Prediction function
def predict(image):
    transformed_image = img_transform(image).unsqueeze(0)
    with torch.inference_mode():
        preds = model(transformed_image)
        probs = torch.softmax(preds, dim=1)
        label_idx = torch.argmax(probs, dim=1).item()
        class_label = class_names[label_idx]
        confidence = probs[0, label_idx].item()
    return class_label, confidence

# Gradio Blocks UI
with gr.Blocks(title="🧠 Brain Tumor MRI Classifier") as demo:
    gr.Markdown("## 🧠 Brain Tumor Classifier (EfficientNet-B3)")
    gr.Markdown("""
    Upload an MRI scan of the brain, and this model will classify it as one of:
    - **Glioma Tumor**
    - **Meningioma Tumor**
    - **Pituitary Tumor**
    - **No Tumor**

    Uses EfficientNet-B3 trained on labeled brain MRI dataset.
    """)

    with gr.Row():
        with gr.Column():
            image_input = gr.Image(type="pil", label="Upload MRI Image")
            predict_button = gr.Button("πŸ” Predict")
            clear_button = gr.Button("🧹 Clear")

        with gr.Column():
            output_label = gr.Label(label="Predicted Class")
            confidence_slider = gr.Slider(minimum=0, maximum=1, step=0.01, label="Confidence Score")

    predict_button.click(fn=predict, inputs=image_input, outputs=[output_label, confidence_slider])
    clear_button.click(lambda: (None, None), inputs=[], outputs=[image_input, output_label, confidence_slider])

    gr.Markdown("---")
    gr.Markdown(
        "<center>πŸ‘€ Developed by [Sagar Bisht](https://www.linkedin.com/in/sagarbisht123)</center>",
        elem_id="footer"
    )

demo.launch(share=True)