classifier / app.py
djairbee5's picture
fix
883f925
raw
history blame
3.31 kB
import gradio as gr
import torch
import torch.nn as nn
from torchvision import transforms
import torchvision.models as models
from PIL import Image
import numpy as np
# Function to load class names from a file
def load_class_names(file_path):
with open(file_path, 'r') as f:
class_names = [line.strip() for line in f.readlines()]
return class_names
# Function to load Wikipedia links from a file
def load_class_info(file_path):
with open(file_path, 'r') as f:
class_info = [line.strip() for line in f.readlines()]
return class_info
# Function to load the model from a .pkl file
def load_model(model_path, model_type='resnet'):
# Load the model state dictionary
model_state_dict = torch.load(model_path, map_location=torch.device('cpu'))
# Create an instance of the model based on model_type
if model_type == 'mobilenet':
model = models.mobilenet_v2(pretrained=False)
model.classifier[1] = nn.Linear(model.last_channel, num_classes)
elif model_type == 'resnet':
model = models.resnet50(pretrained=False)
model.fc = nn.Linear(model.fc.in_features, num_classes)
elif model_type == 'densenet':
model = models.densenet121(pretrained=False)
model.classifier = nn.Linear(model.classifier.in_features, num_classes)
else:
raise ValueError(f"Unsupported model type: {model_type}")
# Load the state dictionary into the model
model.load_state_dict(model_state_dict)
# Set the model to evaluation mode
model.eval()
return model
# Define the transformation
val_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# Define the prediction function
def predict_image(image):
# Convert the NumPy array to a PIL Image
if isinstance(image, np.ndarray):
image = Image.fromarray(image.astype('uint8'), 'RGB')
image = val_transform(image)
image = image.unsqueeze(0) # Add batch dimension
with torch.no_grad():
outputs = model(image)
probabilities = torch.nn.functional.softmax(outputs[0], dim=0)
confidence, preds = torch.max(probabilities, 0)
confidence_score = confidence.item() * 100
if confidence_score < 30:
result = "Not identified"
html_result = ""
else:
class_name = class_names[preds.item()]
wiki_link = class_info[preds.item()]
result = f"{class_name}: {confidence_score:.2f}%"
html_result = f"<h1><br><a href='{wiki_link}' target='_blank'>More Info</a></h1>"
return result, html_result
# Load trained model and class names
model_path = 'resnet30EpochsPretrainedNFeatureX_model.pkl'
class_file_path = 'classes.txt'
class_info_path = 'classinfo.txt'
class_names = load_class_names(class_file_path)
class_info = load_class_info(class_info_path)
num_classes = len(class_names)
model = load_model(model_path)
# Create the Gradio interface
iface = gr.Interface(
fn=predict_image,
inputs=gr.Image(),
outputs=[gr.Label(num_top_classes=1), gr.HTML()],
title="Image Classification",
description="Upload an image to get the predicted label",
allow_flagging="never",
)
# Launch the interface
iface.launch()