import gradio as gr import torch import torchvision from PIL import Image from torchvision import models from torch import nn from typing import List import json import pandas as pd #Read labels file1 with open('cat_to_name.json','r') as f: cat_to_name = json.load(f) #Update last layer of model def set_parameter_requires_grad(model, feature_extracting): if feature_extracting: for param in model.parameters(): param.requires_grad = False def update_last_layer_pretrained_model(pretrained_model, num_classes, feature_extract): set_parameter_requires_grad(pretrained_model, feature_extract) if hasattr(pretrained_model, 'fc') and 'resnet' in pretrained_model.__class__.__name__.lower(): #resnet num_ftrs = pretrained_model.fc.in_features pretrained_model.fc = nn.Linear(num_ftrs, num_classes, bias = True) elif hasattr(pretrained_model, 'classifier') and ('alexnet' in pretrained_model.__class__.__name__.lower() or 'vgg' in pretrained_model.__class__.__name__.lower()): #alexNet, vgg num_ftrs = pretrained_model.classifier[6].in_features pretrained_model.classifier[6] = nn.Linear(num_ftrs, num_classes, bias = True) elif hasattr(pretrained_model, 'classifier') and 'squeezenet' in pretrained_model.__class__.__name__.lower(): #squeezenet pretrained_model.classifier[1] = nn.Conv2d(512, num_classes, kernel_size=(1,1), stride=(1,1)) pretrained_model.num_classes = num_classes elif hasattr(pretrained_model, 'classifier') and ('efficientnet' in pretrained_model.__class__.__name__.lower() or 'mobilenet' in pretrained_model.__class__.__name__.lower()): #efficientnet, mobilenet num_ftrs = pretrained_model.classifier[1].in_features pretrained_model.classifier[1] = nn.Linear(num_ftrs, num_classes, bias = True) elif hasattr(pretrained_model, 'AuxLogits') and 'inception' in pretrained_model.__class__.__name__.lower(): #inception num_ftrs = pretrained_model.AuxLogits.fc.in_features pretrained_model.AuxLogits.fc = nn.Linear(num_ftrs, num_classes) #Auxilary net num_ftrs = pretrained_model.fc.in_features pretrained_model.fc = nn.Linear(num_ftrs,num_classes) #Primary net elif hasattr(pretrained_model, 'classifier') and 'densenet' in pretrained_model.__class__.__name__.lower(): #densenet num_ftrs = pretrained_model.classifier.in_features pretrained_model.classifier = nn.Linear(num_ftrs, num_classes, bias = True) elif hasattr(pretrained_model, 'heads') and 'visiontransformer' in pretrained_model.__class__.__name__.lower(): #vit transformer num_ftrs = pretrained_model.heads.head.in_features pretrained_model.heads.head = nn.Linear(num_ftrs, num_classes, bias = True) elif hasattr(pretrained_model, 'head') and 'swin' in pretrained_model.__class__.__name__.lower(): #swin transformer num_ftrs = pretrained_model.head.in_features pretrained_model.head = nn.Linear(num_ftrs, num_classes, bias = True) return pretrained_model #pred_image def pred_image(model, image_path, class_names = None, transform=None, device: torch.device = "cuda" if torch.cuda.is_available() else "cpu"): target_image = Image.open(image_path) if transform: target_image = transform(target_image) model.to(device) model.eval() with torch.inference_mode(): target_image = target_image.unsqueeze(dim=0) target_image_pred = model(target_image.to(device)) target_image_pred_probs = torch.softmax(target_image_pred, dim=1) ps = target_image_pred_probs.topk(3) ps_numpy = ps[0].cpu().numpy()[0] idxs = [class_names[i] for i in ps[1].numpy()[0]] if class_names else ps[1].numpy()[0] return (ps_numpy, idxs) def process_input(image_path): #Load Model list_of_models_and_weights = [ ('efficientnet_b2','EfficientNet_B2_Weights','flowers_efficientnet_b2_model.pth'), ('alexnet','AlexNet_Weights','flowers_alexnet_model.pth') #('mobilenet_v2','MobileNet_V2_Weights','flowers_mobilenet_v2_model.pth'), #('densenet121','DenseNet121_Weights','flowers_densenet121_model.pth'), #('inception_v3','Inception_V3_Weights','flowers_inception_v3_model.pth'), #('squeezenet1_1','SqueezeNet1_1_Weights','flowers_squeezenet1_1_model.pth'), #('vgg16','VGG16_Weights','flowers_vgg16_model.pth'), #('resnet18','ResNet18_Weights','flowers_resnet18_model.pth'), #('swin_b','Swin_B_Weights','flowers_swin_b_model.pth'), #('vit_b_16', 'ViT_B_16_Weights','flowers_vit_b_16_model.pth') ] #model_name, model_weights, model_path = ('efficientnet_b2','EfficientNet_B2_Weights','flowers_efficientnet_b2_model.pth') #model_name, model_weights, model_path = ('alexnet','AlexNet_Weights','flowers_alexnet_model.pth') list_of_outputs = [] for model_name, model_weights, model_path in list_of_models_and_weights: checkpoint = torch.load(model_path, map_location='cpu') pretrained_weights = eval(f"models.{model_weights}.DEFAULT") auto_transforms = pretrained_weights.transforms() #pretrained_model = eval(f"torchvision.models.{model_name}(weights = pretrained_weights)") pretrained_model = eval(f"models.{model_name}(pretrained = True)") pretrained_model = update_last_layer_pretrained_model(pretrained_model, 102, True) pretrained_model.class_to_idx = checkpoint['class_to_idx'] pretrained_model.class_names = checkpoint['class_names'] pretrained_model.load_state_dict(checkpoint['state_dict']) pretrained_model.to('cpu') #Predict #image_path = 'which-flower/80_image_02020.jpg' probs, idxs = pred_image(model=pretrained_model, image_path=image_path, class_names=pretrained_model.class_names, transform=auto_transforms) names = [cat_to_name[i] for i in idxs] list_of_outputs.append({"Prediction" : names[0], "Probability" : probs[0]}) #Display or return to main function #print({names[i]: float(probs[i]) for i in range(len(names))}) #return {names[i]: float(probs[i]) for i in range(len(names))} #oldreturn {names[i]: float(probs[i]) for i in range(len(names))} print(pd.DataFrame(list_of_outputs)) return pd.DataFrame(list_of_outputs) examples = ['16_image_06670.jpg','33_image_06460.jpg','80_image_02020.jpg', 'Flowers.png','inference_example.png'] title = "Image Classifier - Species of Flower predicted by different Models" description = "Image classifiers to recognize different species of flowers trained on 102 Category Flower Dataset" article = article="
" interpretation = 'default' enable_queue = True iface = gr.Interface(fn=process_input, inputs=gr.inputs.Image(type='filepath'), outputs="dataframe", examples = examples, title=title, description=description,article=article,interpretation=interpretation, enable_queue=enable_queue) iface.launch() #(num_top_classes=3)q