File size: 4,867 Bytes
8f9e2b0
 
 
 
 
 
 
 
 
 
da716ed
 
 
 
6421982
da716ed
 
 
 
 
 
 
 
6421982
 
 
 
 
f890c17
6421982
 
da716ed
af134f4
 
da716ed
 
 
 
 
 
 
 
 
 
af134f4
fcf1013
 
 
 
da716ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
af134f4
da716ed
 
 
 
 
 
 
 
 
 
 
 
391e183
da716ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
391e183
da716ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
391e183
da716ed
8a896a8
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import os
while True:
    try:
        import cv2
    except ImportError:
        print("Package cv2 not found. Attepting installation.")
        os.system("pip install -U opencv-python &> /dev/null")
        continue
    break

import os, cv2, time, math
print("=> Loading libraries...")
start = time.time()

import requests, torch, argparse
import gradio as gr
from torchvision import transforms
from datasets import load_dataset
from timm.data import create_transform
from timm.models import create_model, load_checkpoint
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image

parser = argparse.ArgumentParser()
parser.add_argument("--local", action='store_true')
args = parser.parse_args()

if not args.local:
    print("=> Logging into huggingface...")
    from huggingface_hub import login
    login(token=os.environ["HF_TOKEN"])

device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"=> Libraries loaded in {time.time()- start:.2f} sec(s).")
print("=> Loading model...")
start = time.time()

size = "b"
img_size = 224
crop_pct = 0.9
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)

model = create_model(f"tpmlp_{size}").to(device)
try:
    load_checkpoint(model, f"../tpmlp_{size}.pth.tar", True)
except FileNotFoundError:
    load_checkpoint(model, f"tpmlp_{size}.pth.tar", True)
model.eval()

response = requests.get("https://git.io/JJkYN")
labels = response.text.split("\n")

augs = create_transform(
    input_size=(3, 224, 224),
    is_training=False,
    use_prefetcher=False,
    crop_pct=0.9,
)


scale_size = math.floor(img_size / crop_pct)
resize = transforms.Compose([
    transforms.Resize(scale_size),
    transforms.CenterCrop(img_size),
    transforms.ToTensor()
])
normalize = transforms.Normalize(mean=torch.tensor(IMAGENET_DEFAULT_MEAN), std=torch.tensor(IMAGENET_DEFAULT_STD))

def transform(img):
    img = resize(img.convert("RGB"))
    tensor = normalize(img)
    return img, tensor

def predict(inp):
    img, inp = transform(inp)
    inp = inp.unsqueeze(0)
    with GradCAM(model=model, target_layers=[model.layers[3]], use_cuda=device=="cuda") as cam:
        grayscale_cam, probs = cam(input_tensor=inp, aug_smooth=False, eigen_smooth=False, return_probs=True)
        
        # Here grayscale_cam has only one image in the batch
        grayscale_cam = grayscale_cam[0, :]
        probs = probs[0, :]

        cam_image = show_cam_on_image(img.permute(1, 2, 0).detach().cpu().numpy(), grayscale_cam, use_rgb=True, image_weight=0.5, colormap=cv2.COLORMAP_TWILIGHT_SHIFTED)
        confidences = {labels[i]: float(probs[i]) for i in range(1000)}
    return confidences, cam_image

print(f"=> Model (tpmlp_{size}) loaded in {time.time()- start:.2f} sec(s).")

base = "../example-imgs" if args.local else "."

print("=> Loading examples.")
indices = [
    0,      # Coucal
    2,      # Volcano
    7,      # Sombrero
    9,      # Balance beam
    10,     # Sulphur-crested cockatoo
    11,     # Shower cap
    12,     # Petri dish INCORRECTLY CLASSIFIED as lens
    14,     # Angora rabbit
]
ds = load_dataset("imagenet-1k", split="validation", streaming=True)
examples = []; idx = 0
start = time.time()
for data in ds:
    if idx == indices:
        data['image'].save(f"{base}/{idx}.png")
    idx += 1
    if idx == max(indices):
        break
del ds
print(f"=> Examples loaded in {time.time()- start:.2f} sec(s).")

# demo = gr.Interface(
#     fn=predict, 
#     inputs=gr.inputs.Image(type="pil"),
#     outputs=[gr.outputs.Label(num_top_classes=4), gr.outputs.Image(type="numpy")],
#     examples=[f"../example-imgs/{idx}.png" for idx in indices],
# )


with gr.Blocks(theme=gr.themes.Monochrome(font=[gr.themes.GoogleFont("DM Sans"), "sans-serif"])) as demo:
    gr.HTML("""
    <h1 align="center">Interactive Demo</h1>
    <h2 align="center">CS-Mixer: A Cross-Scale Vision MLP Model with Spatial–Channel Mixing</h2>
    <br><br>
    """)
    with gr.Row():
        input_image = gr.Image(type="pil", min_width=300, label="Input Image")
        softmax = gr.Label(num_top_classes=4, min_width=200, label="Model Predictions")
        grad_cam = gr.Image(type="numpy", min_width=300, label="Grad-CAM")
    with gr.Row():
        gr.Button("Predict").click(fn=predict, inputs=input_image, outputs=[softmax, grad_cam])
        gr.ClearButton(input_image)
    with gr.Row():
        gr.Examples([f"{base}/{idx}.png" for idx in indices], inputs=input_image, outputs=[softmax, grad_cam], fn=predict, run_on_click=True)
            
if args.local:
    demo.launch(
        share=False, debug=False, allowed_paths=[f"{base}"], server_name="0.0.0.0", # ssl_verify=False,
        server_port=8000, # ssl_certfile="/workspace/openssl/cert.pem", ssl_keyfile="/workspace/openssl/key.pem"
    )
else:
    demo.launch(allowed_paths=[f"{base}"])