import torch import torch.nn as nn from torchvision import models, transforms import gradio as gr from PIL import Image # Define the model architecture (must match the saved model) class_names = ["cordana", "healthy", "pestalotiopsis", "sigatoka"] def load_model(): model = models.alexnet(pretrained=False) num_ftrs = model.classifier[6].in_features model.classifier[6] = nn.Linear(num_ftrs, len(class_names)) # Adjust this for your number of classes # Load the model weights model.load_state_dict(torch.load('model_alexnet.pth')) model.eval() # Set to evaluation mode device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model.to(device) return model, device model_alexnet, device = load_model() # Image transformations transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) # Prediction function def predict_image(image): image = Image.fromarray(image.astype('uint8'), 'RGB') image = transform(image).unsqueeze(0) # Add batch dimension image = image.to(device) with torch.no_grad(): outputs = model_alexnet(image) _, predicted = torch.max(outputs, 1) predicted = predicted.cpu().numpy() return class_names[predicted[0]] # Adjust this if needed iface = gr.Interface(fn=predict_image, inputs="image", outputs="label", description="This model is a fine-tuned version of AlexNet specifically designed to identify four types of diseases in banana tree leaves. It can classify the leaves as Cordana, Healthy, Pestalotiopsis, or Sigatoka. Upload a photo of a banana leaf and the model will help you determine its health condition.", examples=[ 'data/test/cordana/1.jpeg', 'data/test/healthy/5.jpeg', 'data/test/pestalotiopsis/5.jpeg', 'data/test/sigatoka/1.jpeg' ] ) iface.launch()