Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
import numpy as np | |
from torchvision import models, transforms | |
import time | |
import os | |
import copy | |
import pickle | |
from PIL import Image | |
import datetime | |
import gdown | |
import urllib.request | |
import gradio as gr | |
import markdown | |
url = 'https://drive.google.com/uc?id=1qKiyp4r8SqUtz2ZWk3E6oZhyhl6t8lyG' | |
path_class_names = "./class_names_restnet_leeds_butterfly.pkl" | |
gdown.download(url, path_class_names, quiet=False) | |
url = 'https://drive.google.com/uc?id=1Ep2YWU4M-yVkF7AFP3aD1sVhuriIDzFe' | |
path_model = "./model_state_restnet_leeds_butterfly.pth" | |
gdown.download(url, path_model, quiet=False) | |
url = "https://upload.wikimedia.org/wikipedia/commons/thumb/f/f8/Red_postman_butterfly_%28Heliconius_erato%29.jpg/1599px-Red_postman_butterfly_%28Heliconius_erato%29.jpg" | |
path_input = "./h_erato.jpg" | |
urllib.request.urlretrieve(url, filename=path_input) | |
url = "https://upload.wikimedia.org/wikipedia/commons/thumb/6/63/Monarch_In_May.jpg/1024px-Monarch_In_May.jpg" | |
path_input = "./d_plexippus.jpg" | |
urllib.request.urlretrieve(url, filename=path_input) | |
# normalisation | |
data_transforms_test = transforms.Compose([ | |
transforms.Resize(256), | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
]) | |
class_names = pickle.load(open(path_class_names, "rb")) | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
model_ft = models.resnet18(pretrained=True) | |
num_ftrs = model_ft.fc.in_features | |
model_ft.fc = nn.Linear(num_ftrs, len(class_names)) | |
model_ft = model_ft.to(device) | |
model_ft.load_state_dict(copy.deepcopy(torch.load(path_model,device))) | |
# Proper labeling | |
id_to_name = { | |
'001_Danaus Plexippus': 'Danaus plexippus - Monarch', | |
'002_Heliconius Charitonius': 'Heliconius charitonius - Zebra Longwing', | |
'003_Heliconius Erato': 'Heliconius erato - Red Postman', | |
'004_Junonia Coenia': 'Junonia coenia - Common Buckeye', | |
'005_Lycaena Phlaeas': 'Lycaena phlaeas - Small Copper', | |
'006_Nymphalis Antiopa': 'Nymphalis antiopa - Mourning Cloak', | |
'007_Papilio Cresphontes': 'Papilio cresphontes - Giant Swallowtail', | |
'008_Pieris Rapae': 'Pieris rapae - Cabbage White', | |
'009_Vanessa Atalanta': 'Vanessa atalanta - Red Admiral', | |
'010_Vanessa Cardui': 'Vanessa cardui - Painted Lady', | |
} | |
def do_inference(img): | |
img_t = data_transforms_test(img) | |
batch_t = torch.unsqueeze(img_t, 0) | |
model_ft.eval() | |
# We don't need gradients for test, so wrap in | |
# no_grad to save memory | |
with torch.no_grad(): | |
batch_t = batch_t.to(device) | |
# forward propagation | |
output = model_ft( batch_t) | |
# get prediction | |
probs = torch.nn.functional.softmax(output, dim=1) | |
output = torch.argsort(probs, dim=1, descending=True).cpu().numpy()[0].astype(int) | |
probs = probs.cpu().numpy()[0] | |
probs = probs[output] | |
labels = np.array(class_names)[output] | |
return {id_to_name[labels[i]]: round(float(probs[i]),2) for i in range(len(labels))} | |
im = gr.inputs.Image(shape=(512, 512), image_mode='RGB', | |
invert_colors=False, source="upload", | |
type="pil") | |
title = "Butterfly Classification Demo" | |
description = "A pretrained ResNet18 CNN trained on the Leeds Butterfly Dataset. Libraries: PyTorch, Gradio." | |
examples = [['./h_erato.jpg'],['d_plexippus.jpg']] | |
article_text = markdown.markdown(''' | |
<h1 style="color:white">PyTorch image classification - A pretrained ResNet18 CNN trained on the <a href="http://www.josiahwang.com/dataset/leedsbutterfly/">Leeds Butterfly Dataset</a></h1> | |
<br> | |
<p>The Leeds Butterfly Dataset consists of 832 images in 10 classes:</p> | |
<ul> | |
<li>Danaus plexippus - Monarch</li> | |
<li>Heliconius charitonius - Zebra Longwing</li> | |
<li>Heliconius erato - Red Postman</li> | |
<li>Lycaena phlaeas - Small Copper</li> | |
<li>Junonia coenia - Common Buckeye</li> | |
<li>Nymphalis antiopa - Mourning Cloak</li> | |
<li>Papilio cresphontes - Giant Swallowtail</li> | |
<li>Pieris rapae - Cabbage White</li> | |
<li>Vanessa atalanta - Red Admiral</li> | |
<li>Vanessa cardui - Painted Lady</li> | |
</ul> | |
<br> | |
<p>Part of a dissertation project. Author: <a href="https://github.com/ttheland">ttheland</a></p> | |
''') | |
# enable queue | |
enable_queue = True | |
iface = gr.Interface( | |
do_inference, | |
im, | |
gr.outputs.Label(num_top_classes=2), | |
live=False, | |
interpretation=None, | |
title=title, | |
description=description, | |
article= article_text, | |
examples=examples, | |
enable_queue=enable_queue | |
) | |
iface.test.launch() | |
iface.launch(share=True) |