which-flower / app.py
rdkulkarni's picture
Update app.py
6d36038
raw
history blame
9.37 kB
import gradio as gr
import torch
from torch import nn
import torchvision
import torchvision.models as models
print(torchvision.__version__)
from PIL import Image
import json
import skimage
import numpy as np
import math
from typing import List
print(models.AlexNet_Weights.DEFAULT)
def pred_and_plot_image(
model: torch.nn.Module,
image_path: str,
class_names: List[str] = None,
transform=None,
device: torch.device = "cuda" if torch.cuda.is_available() else "cpu",
):
# 1. Load in image and convert the tensor values to float32
target_image = torchvision.io.read_image(str(image_path)).type(torch.float32)
# 2. Divide the image pixel values by 255 to get them between [0, 1]
target_image = target_image / 255.0
# 3. Transform if necessary
if transform:
target_image = transform(target_image)
# 4. Make sure the model is on the target device
model.to(device)
# 5. Turn on model evaluation mode and inference mode
model.eval()
with torch.inference_mode():
# Add an extra dimension to the image
target_image = target_image.unsqueeze(dim=0)
# Make a prediction on image with an extra dimension and send it to the target device
target_image_pred = model(target_image.to(device))
# 6. Convert logits -> prediction probabilities (using torch.softmax() for multi-class classification)
target_image_pred_probs = torch.softmax(target_image_pred, dim=1)
# 7. Convert prediction probabilities -> prediction labels
target_image_pred_label = torch.argmax(target_image_pred_probs, dim=1)
if class_names:
idxs= class_names[target_image_pred_label.cpu()]
else:
idxs = target_image_pred_label
ps = target_image_pred_probs.max().cpu()
print(ps)
print(idxs)
return (ps, idxs)
def set_parameter_requires_grad(model, feature_extracting):
if feature_extracting:
for param in model.parameters():
param.requires_grad = False
def update_last_layer_pretrained_model(pretrained_model, num_classes, feature_extract):
set_parameter_requires_grad(pretrained_model, feature_extract)
if hasattr(pretrained_model, 'fc') and 'resnet' in pretrained_model.__class__.__name__.lower(): #resnet
num_ftrs = pretrained_model.fc.in_features
pretrained_model.fc = nn.Linear(num_ftrs, num_classes, bias = True)
elif hasattr(pretrained_model, 'classifier') and ('alexnet' in pretrained_model.__class__.__name__.lower() or 'vgg' in pretrained_model.__class__.__name__.lower()): #alexNet, vgg
num_ftrs = pretrained_model.classifier[6].in_features
pretrained_model.classifier[6] = nn.Linear(num_ftrs, num_classes, bias = True)
elif hasattr(pretrained_model, 'classifier') and 'squeezenet' in pretrained_model.__class__.__name__.lower(): #squeezenet
pretrained_model.classifier[1] = nn.Conv2d(512, num_classes, kernel_size=(1,1), stride=(1,1))
pretrained_model.num_classes = num_classes
elif hasattr(pretrained_model, 'classifier') and ('efficientnet' in pretrained_model.__class__.__name__.lower() or 'mobilenet' in pretrained_model.__class__.__name__.lower()): #efficientnet, mobilenet
num_ftrs = pretrained_model.classifier[1].in_features
pretrained_model.classifier[1] = nn.Linear(num_ftrs, num_classes, bias = True)
elif hasattr(pretrained_model, 'AuxLogits') and 'inception' in pretrained_model.__class__.__name__.lower(): #inception
num_ftrs = pretrained_model.AuxLogits.fc.in_features
pretrained_model.AuxLogits.fc = nn.Linear(num_ftrs, num_classes) #Auxilary net
num_ftrs = pretrained_model.fc.in_features
pretrained_model.fc = nn.Linear(num_ftrs,num_classes) #Primary net
elif hasattr(pretrained_model, 'classifier') and 'densenet' in pretrained_model.__class__.__name__.lower(): #densenet
num_ftrs = pretrained_model.classifier.in_features
pretrained_model.classifier = nn.Linear(num_ftrs, num_classes, bias = True)
elif hasattr(pretrained_model, 'heads') and 'visiontransformer' in pretrained_model.__class__.__name__.lower(): #vit transformer
num_ftrs = pretrained_model.heads.head.in_features
pretrained_model.heads.head = nn.Linear(num_ftrs, num_classes, bias = True)
elif hasattr(pretrained_model, 'head') and 'swin' in pretrained_model.__class__.__name__.lower(): #swin transformer
num_ftrs = pretrained_model.head.in_features
pretrained_model.head = nn.Linear(num_ftrs, num_classes, bias = True)
return pretrained_model
def process_image(image_path):
im = Image.open(image_path)
# Resize
if im.size[1] < im.size[0]:
im.thumbnail((255, math.pow(255, 2)))
else:
im.thumbnail((math.pow(255, 2), 255))
#Crop
width, height = im.size
left = (width - 224)/2
top = (height - 224)/2
right = (width + 224)/2
bottom = (height + 224)/2
im = im.crop((left, top, right, bottom))
#Convert to np.array
np_image = np.array(im)/255
#Undo Mean, Standard Deviation and Transpose
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
np_image = (np_image - mean)/ std
np_image = np.transpose(np_image, (2, 0, 1))
return np_image
def predict(image_path, model, topk=5):
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
tensor_img = torch.FloatTensor(process_image(image_path))
tensor_img = tensor_img.unsqueeze(0)
tensor_img = tensor_img.to(device)
log_ps = model(tensor_img)
result = log_ps.topk(topk)
#target_image_pred_probs = torch.softmax(target_image_pred, dim=1)
#target_image_pred_label = torch.argmax(target_image_pred_probs, dim=1)
print(result)
if torch.cuda.is_available(): #gpu Move it from gpu to cpu for numpy
ps = torch.exp(result[0].data).cpu().numpy()[0]
idxs = result[1].data.cpu().numpy()[0]
else: #cpu Keep it on cpu for nump
ps = torch.exp(result[0].data).numpy()[0]
idxs = result[1].data.numpy()[0]
return (ps, idxs)
def process_input(image_path):
#list_of_models_weights_paths = [('densenet121','DenseNet121_Weights','checkpoint-densenet121.pth')] #[('swin_b','Swin_B'), ('vit_b_16', 'ViT_B_16')]
#list_of_models_weights_paths = [('alexnet','AlexNet_Weights','flowers_alexnet_model.pth'), ('resnet18','ResNet18_Weights','flowers_resnet18_model.pth')] #[('swin_b','Swin_B'), ('vit_b_16', 'ViT_B_16')]
#list_of_predictions = []
#for model in list_of_models_weights_paths:
#Load saved model
model_name, model_weights, model_path = ('alexnet','AlexNet_Weights','flowers_alexnet_model.pth')
checkpoint = torch.load(model_path, map_location='cpu')
pretrained_weights = eval(f"torchvision.models.{model_weights}.DEFAULT")
auto_transforms = pretrained_weights.transforms()
#pretrained_model = eval(f"torchvision.models.{model_name}(weights = pretrained_weights)")
pretrained_model = eval(f"models.{model_name}(pretrained = True)")
pretrained_model = update_last_layer_pretrained_model(pretrained_model, 102, True)
#pretrained_model.class_to_idx = checkpoint['class_to_idx']
pretrained_model.class_names = checkpoint['class_names']
pretrained_model.load_state_dict(checkpoint['state_dict'])
pretrained_model.to('cpu')
#probs, idxs = predict(image_path, model = pretrained_model, topk = 5)
#
probs, idxs = pred_and_plot_image(model=pretrained_model,
image_path=image_path,
class_names=pretrained_model.class_names,
transform=auto_transforms)
print(probs)
print(idxs)
return { idxs : float(probs) }
#append to list to be returned at end
#response = {names[i]: float(probs[i]) for i in range(len(names))}
#list_of_predictions.append(response)
#return list_of_predictions
examples = ['16_image_06670.jpg','33_image_06460.jpg','80_image_02020.jpg', 'Flowers.png','inference_example.png']
title = "Image Classifier - Species of Flower predicted by different Models"
description = "Image classifiers to recognize different species of flowers trained on 102 Category Flower Dataset"
article = article="<p style='text-align: center'><a href='https://www.robots.ox.ac.uk/~vgg/data/flowers/102/index.html' target='_blank'>Source 102 Flower Dataset</a></p>"
interpretation = 'default'
enable_queue = True
iface = gr.Interface(fn=process_input, inputs=gr.inputs.Image(type='filepath'), outputs=gr.outputs.Label(num_top_classes=3), examples = examples,
title=title, description=description,article=article,interpretation=interpretation, enable_queue=enable_queue
)
iface.launch()
# Swap class to index mapping with index to class mapping and then map the classes to flower category labels using the json file
#idx_to_class = {v: k for k, v in pretrained_model.class_to_idx.items()}
#with open('cat_to_name.json','r') as f:
# cat_to_name = json.load(f)
#names = list(map(lambda x: cat_to_name[f"{idx_to_class[x]}"],idxs))
#return names, probs
#return { idxs[i].item() : float(probs[i].item()) for i in range(len(idxs))}