File size: 4,867 Bytes
8f9e2b0 da716ed 6421982 da716ed 6421982 f890c17 6421982 da716ed af134f4 da716ed af134f4 fcf1013 da716ed af134f4 da716ed 391e183 da716ed 391e183 da716ed 391e183 da716ed 8a896a8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
import os
while True:
try:
import cv2
except ImportError:
print("Package cv2 not found. Attepting installation.")
os.system("pip install -U opencv-python &> /dev/null")
continue
break
import os, cv2, time, math
print("=> Loading libraries...")
start = time.time()
import requests, torch, argparse
import gradio as gr
from torchvision import transforms
from datasets import load_dataset
from timm.data import create_transform
from timm.models import create_model, load_checkpoint
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
parser = argparse.ArgumentParser()
parser.add_argument("--local", action='store_true')
args = parser.parse_args()
if not args.local:
print("=> Logging into huggingface...")
from huggingface_hub import login
login(token=os.environ["HF_TOKEN"])
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"=> Libraries loaded in {time.time()- start:.2f} sec(s).")
print("=> Loading model...")
start = time.time()
size = "b"
img_size = 224
crop_pct = 0.9
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
model = create_model(f"tpmlp_{size}").to(device)
try:
load_checkpoint(model, f"../tpmlp_{size}.pth.tar", True)
except FileNotFoundError:
load_checkpoint(model, f"tpmlp_{size}.pth.tar", True)
model.eval()
response = requests.get("https://git.io/JJkYN")
labels = response.text.split("\n")
augs = create_transform(
input_size=(3, 224, 224),
is_training=False,
use_prefetcher=False,
crop_pct=0.9,
)
scale_size = math.floor(img_size / crop_pct)
resize = transforms.Compose([
transforms.Resize(scale_size),
transforms.CenterCrop(img_size),
transforms.ToTensor()
])
normalize = transforms.Normalize(mean=torch.tensor(IMAGENET_DEFAULT_MEAN), std=torch.tensor(IMAGENET_DEFAULT_STD))
def transform(img):
img = resize(img.convert("RGB"))
tensor = normalize(img)
return img, tensor
def predict(inp):
img, inp = transform(inp)
inp = inp.unsqueeze(0)
with GradCAM(model=model, target_layers=[model.layers[3]], use_cuda=device=="cuda") as cam:
grayscale_cam, probs = cam(input_tensor=inp, aug_smooth=False, eigen_smooth=False, return_probs=True)
# Here grayscale_cam has only one image in the batch
grayscale_cam = grayscale_cam[0, :]
probs = probs[0, :]
cam_image = show_cam_on_image(img.permute(1, 2, 0).detach().cpu().numpy(), grayscale_cam, use_rgb=True, image_weight=0.5, colormap=cv2.COLORMAP_TWILIGHT_SHIFTED)
confidences = {labels[i]: float(probs[i]) for i in range(1000)}
return confidences, cam_image
print(f"=> Model (tpmlp_{size}) loaded in {time.time()- start:.2f} sec(s).")
base = "../example-imgs" if args.local else "."
print("=> Loading examples.")
indices = [
0, # Coucal
2, # Volcano
7, # Sombrero
9, # Balance beam
10, # Sulphur-crested cockatoo
11, # Shower cap
12, # Petri dish INCORRECTLY CLASSIFIED as lens
14, # Angora rabbit
]
ds = load_dataset("imagenet-1k", split="validation", streaming=True)
examples = []; idx = 0
start = time.time()
for data in ds:
if idx == indices:
data['image'].save(f"{base}/{idx}.png")
idx += 1
if idx == max(indices):
break
del ds
print(f"=> Examples loaded in {time.time()- start:.2f} sec(s).")
# demo = gr.Interface(
# fn=predict,
# inputs=gr.inputs.Image(type="pil"),
# outputs=[gr.outputs.Label(num_top_classes=4), gr.outputs.Image(type="numpy")],
# examples=[f"../example-imgs/{idx}.png" for idx in indices],
# )
with gr.Blocks(theme=gr.themes.Monochrome(font=[gr.themes.GoogleFont("DM Sans"), "sans-serif"])) as demo:
gr.HTML("""
<h1 align="center">Interactive Demo</h1>
<h2 align="center">CS-Mixer: A Cross-Scale Vision MLP Model with Spatial–Channel Mixing</h2>
<br><br>
""")
with gr.Row():
input_image = gr.Image(type="pil", min_width=300, label="Input Image")
softmax = gr.Label(num_top_classes=4, min_width=200, label="Model Predictions")
grad_cam = gr.Image(type="numpy", min_width=300, label="Grad-CAM")
with gr.Row():
gr.Button("Predict").click(fn=predict, inputs=input_image, outputs=[softmax, grad_cam])
gr.ClearButton(input_image)
with gr.Row():
gr.Examples([f"{base}/{idx}.png" for idx in indices], inputs=input_image, outputs=[softmax, grad_cam], fn=predict, run_on_click=True)
if args.local:
demo.launch(
share=False, debug=False, allowed_paths=[f"{base}"], server_name="0.0.0.0", # ssl_verify=False,
server_port=8000, # ssl_certfile="/workspace/openssl/cert.pem", ssl_keyfile="/workspace/openssl/key.pem"
)
else:
demo.launch(allowed_paths=[f"{base}"]) |