import torch import torch.nn as nn import numpy as np from torchvision import models, transforms import time import os import copy import pickle from PIL import Image import datetime import gdown import urllib.request import gradio as gr import markdown # load model state and class names from gdrive # issue accessing the link. updated permissions #https://drive.google.com/file/d/1m9C-WMfKRDCmScxTh8JmcoFtymxAqjS3/view?usp=sharing url = 'https://drive.google.com/uc?id=1m9C-WMfKRDCmScxTh8JmcoFtymxAqjS3' path_class_names = "./class_names_restnet_leeds_butterfly.pkl" gdown.download(url, path_class_names, quiet=False) #https://drive.google.com/file/d/1qxaWnYwLIwWGrGg9uehG7h2W227SXGKq/view?usp=sharing url = 'https://drive.google.com/uc?id=1qxaWnYwLIwWGrGg9uehG7h2W227SXGKq' path_model = "./model_state_restnet_leeds_butterfly.pth" gdown.download(url, path_model, quiet=False) # example images url = "https://upload.wikimedia.org/wikipedia/commons/thumb/f/f8/Red_postman_butterfly_%28Heliconius_erato%29.jpg/1599px-Red_postman_butterfly_%28Heliconius_erato%29.jpg" path_input = "./h_erato.jpg" urllib.request.urlretrieve(url, filename=path_input) url = "https://upload.wikimedia.org/wikipedia/commons/thumb/6/63/Monarch_In_May.jpg/1024px-Monarch_In_May.jpg" path_input = "./d_plexippus.jpg" urllib.request.urlretrieve(url, filename=path_input) url = "https://drive.google.com/uc?id=1A7WgDrQ_RLO6JOQiYhkH_hj_EKcbpmOl" path_input = "./v_cardui.jpg" urllib.request.urlretrieve(url, filename=path_input) url = "https://drive.google.com/uc?id=1CiWShQYIm2N0fkVaWJpftlXZFqwjsXhA" path_input = "./p_cresphontes.jpg" urllib.request.urlretrieve(url, filename=path_input) url = "https://drive.google.com/uc?id=1r8rbkUwTSIZL0MQVgU-WjDGwvLXuwYPG" path_input = "./p_rapae.jpg" urllib.request.urlretrieve(url, filename=path_input) # normalisation data_transforms_test = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) class_names = pickle.load(open(path_class_names, "rb")) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model_ft = models.resnet18(pretrained=True) num_ftrs = model_ft.fc.in_features model_ft.fc = nn.Linear(num_ftrs, len(class_names)) model_ft = model_ft.to(device) model_ft.load_state_dict(copy.deepcopy(torch.load(path_model,device))) # Proper labeling id_to_name = { '001_Danaus Plexippus': 'Danaus plexippus - Monarch', '002_Heliconius Charitonius': 'Heliconius charitonius - Zebra Longwing', '003_Heliconius Erato': 'Heliconius erato - Red Postman', '004_Junonia Coenia': 'Junonia coenia - Common Buckeye', '005_Lycaena Phlaeas': 'Lycaena phlaeas - Small Copper', '006_Nymphalis Antiopa': 'Nymphalis antiopa - Mourning Cloak', '007_Papilio Cresphontes': 'Papilio cresphontes - Giant Swallowtail', '008_Pieris Rapae': 'Pieris rapae - Cabbage White', '009_Vanessa Atalanta': 'Vanessa atalanta - Red Admiral', '010_Vanessa Cardui': 'Vanessa cardui - Painted Lady', } def do_inference(img): img_t = data_transforms_test(img) batch_t = torch.unsqueeze(img_t, 0) model_ft.eval() # We don't need gradients for test, so wrap in # no_grad to save memory with torch.no_grad(): batch_t = batch_t.to(device) # forward propagation output = model_ft( batch_t) # get prediction probs = torch.nn.functional.softmax(output, dim=1) output = torch.argsort(probs, dim=1, descending=True).cpu().numpy()[0].astype(int) probs = probs.cpu().numpy()[0] probs = probs[output] labels = np.array(class_names)[output] return {id_to_name[labels[i]]: round(float(probs[i]),2) for i in range(len(labels))} im = gr.inputs.Image(shape=(512, 512), image_mode='RGB', invert_colors=False, source="upload", type="pil") title = "Butterfly Classification Demo" description = "A pretrained ResNet18 CNN trained on the Leeds Butterfly Dataset. Libraries: PyTorch, Gradio." examples = [['./h_erato.jpg'],['d_plexippus.jpg'],['v_cardui.jpg'],['p_cresphontes.jpg'],['p_rapae.jpg']] article_text = markdown.markdown('''

PyTorch image classification - A pretrained ResNet18 CNN trained on the Leeds Butterfly Dataset


The Leeds Butterfly Dataset consists of 832 images in 10 classes:


Part of a dissertation project. Author: ttheland

''') iface = gr.Interface( do_inference, im, gr.outputs.Label(num_top_classes=3), live=False, interpretation=None, title=title, description=description, article= article_text, examples=examples, theme="dark-peach" ) iface.test_launch() iface.launch(share=True, enable_queue=True)