which-flower / app.py
rdkulkarni's picture
Update app.py
5e5d09e
import gradio as gr
import torch
import torchvision
from PIL import Image
from torchvision import models
from torch import nn
from typing import List
import json
import pandas as pd
#Read labels file1
with open('cat_to_name.json','r') as f:
cat_to_name = json.load(f)
#Update last layer of model
def set_parameter_requires_grad(model, feature_extracting):
if feature_extracting:
for param in model.parameters():
param.requires_grad = False
def update_last_layer_pretrained_model(pretrained_model, num_classes, feature_extract):
set_parameter_requires_grad(pretrained_model, feature_extract)
if hasattr(pretrained_model, 'fc') and 'resnet' in pretrained_model.__class__.__name__.lower(): #resnet
num_ftrs = pretrained_model.fc.in_features
pretrained_model.fc = nn.Linear(num_ftrs, num_classes, bias = True)
elif hasattr(pretrained_model, 'classifier') and ('alexnet' in pretrained_model.__class__.__name__.lower() or 'vgg' in pretrained_model.__class__.__name__.lower()): #alexNet, vgg
num_ftrs = pretrained_model.classifier[6].in_features
pretrained_model.classifier[6] = nn.Linear(num_ftrs, num_classes, bias = True)
elif hasattr(pretrained_model, 'classifier') and 'squeezenet' in pretrained_model.__class__.__name__.lower(): #squeezenet
pretrained_model.classifier[1] = nn.Conv2d(512, num_classes, kernel_size=(1,1), stride=(1,1))
pretrained_model.num_classes = num_classes
elif hasattr(pretrained_model, 'classifier') and ('efficientnet' in pretrained_model.__class__.__name__.lower() or 'mobilenet' in pretrained_model.__class__.__name__.lower()): #efficientnet, mobilenet
num_ftrs = pretrained_model.classifier[1].in_features
pretrained_model.classifier[1] = nn.Linear(num_ftrs, num_classes, bias = True)
elif hasattr(pretrained_model, 'AuxLogits') and 'inception' in pretrained_model.__class__.__name__.lower(): #inception
num_ftrs = pretrained_model.AuxLogits.fc.in_features
pretrained_model.AuxLogits.fc = nn.Linear(num_ftrs, num_classes) #Auxilary net
num_ftrs = pretrained_model.fc.in_features
pretrained_model.fc = nn.Linear(num_ftrs,num_classes) #Primary net
elif hasattr(pretrained_model, 'classifier') and 'densenet' in pretrained_model.__class__.__name__.lower(): #densenet
num_ftrs = pretrained_model.classifier.in_features
pretrained_model.classifier = nn.Linear(num_ftrs, num_classes, bias = True)
elif hasattr(pretrained_model, 'heads') and 'visiontransformer' in pretrained_model.__class__.__name__.lower(): #vit transformer
num_ftrs = pretrained_model.heads.head.in_features
pretrained_model.heads.head = nn.Linear(num_ftrs, num_classes, bias = True)
elif hasattr(pretrained_model, 'head') and 'swin' in pretrained_model.__class__.__name__.lower(): #swin transformer
num_ftrs = pretrained_model.head.in_features
pretrained_model.head = nn.Linear(num_ftrs, num_classes, bias = True)
return pretrained_model
#pred_image
def pred_image(model, image_path, class_names = None, transform=None, device: torch.device = "cuda" if torch.cuda.is_available() else "cpu"):
target_image = Image.open(image_path)
if transform:
target_image = transform(target_image)
model.to(device)
model.eval()
with torch.inference_mode():
target_image = target_image.unsqueeze(dim=0)
target_image_pred = model(target_image.to(device))
target_image_pred_probs = torch.softmax(target_image_pred, dim=1)
ps = target_image_pred_probs.topk(3)
ps_numpy = ps[0].cpu().numpy()[0]
idxs = [class_names[i] for i in ps[1].numpy()[0]] if class_names else ps[1].numpy()[0]
return (ps_numpy, idxs)
def process_input(image_path):
#Load Model
list_of_models_and_weights = [
('efficientnet_b2','EfficientNet_B2_Weights','flowers_efficientnet_b2_model.pth'),
('alexnet','AlexNet_Weights','flowers_alexnet_model.pth')
#('mobilenet_v2','MobileNet_V2_Weights','flowers_mobilenet_v2_model.pth'),
#('densenet121','DenseNet121_Weights','flowers_densenet121_model.pth'),
#('inception_v3','Inception_V3_Weights','flowers_inception_v3_model.pth'),
#('squeezenet1_1','SqueezeNet1_1_Weights','flowers_squeezenet1_1_model.pth'),
#('vgg16','VGG16_Weights','flowers_vgg16_model.pth'),
#('resnet18','ResNet18_Weights','flowers_resnet18_model.pth'),
#('swin_b','Swin_B_Weights','flowers_swin_b_model.pth'),
#('vit_b_16', 'ViT_B_16_Weights','flowers_vit_b_16_model.pth')
]
#model_name, model_weights, model_path = ('efficientnet_b2','EfficientNet_B2_Weights','flowers_efficientnet_b2_model.pth')
#model_name, model_weights, model_path = ('alexnet','AlexNet_Weights','flowers_alexnet_model.pth')
list_of_outputs = []
for model_name, model_weights, model_path in list_of_models_and_weights:
checkpoint = torch.load(model_path, map_location='cpu')
pretrained_weights = eval(f"models.{model_weights}.DEFAULT")
auto_transforms = pretrained_weights.transforms()
#pretrained_model = eval(f"torchvision.models.{model_name}(weights = pretrained_weights)")
pretrained_model = eval(f"models.{model_name}(pretrained = True)")
pretrained_model = update_last_layer_pretrained_model(pretrained_model, 102, True)
pretrained_model.class_to_idx = checkpoint['class_to_idx']
pretrained_model.class_names = checkpoint['class_names']
pretrained_model.load_state_dict(checkpoint['state_dict'])
pretrained_model.to('cpu')
#Predict
#image_path = 'which-flower/80_image_02020.jpg'
probs, idxs = pred_image(model=pretrained_model, image_path=image_path, class_names=pretrained_model.class_names, transform=auto_transforms)
names = [cat_to_name[i] for i in idxs]
list_of_outputs.append({"Prediction" : names[0], "Probability" : probs[0]})
#Display or return to main function
#print({names[i]: float(probs[i]) for i in range(len(names))})
#return {names[i]: float(probs[i]) for i in range(len(names))}
#oldreturn {names[i]: float(probs[i]) for i in range(len(names))}
print(pd.DataFrame(list_of_outputs))
return pd.DataFrame(list_of_outputs)
examples = ['16_image_06670.jpg','33_image_06460.jpg','80_image_02020.jpg', 'Flowers.png','inference_example.png']
title = "Image Classifier - Species of Flower predicted by different Models"
description = "Image classifiers to recognize different species of flowers trained on 102 Category Flower Dataset"
article = article="<p style='text-align: center'><a href='https://www.robots.ox.ac.uk/~vgg/data/flowers/102/index.html' target='_blank'>Source 102 Flower Dataset</a></p>"
interpretation = 'default'
enable_queue = True
iface = gr.Interface(fn=process_input, inputs=gr.inputs.Image(type='filepath'), outputs="dataframe", examples = examples, title=title, description=description,article=article,interpretation=interpretation, enable_queue=enable_queue)
iface.launch()
#(num_top_classes=3)q