Spaces:
Runtime error
Runtime error
File size: 3,508 Bytes
6ab04f7 68f7ba1 6ab04f7 df59928 6ab04f7 df59928 43382ea df59928 a58085c df59928 6ab04f7 7e3982a 6ab04f7 68f7ba1 6ab04f7 df59928 d2d4aba 6ab04f7 df59928 6ab04f7 3a5cbeb 6ab04f7 a6dcd84 892e4ed |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
import argparse
import requests
import gradio as gr
import numpy as np
import cv2
import torch
import torch.nn as nn
from PIL import Image
import torchvision
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.data import create_transform
from timmvit import timmvit
import json
from timm.models.hub import download_cached_file
from PIL import Image
def pil_loader(filepath):
with Image.open(filepath) as img:
img = img.convert('RGB')
return img
def build_transforms(input_size, center_crop=True):
transform = torchvision.transforms.Compose([
torchvision.transforms.ToPILImage(),
torchvision.transforms.Resize(input_size * 8 // 7),
torchvision.transforms.CenterCrop(input_size),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
return transform
# Download human-readable labels for Bamboo.
with open('./trainid2name.json') as f:
id2name = json.load(f)
'''
build model
'''
model = timmvit(pretrain_path='./Bamboo_v0-1_ViT-B16.pth.tar.convert')
model.eval()
'''
borrow code from here: https://github.com/jacobgil/pytorch-grad-cam/blob/master/pytorch_grad_cam/utils/image.py
'''
def show_cam_on_image(img: np.ndarray,
mask: np.ndarray,
use_rgb: bool = False,
colormap: int = cv2.COLORMAP_JET) -> np.ndarray:
""" This function overlays the cam mask on the image as an heatmap.
By default the heatmap is in BGR format.
:param img: The base image in RGB or BGR format.
:param mask: The cam mask.
:param use_rgb: Whether to use an RGB or BGR heatmap, this should be set to True if 'img' is in RGB format.
:param colormap: The OpenCV colormap to be used.
:returns: The default image with the cam overlay.
"""
heatmap = cv2.applyColorMap(np.uint8(255 * mask), colormap)
if use_rgb:
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
heatmap = np.float32(heatmap) / 255
if np.max(img) > 1:
raise Exception(
"The input image should np.float32 in the range [0, 1]")
cam = 0.7*heatmap + 0.3*img
# cam = cam / np.max(cam)
return np.uint8(255 * cam)
def recognize_image(image):
img_t = eval_transforms(image)
# compute output
output = model(img_t.unsqueeze(0))
prediction = output.softmax(-1).flatten()
_,top5_idx = torch.topk(prediction, 5)
return {id2name[str(i)][0]: float(prediction[i]) for i in top5_idx.tolist()}
eval_transforms = build_transforms(224)
image = gr.inputs.Image()
label = gr.outputs.Label(num_top_classes=5)
gr.Interface(
description="Bamboo for Image Recognition Demo (https://github.com/Davidzhangyuanhan/Bamboo). Bamboo knows what this object is and what you are doing in a very fine-grain granularity: fratercula arctica (fig.5) and dribbler (fig.2)).",
fn=recognize_image,
inputs=["image"],
outputs=[
label,
],
examples=[
["./examples/playing_mahjong.jpg"],
["./examples/dribbler.jpg"],
["./examples/Ferrari-F355.jpg"],
["./examples/northern_oriole.jpg"],
["./examples/fratercula_arctica.jpg"],
["./examples/husky.jpg"],
["./examples/taraxacum_erythrospermum.jpg"],
],
).launch()
gr.markdown('<img id="visitor-badge" alt="visitor badge" src="https://visitor-badge.glitch.me/badge?page_id=cvpr.bamboo_vit-b16_demo" />') |