Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import torch.nn as nn | |
from torchvision import transforms | |
import torchvision.models as models | |
from PIL import Image | |
import numpy as np | |
# Function to load class names from a file | |
def load_class_names(file_path): | |
with open(file_path, 'r') as f: | |
class_names = [line.strip() for line in f.readlines()] | |
return class_names | |
# Function to load Wikipedia links from a file | |
def load_class_info(file_path): | |
with open(file_path, 'r') as f: | |
class_info = [line.strip() for line in f.readlines()] | |
return class_info | |
# Function to load the model from a .pkl file | |
def load_model(model_path, model_type='resnet'): | |
# Load the model state dictionary | |
model_state_dict = torch.load(model_path, map_location=torch.device('cpu')) | |
# Create an instance of the model based on model_type | |
if model_type == 'mobilenet': | |
model = models.mobilenet_v2(pretrained=False) | |
model.classifier[1] = nn.Linear(model.last_channel, num_classes) | |
elif model_type == 'resnet': | |
model = models.resnet50(pretrained=False) | |
model.fc = nn.Linear(model.fc.in_features, num_classes) | |
elif model_type == 'densenet': | |
model = models.densenet121(pretrained=False) | |
model.classifier = nn.Linear(model.classifier.in_features, num_classes) | |
else: | |
raise ValueError(f"Unsupported model type: {model_type}") | |
# Load the state dictionary into the model | |
model.load_state_dict(model_state_dict) | |
# Set the model to evaluation mode | |
model.eval() | |
return model | |
# Define the transformation | |
val_transform = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
]) | |
# Define the prediction function | |
def predict_image(image): | |
# Convert the NumPy array to a PIL Image | |
if isinstance(image, np.ndarray): | |
image = Image.fromarray(image.astype('uint8'), 'RGB') | |
image = val_transform(image) | |
image = image.unsqueeze(0) # Add batch dimension | |
with torch.no_grad(): | |
outputs = model(image) | |
probabilities = torch.nn.functional.softmax(outputs[0], dim=0) | |
confidence, preds = torch.max(probabilities, 0) | |
confidence_score = confidence.item() * 100 | |
if confidence_score < 30: | |
result = "Not identified" | |
html_result = "" | |
else: | |
class_name = class_names[preds.item()] | |
wiki_link = class_info[preds.item()] | |
result = f"{class_name}: {confidence_score:.2f}%" | |
html_result = f"<h1><br><a href='{wiki_link}' target='_blank'>More Info</a></h1>" | |
return result, html_result | |
# Load trained model and class names | |
model_path = 'resnet30EpochsPretrainedNFeatureX_model.pkl' | |
class_file_path = 'classes.txt' | |
class_info_path = 'classinfo.txt' | |
class_names = load_class_names(class_file_path) | |
class_info = load_class_info(class_info_path) | |
num_classes = len(class_names) | |
model = load_model(model_path) | |
# Create the Gradio interface | |
iface = gr.Interface( | |
fn=predict_image, | |
inputs=gr.Image(), | |
outputs=[gr.Label(num_top_classes=1), gr.HTML()], | |
title="Image Classification", | |
description="Upload an image to get the predicted label", | |
allow_flagging="never", | |
) | |
# Launch the interface | |
iface.launch() |