Juliofc's picture
Update app.py
543f089 verified
raw
history blame contribute delete
No virus
2.14 kB
import torch
import torch.nn as nn
from torchvision import models, transforms
import gradio as gr
from PIL import Image
# Define the model architecture (must match the saved model)
class_names = ["cordana", "healthy", "pestalotiopsis", "sigatoka"]
def load_model():
model = models.alexnet(pretrained=False)
num_ftrs = model.classifier[6].in_features
model.classifier[6] = nn.Linear(num_ftrs, len(class_names)) # Adjust this for your number of classes
# Correctly map the model to CPU if CUDA is not available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if device == torch.device('cpu'):
model.load_state_dict(torch.load('model_alexnet.pth', map_location=device))
else:
model.load_state_dict(torch.load('model_alexnet.pth'))
model.eval() # Set to evaluation mode
model.to(device)
return model, device
model_alexnet, device = load_model()
# Image transformations
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# Prediction function
def predict_image(image):
image = Image.fromarray(image.astype('uint8'), 'RGB')
image = transform(image).unsqueeze(0) # Add batch dimension
image = image.to(device)
with torch.no_grad():
outputs = model_alexnet(image)
_, predicted = torch.max(outputs, 1)
predicted = predicted.cpu().numpy()
return class_names[predicted[0]] # Adjust this if needed
iface = gr.Interface(fn=predict_image, inputs="image", outputs="label",
description="This model is a fine-tuned version of AlexNet specifically designed to identify four types of diseases in banana tree leaves. It can classify the leaves as Cordana, Healthy, Pestalotiopsis, or Sigatoka. Upload a photo of a banana leaf and the model will help you determine its health condition.",
examples=[
'cordana.jpeg',
'healthy.jpeg',
'pestalotiopsis.jpeg',
'sigatoka.jpeg'
]
)
iface.launch()