Spaces:
Sleeping
Sleeping
File size: 4,255 Bytes
e8a6c24 fa48c72 02654db fa48c72 c152ea8 e8a6c24 fa48c72 139ad76 fda2cd7 fa48c72 83b0bb3 02654db b73a2e9 59b0a85 b73a2e9 59b0a85 b73a2e9 59b0a85 b73a2e9 02654db fa48c72 83b0bb3 bb8f6a5 4f172d5 bb8f6a5 83b0bb3 bb8f6a5 c152ea8 bb8f6a5 fa48c72 bb8f6a5 fa48c72 5bbab1c bb8f6a5 5bbab1c bb8f6a5 883f925 2c24524 139ad76 bb8f6a5 fda2cd7 fa48c72 83b0bb3 fa48c72 fda2cd7 fa48c72 fda2cd7 fa48c72 83b0bb3 fa48c72 bb8f6a5 139ad76 bb8f6a5 066233a fa48c72 139ad76 fa48c72 2a995e9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
import gradio as gr
import torch
import torch.nn as nn
from torchvision import transforms
import torchvision.models as models
from PIL import Image
import numpy as np
# Function to load class names from a file
def load_class_names(file_path):
with open(file_path, 'r') as f:
class_names = [line.strip() for line in f.readlines()]
return class_names
# Function to load Wikipedia links from a file
def load_class_info(file_path):
with open(file_path, 'r') as f:
class_info = [line.strip() for line in f.readlines()]
return class_info
# Function to load the model from a .pkl file
def load_model(model_path, model_type):
# Load the model state dictionary
model_state_dict = torch.load(model_path, map_location=torch.device('cpu'))
# Create an instance of the model based on model_type
if model_type == 'mobilenetv2':
model = models.mobilenet_v2(pretrained=False)
model.classifier[1] = nn.Linear(model.last_channel, num_classes)
elif model_type == 'resnet18':
model = models.resnet18(pretrained=False)
model.fc = nn.Linear(model.fc.in_features, num_classes)
elif model_type == 'densenet121':
model = models.densenet121(pretrained=False)
model.classifier = nn.Linear(model.classifier.in_features, num_classes)
else:
raise ValueError(f"Unsupported model type: {model_type}")
# Load the state dictionary into the model
model.load_state_dict(model_state_dict)
# Set the model to evaluation mode
model.eval()
return model
# Define the transformation
val_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# Define the prediction function
def predict_image(image, model_choice):
global model, current_model
# Check if a model is selected
if model_choice not in model_paths:
return "Error: Please select a valid model.", ""
# Check if an image is provided
if image is None:
return "Error: Please upload an image.", ""
# Load the selected model if it's not already loaded
if model_choice != current_model:
model_path = model_paths[model_choice]
model = load_model(model_path, model_choice)
current_model = model_choice
# Convert the NumPy array to a PIL Image if needed
if isinstance(image, np.ndarray):
image = Image.fromarray(image.astype('uint8'), 'RGB')
image = val_transform(image)
image = image.unsqueeze(0) # Add batch dimension
with torch.no_grad():
outputs = model(image)
probabilities = torch.nn.functional.softmax(outputs[0], dim=0)
confidence, preds = torch.max(probabilities, 0)
confidence_score = confidence.item() * 100
if confidence_score < 30:
result = "Not identified"
html_result = ""
else:
class_name = class_names[preds.item()]
wiki_link = class_info[preds.item()]
result = f"{class_name}: {confidence_score:.2f}%"
html_result = f"<h1><br><a href='{wiki_link}' target='_blank'>More Info</a></h1>"
return result, html_result
# Load class names and class info
class_file_path = 'classes.txt'
class_info_path = 'classinfo.txt'
class_names = load_class_names(class_file_path)
class_info = load_class_info(class_info_path)
num_classes = len(class_names)
# Define model paths
model_paths = {
'densenet121': 'densenet121_15EpochsPretrainedNoExtractionNoLR_model.pkl',
'resnet18': 'resnet18_25EpochsPretrainedExtractionNoLR_model.pkl',
'mobilenetv2': 'mobilenetv2_25EpochsPretrainedExtractionNoLR_model.pkl'
}
# Set default model
current_model = 'densenet121'
model = load_model(model_paths[current_model], current_model)
# Create the Gradio interface
iface = gr.Interface(
fn=predict_image,
inputs=[
gr.Image(height=500),
gr.Dropdown(choices=["densenet121", "resnet18", "mobilenetv2"], value="densenet121", label="Select Model")
],
outputs=[gr.Label(num_top_classes=1), gr.HTML()],
title="Animal Classification",
description="Upload an image to get the predicted label",
allow_flagging="never",
)
# Launch the interface
iface.launch()
|