File size: 3,011 Bytes
0f36ee2
 
 
 
 
 
 
 
 
 
 
 
 
 
1b63bdd
0f36ee2
1b63bdd
0f36ee2
1b63bdd
0f36ee2
1b63bdd
0f36ee2
1b63bdd
 
 
0f36ee2
 
 
 
 
 
 
 
1b63bdd
 
0f36ee2
 
 
1b63bdd
 
 
0f36ee2
 
 
 
 
 
 
 
1b63bdd
 
0f36ee2
 
1b63bdd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f36ee2
1b63bdd
 
 
 
0f36ee2
2b73cea
1b63bdd
4b80b24
0f36ee2
1b63bdd
 
0f36ee2
 
77f2530
0f36ee2
 
a40046b
1b63bdd
0f36ee2
 
1b63bdd
0f36ee2
a4885d0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import torch
import torch.nn as nn
import numpy as np
from torchvision import models, transforms
import time
import os
import copy
import pickle
from PIL import Image
import datetime
import gdown
import urllib.request
import gradio as gr

# url = 'https://drive.google.com/uc?id=1VMLpE5ojF9fq0GtBKaqcMVWUIfJUfKbc'
path_class_names = "./class_names_restnet_catsVSdogs.pkl"
# gdown.download(url, path_class_names, quiet=False, use_cookies=False)

# url = 'https://drive.google.com/uc?id=1jorQB1mpPCLH097M8paxut3v5XwVlKqp'
path_model = "./model_state_restnet_catsVSdogs.pth"
# gdown.download(url, path_model, quiet=False, use_cookies=False)

url = (
    "https://upload.wikimedia.org/wikipedia/commons/3/38/Adorable-animal-cat-20787.jpg"
)
path_input = "./cat.jpg"
urllib.request.urlretrieve(url, filename=path_input)


url = "https://upload.wikimedia.org/wikipedia/commons/4/43/Cute_dog.jpg"
path_input = "./dog.jpg"
urllib.request.urlretrieve(url, filename=path_input)

data_transforms_val = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
)
class_names = pickle.load(open(path_class_names, "rb"))

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model_ft = models.resnet18(pretrained=True)
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, len(class_names))
model_ft = model_ft.to(device)
model_ft.load_state_dict(copy.deepcopy(torch.load(path_model, device)))


def do_inference(img):
    img_t = data_transforms_val(img)
    batch_t = torch.unsqueeze(img_t, 0)
    model_ft.eval()
    # We don't need gradients for test, so wrap in
    # no_grad to save memory
    with torch.no_grad():
        batch_t = batch_t.to(device)
        # forward propagation
        output = model_ft(batch_t)
        # get prediction
        probs = torch.nn.functional.softmax(output, dim=1)
        output = (
            torch.argsort(probs, dim=1, descending=True).cpu().numpy()[0].astype(int)
        )
        probs = probs.cpu().numpy()[0]
        probs = probs[output]
        labels = np.array(class_names)[output]
        return {labels[i]: round(float(probs[i]), 2) for i in range(len(labels))}


im = gr.inputs.Image(
    shape=(512, 512), image_mode="RGB", invert_colors=False, source="upload", type="pil"
)
title = "CatsVsDogs Classifier"
description = "Playground: Inferernce of Object Classification (Binary) using ResNet18 model and CatsVsDogs dataset. Libraries: PyTorch, Gradio."
examples = [["./cat.jpg"], ["./dog.jpg"]]
article = "<p style='text-align: center'><a href='https://github.com/mawady' target='_blank'>By Dr. Mohamed Elawady</a></p>"
iface = gr.Interface(
    do_inference,
    im,
    gr.outputs.Label(num_top_classes=2),
    live=False,
    interpretation=None,
    title=title,
    description=description,
    article=article,
    examples=examples,
)

# iface.test_launch()

iface.launch()