mawady's picture
format code
4b80b24
import torch
import torch.nn as nn
import numpy as np
from torchvision import models, transforms
import time
import os
import copy
import pickle
from PIL import Image
import datetime
import gdown
import urllib.request
import gradio as gr
# url = 'https://drive.google.com/uc?id=1VMLpE5ojF9fq0GtBKaqcMVWUIfJUfKbc'
path_class_names = "./class_names_restnet_catsVSdogs.pkl"
# gdown.download(url, path_class_names, quiet=False, use_cookies=False)
# url = 'https://drive.google.com/uc?id=1jorQB1mpPCLH097M8paxut3v5XwVlKqp'
path_model = "./model_state_restnet_catsVSdogs.pth"
# gdown.download(url, path_model, quiet=False, use_cookies=False)
url = (
"https://upload.wikimedia.org/wikipedia/commons/3/38/Adorable-animal-cat-20787.jpg"
)
path_input = "./cat.jpg"
urllib.request.urlretrieve(url, filename=path_input)
url = "https://upload.wikimedia.org/wikipedia/commons/4/43/Cute_dog.jpg"
path_input = "./dog.jpg"
urllib.request.urlretrieve(url, filename=path_input)
data_transforms_val = transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
class_names = pickle.load(open(path_class_names, "rb"))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_ft = models.resnet18(pretrained=True)
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, len(class_names))
model_ft = model_ft.to(device)
model_ft.load_state_dict(copy.deepcopy(torch.load(path_model, device)))
def do_inference(img):
img_t = data_transforms_val(img)
batch_t = torch.unsqueeze(img_t, 0)
model_ft.eval()
# We don't need gradients for test, so wrap in
# no_grad to save memory
with torch.no_grad():
batch_t = batch_t.to(device)
# forward propagation
output = model_ft(batch_t)
# get prediction
probs = torch.nn.functional.softmax(output, dim=1)
output = (
torch.argsort(probs, dim=1, descending=True).cpu().numpy()[0].astype(int)
)
probs = probs.cpu().numpy()[0]
probs = probs[output]
labels = np.array(class_names)[output]
return {labels[i]: round(float(probs[i]), 2) for i in range(len(labels))}
im = gr.inputs.Image(
shape=(512, 512), image_mode="RGB", invert_colors=False, source="upload", type="pil"
)
title = "CatsVsDogs Classifier"
description = "Playground: Inferernce of Object Classification (Binary) using ResNet18 model and CatsVsDogs dataset. Libraries: PyTorch, Gradio."
examples = [["./cat.jpg"], ["./dog.jpg"]]
article = "<p style='text-align: center'><a href='https://github.com/mawady' target='_blank'>By Dr. Mohamed Elawady</a></p>"
iface = gr.Interface(
do_inference,
im,
gr.outputs.Label(num_top_classes=2),
live=False,
interpretation=None,
title=title,
description=description,
article=article,
examples=examples,
)
# iface.test_launch()
iface.launch()