import torch import torch.nn as nn import numpy as np from torchvision import models, transforms import time import os import copy import pickle from PIL import Image import datetime import gdown import urllib.request import gradio as gr # url = 'https://drive.google.com/uc?id=1VMLpE5ojF9fq0GtBKaqcMVWUIfJUfKbc' path_class_names = "./class_names_restnet_catsVSdogs.pkl" # gdown.download(url, path_class_names, quiet=False, use_cookies=False) # url = 'https://drive.google.com/uc?id=1jorQB1mpPCLH097M8paxut3v5XwVlKqp' path_model = "./model_state_restnet_catsVSdogs.pth" # gdown.download(url, path_model, quiet=False, use_cookies=False) url = ( "https://upload.wikimedia.org/wikipedia/commons/3/38/Adorable-animal-cat-20787.jpg" ) path_input = "./cat.jpg" urllib.request.urlretrieve(url, filename=path_input) url = "https://upload.wikimedia.org/wikipedia/commons/4/43/Cute_dog.jpg" path_input = "./dog.jpg" urllib.request.urlretrieve(url, filename=path_input) data_transforms_val = transforms.Compose( [ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ] ) class_names = pickle.load(open(path_class_names, "rb")) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model_ft = models.resnet18(pretrained=True) num_ftrs = model_ft.fc.in_features model_ft.fc = nn.Linear(num_ftrs, len(class_names)) model_ft = model_ft.to(device) model_ft.load_state_dict(copy.deepcopy(torch.load(path_model, device))) def do_inference(img): img_t = data_transforms_val(img) batch_t = torch.unsqueeze(img_t, 0) model_ft.eval() # We don't need gradients for test, so wrap in # no_grad to save memory with torch.no_grad(): batch_t = batch_t.to(device) # forward propagation output = model_ft(batch_t) # get prediction probs = torch.nn.functional.softmax(output, dim=1) output = ( torch.argsort(probs, dim=1, descending=True).cpu().numpy()[0].astype(int) ) probs = probs.cpu().numpy()[0] probs = probs[output] labels = np.array(class_names)[output] return {labels[i]: round(float(probs[i]), 2) for i in range(len(labels))} im = gr.inputs.Image( shape=(512, 512), image_mode="RGB", invert_colors=False, source="upload", type="pil" ) title = "CatsVsDogs Classifier" description = "Playground: Inferernce of Object Classification (Binary) using ResNet18 model and CatsVsDogs dataset. Libraries: PyTorch, Gradio." examples = [["./cat.jpg"], ["./dog.jpg"]] article = "

By Dr. Mohamed Elawady

" iface = gr.Interface( do_inference, im, gr.outputs.Label(num_top_classes=2), live=False, interpretation=None, title=title, description=description, article=article, examples=examples, ) # iface.test_launch() iface.launch()