File size: 4,255 Bytes
e8a6c24
fa48c72
 
 
02654db
fa48c72
c152ea8
e8a6c24
fa48c72
 
 
 
 
139ad76
fda2cd7
 
 
 
 
 
fa48c72
83b0bb3
02654db
 
b73a2e9
 
59b0a85
b73a2e9
 
59b0a85
 
b73a2e9
59b0a85
b73a2e9
 
 
 
 
02654db
 
 
fa48c72
 
 
 
 
 
 
 
 
 
 
83b0bb3
 
 
bb8f6a5
 
 
4f172d5
 
 
bb8f6a5
83b0bb3
 
 
 
 
bb8f6a5
c152ea8
 
bb8f6a5
fa48c72
 
bb8f6a5
fa48c72
 
5bbab1c
 
bb8f6a5
5bbab1c
bb8f6a5
883f925
2c24524
139ad76
 
 
 
 
 
bb8f6a5
fda2cd7
fa48c72
83b0bb3
fa48c72
fda2cd7
fa48c72
fda2cd7
fa48c72
83b0bb3
 
 
 
 
 
 
 
 
 
 
fa48c72
 
 
bb8f6a5
 
 
 
 
139ad76
bb8f6a5
066233a
 
fa48c72
139ad76
fa48c72
2a995e9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import gradio as gr
import torch
import torch.nn as nn
from torchvision import transforms
import torchvision.models as models
from PIL import Image
import numpy as np

# Function to load class names from a file
def load_class_names(file_path):
    with open(file_path, 'r') as f:
        class_names = [line.strip() for line in f.readlines()]
    return class_names

# Function to load Wikipedia links from a file
def load_class_info(file_path):
    with open(file_path, 'r') as f:
        class_info = [line.strip() for line in f.readlines()]
    return class_info

# Function to load the model from a .pkl file
def load_model(model_path, model_type):
    # Load the model state dictionary
    model_state_dict = torch.load(model_path, map_location=torch.device('cpu'))

    # Create an instance of the model based on model_type
    if model_type == 'mobilenetv2':
        model = models.mobilenet_v2(pretrained=False)
        model.classifier[1] = nn.Linear(model.last_channel, num_classes)
    elif model_type == 'resnet18':
        model = models.resnet18(pretrained=False)
        model.fc = nn.Linear(model.fc.in_features, num_classes)
    elif model_type == 'densenet121':
        model = models.densenet121(pretrained=False)
        model.classifier = nn.Linear(model.classifier.in_features, num_classes)
    else:
        raise ValueError(f"Unsupported model type: {model_type}")

    # Load the state dictionary into the model
    model.load_state_dict(model_state_dict)
    # Set the model to evaluation mode
    model.eval()
    return model

# Define the transformation
val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Define the prediction function
def predict_image(image, model_choice):
    global model, current_model

    # Check if a model is selected
    if model_choice not in model_paths:
        return "Error: Please select a valid model.", ""
     # Check if an image is provided
    if image is None:
        return "Error: Please upload an image.", ""

    # Load the selected model if it's not already loaded
    if model_choice != current_model:
        model_path = model_paths[model_choice]
        model = load_model(model_path, model_choice)
        current_model = model_choice
    # Convert the NumPy array to a PIL Image if needed
    if isinstance(image, np.ndarray):
        image = Image.fromarray(image.astype('uint8'), 'RGB')
   
    image = val_transform(image)
    image = image.unsqueeze(0)  # Add batch dimension

    with torch.no_grad():
        outputs = model(image)
        probabilities = torch.nn.functional.softmax(outputs[0], dim=0)
        confidence, preds = torch.max(probabilities, 0)

    confidence_score = confidence.item() * 100

    if confidence_score < 30:
        result = "Not identified"
        html_result = ""
    else:
        class_name = class_names[preds.item()]
        wiki_link = class_info[preds.item()]
        result = f"{class_name}: {confidence_score:.2f}%"
        html_result = f"<h1><br><a href='{wiki_link}' target='_blank'>More Info</a></h1>"

    return result, html_result

# Load class names and class info
class_file_path = 'classes.txt'
class_info_path = 'classinfo.txt'
class_names = load_class_names(class_file_path)
class_info = load_class_info(class_info_path)
num_classes = len(class_names)

# Define model paths
model_paths = {
    'densenet121': 'densenet121_15EpochsPretrainedNoExtractionNoLR_model.pkl',
    'resnet18': 'resnet18_25EpochsPretrainedExtractionNoLR_model.pkl',
    'mobilenetv2': 'mobilenetv2_25EpochsPretrainedExtractionNoLR_model.pkl'
}

# Set default model
current_model = 'densenet121'
model = load_model(model_paths[current_model], current_model)

# Create the Gradio interface
iface = gr.Interface(
    fn=predict_image,
    inputs=[
        gr.Image(height=500),
        gr.Dropdown(choices=["densenet121", "resnet18", "mobilenetv2"], value="densenet121", label="Select Model")
    ],
    outputs=[gr.Label(num_top_classes=1), gr.HTML()],
    title="Animal Classification",
    description="Upload an image to get the predicted label",
    allow_flagging="never",
)

# Launch the interface
iface.launch()