Juliofc's picture
Update app.py
f0b566e verified
raw
history blame
No virus
1.97 kB
import torch
import torch.nn as nn
from torchvision import models, transforms
import gradio as gr
from PIL import Image
# Define the model architecture (must match the saved model)
class_names = ["cordana", "healthy", "pestalotiopsis", "sigatoka"]
def load_model():
model = models.alexnet(pretrained=False)
num_ftrs = model.classifier[6].in_features
model.classifier[6] = nn.Linear(num_ftrs, len(class_names)) # Adjust this for your number of classes
# Load the model weights
model.load_state_dict(torch.load('model_alexnet.pth'))
model.eval() # Set to evaluation mode
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
return model, device
model_alexnet, device = load_model()
# Image transformations
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# Prediction function
def predict_image(image):
image = Image.fromarray(image.astype('uint8'), 'RGB')
image = transform(image).unsqueeze(0) # Add batch dimension
image = image.to(device)
with torch.no_grad():
outputs = model_alexnet(image)
_, predicted = torch.max(outputs, 1)
predicted = predicted.cpu().numpy()
return class_names[predicted[0]] # Adjust this if needed
iface = gr.Interface(fn=predict_image, inputs="image", outputs="label",
description="This model is a fine-tuned version of AlexNet specifically designed to identify four types of diseases in banana tree leaves. It can classify the leaves as Cordana, Healthy, Pestalotiopsis, or Sigatoka. Upload a photo of a banana leaf and the model will help you determine its health condition.",
examples=[
'cordana.jpeg',
'healthy.jpeg',
'pestalotiopsis.jpeg',
'sigatoka.jpeg'
]
)
iface.launch()