Spaces:
Paused
Paused
import gradio as gr | |
import torch | |
import torchvision | |
from PIL import Image | |
from torchvision import models | |
from torch import nn | |
from typing import List | |
import json | |
import pandas as pd | |
#Read labels file1 | |
with open('cat_to_name.json','r') as f: | |
cat_to_name = json.load(f) | |
#Update last layer of model | |
def set_parameter_requires_grad(model, feature_extracting): | |
if feature_extracting: | |
for param in model.parameters(): | |
param.requires_grad = False | |
def update_last_layer_pretrained_model(pretrained_model, num_classes, feature_extract): | |
set_parameter_requires_grad(pretrained_model, feature_extract) | |
if hasattr(pretrained_model, 'fc') and 'resnet' in pretrained_model.__class__.__name__.lower(): #resnet | |
num_ftrs = pretrained_model.fc.in_features | |
pretrained_model.fc = nn.Linear(num_ftrs, num_classes, bias = True) | |
elif hasattr(pretrained_model, 'classifier') and ('alexnet' in pretrained_model.__class__.__name__.lower() or 'vgg' in pretrained_model.__class__.__name__.lower()): #alexNet, vgg | |
num_ftrs = pretrained_model.classifier[6].in_features | |
pretrained_model.classifier[6] = nn.Linear(num_ftrs, num_classes, bias = True) | |
elif hasattr(pretrained_model, 'classifier') and 'squeezenet' in pretrained_model.__class__.__name__.lower(): #squeezenet | |
pretrained_model.classifier[1] = nn.Conv2d(512, num_classes, kernel_size=(1,1), stride=(1,1)) | |
pretrained_model.num_classes = num_classes | |
elif hasattr(pretrained_model, 'classifier') and ('efficientnet' in pretrained_model.__class__.__name__.lower() or 'mobilenet' in pretrained_model.__class__.__name__.lower()): #efficientnet, mobilenet | |
num_ftrs = pretrained_model.classifier[1].in_features | |
pretrained_model.classifier[1] = nn.Linear(num_ftrs, num_classes, bias = True) | |
elif hasattr(pretrained_model, 'AuxLogits') and 'inception' in pretrained_model.__class__.__name__.lower(): #inception | |
num_ftrs = pretrained_model.AuxLogits.fc.in_features | |
pretrained_model.AuxLogits.fc = nn.Linear(num_ftrs, num_classes) #Auxilary net | |
num_ftrs = pretrained_model.fc.in_features | |
pretrained_model.fc = nn.Linear(num_ftrs,num_classes) #Primary net | |
elif hasattr(pretrained_model, 'classifier') and 'densenet' in pretrained_model.__class__.__name__.lower(): #densenet | |
num_ftrs = pretrained_model.classifier.in_features | |
pretrained_model.classifier = nn.Linear(num_ftrs, num_classes, bias = True) | |
elif hasattr(pretrained_model, 'heads') and 'visiontransformer' in pretrained_model.__class__.__name__.lower(): #vit transformer | |
num_ftrs = pretrained_model.heads.head.in_features | |
pretrained_model.heads.head = nn.Linear(num_ftrs, num_classes, bias = True) | |
elif hasattr(pretrained_model, 'head') and 'swin' in pretrained_model.__class__.__name__.lower(): #swin transformer | |
num_ftrs = pretrained_model.head.in_features | |
pretrained_model.head = nn.Linear(num_ftrs, num_classes, bias = True) | |
return pretrained_model | |
#pred_image | |
def pred_image(model, image_path, class_names = None, transform=None, device: torch.device = "cuda" if torch.cuda.is_available() else "cpu"): | |
target_image = Image.open(image_path) | |
if transform: | |
target_image = transform(target_image) | |
model.to(device) | |
model.eval() | |
with torch.inference_mode(): | |
target_image = target_image.unsqueeze(dim=0) | |
target_image_pred = model(target_image.to(device)) | |
target_image_pred_probs = torch.softmax(target_image_pred, dim=1) | |
ps = target_image_pred_probs.topk(3) | |
ps_numpy = ps[0].cpu().numpy()[0] | |
idxs = [class_names[i] for i in ps[1].numpy()[0]] if class_names else ps[1].numpy()[0] | |
return (ps_numpy, idxs) | |
def process_input(image_path): | |
#Load Model | |
list_of_models_and_weights = [ | |
('efficientnet_b2','EfficientNet_B2_Weights','flowers_efficientnet_b2_model.pth'), | |
('alexnet','AlexNet_Weights','flowers_alexnet_model.pth') | |
#('mobilenet_v2','MobileNet_V2_Weights','flowers_mobilenet_v2_model.pth'), | |
#('densenet121','DenseNet121_Weights','flowers_densenet121_model.pth'), | |
#('inception_v3','Inception_V3_Weights','flowers_inception_v3_model.pth'), | |
#('squeezenet1_1','SqueezeNet1_1_Weights','flowers_squeezenet1_1_model.pth'), | |
#('vgg16','VGG16_Weights','flowers_vgg16_model.pth'), | |
#('resnet18','ResNet18_Weights','flowers_resnet18_model.pth'), | |
#('swin_b','Swin_B_Weights','flowers_swin_b_model.pth'), | |
#('vit_b_16', 'ViT_B_16_Weights','flowers_vit_b_16_model.pth') | |
] | |
#model_name, model_weights, model_path = ('efficientnet_b2','EfficientNet_B2_Weights','flowers_efficientnet_b2_model.pth') | |
#model_name, model_weights, model_path = ('alexnet','AlexNet_Weights','flowers_alexnet_model.pth') | |
list_of_outputs = [] | |
for model_name, model_weights, model_path in list_of_models_and_weights: | |
checkpoint = torch.load(model_path, map_location='cpu') | |
pretrained_weights = eval(f"models.{model_weights}.DEFAULT") | |
auto_transforms = pretrained_weights.transforms() | |
#pretrained_model = eval(f"torchvision.models.{model_name}(weights = pretrained_weights)") | |
pretrained_model = eval(f"models.{model_name}(pretrained = True)") | |
pretrained_model = update_last_layer_pretrained_model(pretrained_model, 102, True) | |
pretrained_model.class_to_idx = checkpoint['class_to_idx'] | |
pretrained_model.class_names = checkpoint['class_names'] | |
pretrained_model.load_state_dict(checkpoint['state_dict']) | |
pretrained_model.to('cpu') | |
#Predict | |
#image_path = 'which-flower/80_image_02020.jpg' | |
probs, idxs = pred_image(model=pretrained_model, image_path=image_path, class_names=pretrained_model.class_names, transform=auto_transforms) | |
names = [cat_to_name[i] for i in idxs] | |
list_of_outputs.append({"Prediction" : names[0], "Probability" : probs[0]}) | |
#Display or return to main function | |
#print({names[i]: float(probs[i]) for i in range(len(names))}) | |
#return {names[i]: float(probs[i]) for i in range(len(names))} | |
#oldreturn {names[i]: float(probs[i]) for i in range(len(names))} | |
print(pd.DataFrame(list_of_outputs)) | |
return pd.DataFrame(list_of_outputs) | |
examples = ['16_image_06670.jpg','33_image_06460.jpg','80_image_02020.jpg', 'Flowers.png','inference_example.png'] | |
title = "Image Classifier - Species of Flower predicted by different Models" | |
description = "Image classifiers to recognize different species of flowers trained on 102 Category Flower Dataset" | |
article = article="<p style='text-align: center'><a href='https://www.robots.ox.ac.uk/~vgg/data/flowers/102/index.html' target='_blank'>Source 102 Flower Dataset</a></p>" | |
interpretation = 'default' | |
enable_queue = True | |
iface = gr.Interface(fn=process_input, inputs=gr.inputs.Image(type='filepath'), outputs="dataframe", examples = examples, title=title, description=description,article=article,interpretation=interpretation, enable_queue=enable_queue) | |
iface.launch() | |
#(num_top_classes=3)q |