Fawazzx's picture
Update app.py
02e983e verified
raw
history blame contribute delete
No virus
1.89 kB
import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
from torchvision import models
import gradio as gr
# Define transformations (must be the same as those used during training)
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Load the model architecture and weights
model = models.resnet50(weights=None) # Initialize model without pretrained weights
model.fc = nn.Linear(model.fc.in_features, 4) # Adjust final layer for 4 classes
# Load the state dictionary with map_location for CPU
model.load_state_dict(torch.load("alzheimer_model_resnet50.pth", map_location=torch.device('cpu')))
model.eval() # Set model to evaluation mode
# Define class labels (must match the dataset used during training)
class_labels = ["Mild_Demented 0", "Moderate_Demented 1", "Non_Demented 2", "Very_Mild_Demented 3"] # Replace with your class names
# Define the prediction function
def predict(image):
if isinstance(image, np.ndarray):
image = Image.fromarray(image.astype('uint8'), 'RGB')
else:
image = Image.open(image).convert("RGB")
image = transform(image).unsqueeze(0) # Add batch dimension
with torch.no_grad():
outputs = model(image)
_, predicted = torch.max(outputs.data, 1)
label = class_labels[predicted.item()]
return label
# Create a Gradio interface with examples
examples = [
["image.jpg"],
["image (1).jpg"],
["image (2).jpg"],
["image (3).jpg"]
]
iface = gr.Interface(
fn=predict,
inputs=gr.Image(type="numpy", label="Upload an MRI Image"),
outputs=gr.Textbox(label="Prediction"),
title="Alzheimer MRI Classification",
examples=examples
)
if __name__ == "__main__":
iface.launch()