classifier / app.py
djairbee5's picture
more info displayed
fda2cd7
raw
history blame
3.21 kB
# Install required packages
import gradio as gr
import torch
import torch.nn as nn
from torchvision import transforms
import torchvision.models as models
from PIL import Image
import numpy as np
# Function to load class names from a file
def load_class_names(file_path):
with open(file_path, 'r') as f:
class_names = [line.strip() for line in f.readlines()]
return class_names
# Function to load Wikipedia links from a file
def load_class_info(file_path):
with open(file_path, 'r') as f:
class_info = [line.strip() for line in f.readlines()]
return class_info
# Function to load the model from a .pkl file
def load_model(model_path, model_type='resnet'):
# Load the model state dictionary
model_state_dict = torch.load(model_path, map_location=torch.device('cpu'))
# Create an instance of the model based on model_type
if model_type == 'mobilenet':
model = models.mobilenet_v2(pretrained=False)
model.classifier[1] = nn.Linear(model.last_channel, num_classes)
elif model_type == 'resnet':
model = models.resnet50(pretrained=False)
model.fc = nn.Linear(model.fc.in_features, num_classes)
elif model_type == 'densenet':
model = models.densenet121(pretrained=False)
model.classifier = nn.Linear(model.classifier.in_features, num_classes)
else:
raise ValueError(f"Unsupported model type: {model_type}")
# Load the state dictionary into the model
model.load_state_dict(model_state_dict)
# Set the model to evaluation mode
model.eval()
return model
# Define the transformation
val_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# Define the prediction function
def predict_image(image):
# Convert the NumPy array to a PIL Image
if isinstance(image, np.ndarray):
image = Image.fromarray(image.astype('uint8'), 'RGB')
image = val_transform(image)
image = image.unsqueeze(0) # Add batch dimension
with torch.no_grad():
outputs = model(image)
probabilities = torch.nn.functional.softmax(outputs[0], dim=0)
confidence, preds = torch.max(probabilities, 0)
class_name = class_names[preds.item()]
wiki_link = class_info[preds.item()]
confidence_score = confidence.item() * 100
result = f"{class_name}: {confidence_score:.2f}%"
html_result = f"<h1><br><a href='{wiki_link}' target='_blank'>More Info</a></h1>"
return result, html_result
# Load trained model and class names
model_path = 'resnet30EpochsPretrainedNFeatureX_model.pkl'
class_file_path = 'classes.txt'
class_info_path = 'classinfo.txt'
class_names = load_class_names(class_file_path)
class_info = load_class_info(class_info_path)
num_classes = len(class_names)
model = load_model(model_path)
# Create the Gradio interface
iface = gr.Interface(
fn=predict_image,
inputs=gr.Image(),
outputs=[gr.Label(num_top_classes=1),gr.HTML()],
title="Image Classification",
description="Upload an image to get the predicted label",
allow_flagging="never",
)
# Launch the interface
iface.launch()