Spaces:
Running
Running
File size: 3,011 Bytes
0f36ee2 1b63bdd 0f36ee2 1b63bdd 0f36ee2 1b63bdd 0f36ee2 1b63bdd 0f36ee2 1b63bdd 0f36ee2 1b63bdd 0f36ee2 1b63bdd 0f36ee2 1b63bdd 0f36ee2 1b63bdd 0f36ee2 1b63bdd 0f36ee2 2b73cea 1b63bdd 4b80b24 0f36ee2 1b63bdd 0f36ee2 77f2530 0f36ee2 a40046b 1b63bdd 0f36ee2 1b63bdd 0f36ee2 a4885d0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
import torch
import torch.nn as nn
import numpy as np
from torchvision import models, transforms
import time
import os
import copy
import pickle
from PIL import Image
import datetime
import gdown
import urllib.request
import gradio as gr
# url = 'https://drive.google.com/uc?id=1VMLpE5ojF9fq0GtBKaqcMVWUIfJUfKbc'
path_class_names = "./class_names_restnet_catsVSdogs.pkl"
# gdown.download(url, path_class_names, quiet=False, use_cookies=False)
# url = 'https://drive.google.com/uc?id=1jorQB1mpPCLH097M8paxut3v5XwVlKqp'
path_model = "./model_state_restnet_catsVSdogs.pth"
# gdown.download(url, path_model, quiet=False, use_cookies=False)
url = (
"https://upload.wikimedia.org/wikipedia/commons/3/38/Adorable-animal-cat-20787.jpg"
)
path_input = "./cat.jpg"
urllib.request.urlretrieve(url, filename=path_input)
url = "https://upload.wikimedia.org/wikipedia/commons/4/43/Cute_dog.jpg"
path_input = "./dog.jpg"
urllib.request.urlretrieve(url, filename=path_input)
data_transforms_val = transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
class_names = pickle.load(open(path_class_names, "rb"))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_ft = models.resnet18(pretrained=True)
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, len(class_names))
model_ft = model_ft.to(device)
model_ft.load_state_dict(copy.deepcopy(torch.load(path_model, device)))
def do_inference(img):
img_t = data_transforms_val(img)
batch_t = torch.unsqueeze(img_t, 0)
model_ft.eval()
# We don't need gradients for test, so wrap in
# no_grad to save memory
with torch.no_grad():
batch_t = batch_t.to(device)
# forward propagation
output = model_ft(batch_t)
# get prediction
probs = torch.nn.functional.softmax(output, dim=1)
output = (
torch.argsort(probs, dim=1, descending=True).cpu().numpy()[0].astype(int)
)
probs = probs.cpu().numpy()[0]
probs = probs[output]
labels = np.array(class_names)[output]
return {labels[i]: round(float(probs[i]), 2) for i in range(len(labels))}
im = gr.inputs.Image(
shape=(512, 512), image_mode="RGB", invert_colors=False, source="upload", type="pil"
)
title = "CatsVsDogs Classifier"
description = "Playground: Inferernce of Object Classification (Binary) using ResNet18 model and CatsVsDogs dataset. Libraries: PyTorch, Gradio."
examples = [["./cat.jpg"], ["./dog.jpg"]]
article = "<p style='text-align: center'><a href='https://github.com/mawady' target='_blank'>By Dr. Mohamed Elawady</a></p>"
iface = gr.Interface(
do_inference,
im,
gr.outputs.Label(num_top_classes=2),
live=False,
interpretation=None,
title=title,
description=description,
article=article,
examples=examples,
)
# iface.test_launch()
iface.launch() |