classifier / app.py
djairbee5's picture
added error handling
4f172d5
raw
history blame contribute delete
No virus
4.26 kB
import gradio as gr
import torch
import torch.nn as nn
from torchvision import transforms
import torchvision.models as models
from PIL import Image
import numpy as np
# Function to load class names from a file
def load_class_names(file_path):
with open(file_path, 'r') as f:
class_names = [line.strip() for line in f.readlines()]
return class_names
# Function to load Wikipedia links from a file
def load_class_info(file_path):
with open(file_path, 'r') as f:
class_info = [line.strip() for line in f.readlines()]
return class_info
# Function to load the model from a .pkl file
def load_model(model_path, model_type):
# Load the model state dictionary
model_state_dict = torch.load(model_path, map_location=torch.device('cpu'))
# Create an instance of the model based on model_type
if model_type == 'mobilenetv2':
model = models.mobilenet_v2(pretrained=False)
model.classifier[1] = nn.Linear(model.last_channel, num_classes)
elif model_type == 'resnet18':
model = models.resnet18(pretrained=False)
model.fc = nn.Linear(model.fc.in_features, num_classes)
elif model_type == 'densenet121':
model = models.densenet121(pretrained=False)
model.classifier = nn.Linear(model.classifier.in_features, num_classes)
else:
raise ValueError(f"Unsupported model type: {model_type}")
# Load the state dictionary into the model
model.load_state_dict(model_state_dict)
# Set the model to evaluation mode
model.eval()
return model
# Define the transformation
val_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# Define the prediction function
def predict_image(image, model_choice):
global model, current_model
# Check if a model is selected
if model_choice not in model_paths:
return "Error: Please select a valid model.", ""
# Check if an image is provided
if image is None:
return "Error: Please upload an image.", ""
# Load the selected model if it's not already loaded
if model_choice != current_model:
model_path = model_paths[model_choice]
model = load_model(model_path, model_choice)
current_model = model_choice
# Convert the NumPy array to a PIL Image if needed
if isinstance(image, np.ndarray):
image = Image.fromarray(image.astype('uint8'), 'RGB')
image = val_transform(image)
image = image.unsqueeze(0) # Add batch dimension
with torch.no_grad():
outputs = model(image)
probabilities = torch.nn.functional.softmax(outputs[0], dim=0)
confidence, preds = torch.max(probabilities, 0)
confidence_score = confidence.item() * 100
if confidence_score < 30:
result = "Not identified"
html_result = ""
else:
class_name = class_names[preds.item()]
wiki_link = class_info[preds.item()]
result = f"{class_name}: {confidence_score:.2f}%"
html_result = f"<h1><br><a href='{wiki_link}' target='_blank'>More Info</a></h1>"
return result, html_result
# Load class names and class info
class_file_path = 'classes.txt'
class_info_path = 'classinfo.txt'
class_names = load_class_names(class_file_path)
class_info = load_class_info(class_info_path)
num_classes = len(class_names)
# Define model paths
model_paths = {
'densenet121': 'densenet121_15EpochsPretrainedNoExtractionNoLR_model.pkl',
'resnet18': 'resnet18_25EpochsPretrainedExtractionNoLR_model.pkl',
'mobilenetv2': 'mobilenetv2_25EpochsPretrainedExtractionNoLR_model.pkl'
}
# Set default model
current_model = 'densenet121'
model = load_model(model_paths[current_model], current_model)
# Create the Gradio interface
iface = gr.Interface(
fn=predict_image,
inputs=[
gr.Image(height=500),
gr.Dropdown(choices=["densenet121", "resnet18", "mobilenetv2"], value="densenet121", label="Select Model")
],
outputs=[gr.Label(num_top_classes=1), gr.HTML()],
title="Animal Classification",
description="Upload an image to get the predicted label",
allow_flagging="never",
)
# Launch the interface
iface.launch()