File size: 2,142 Bytes
6895a87
 
 
 
 
 
 
 
 
 
 
 
 
543f089
6895a87
543f089
 
 
 
 
 
6895a87
 
 
543f089
6895a87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f0b566e
 
 
 
6895a87
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import torch
import torch.nn as nn
from torchvision import models, transforms
import gradio as gr
from PIL import Image

# Define the model architecture (must match the saved model)
class_names = ["cordana", "healthy", "pestalotiopsis", "sigatoka"]
def load_model():
    model = models.alexnet(pretrained=False)
    num_ftrs = model.classifier[6].in_features
    model.classifier[6] = nn.Linear(num_ftrs, len(class_names))  # Adjust this for your number of classes

    # Correctly map the model to CPU if CUDA is not available
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    if device == torch.device('cpu'):
        model.load_state_dict(torch.load('model_alexnet.pth', map_location=device))
    else:
        model.load_state_dict(torch.load('model_alexnet.pth'))

    model.eval()  # Set to evaluation mode
    model.to(device)
    return model, device


model_alexnet, device = load_model()

# Image transformations
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Prediction function
def predict_image(image):
    image = Image.fromarray(image.astype('uint8'), 'RGB')
    image = transform(image).unsqueeze(0)  # Add batch dimension
    image = image.to(device)
    
    with torch.no_grad():
        outputs = model_alexnet(image)
        _, predicted = torch.max(outputs, 1)
        predicted = predicted.cpu().numpy()
    
    return class_names[predicted[0]]  # Adjust this if needed

    
iface = gr.Interface(fn=predict_image, inputs="image", outputs="label",
                     description="This model is a fine-tuned version of AlexNet specifically designed to identify four types of diseases in banana tree leaves. It can classify the leaves as Cordana, Healthy, Pestalotiopsis, or Sigatoka. Upload a photo of a banana leaf and the model will help you determine its health condition.",
                     examples=[
        'cordana.jpeg',
        'healthy.jpeg',
        'pestalotiopsis.jpeg',
        'sigatoka.jpeg'
    ]
)

iface.launch()