import argparse import requests import gradio as gr import numpy as np import cv2 import torch import torch.nn as nn from PIL import Image import torchvision from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import create_transform from timmvit import timmvit import json from timm.models.hub import download_cached_file from PIL import Image def pil_loader(filepath): with Image.open(filepath) as img: img = img.convert('RGB') return img def build_transforms(input_size, center_crop=True): transform = torchvision.transforms.Compose([ torchvision.transforms.ToPILImage(), torchvision.transforms.Resize(input_size * 8 // 7), torchvision.transforms.CenterCrop(input_size), torchvision.transforms.ToTensor(), torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) return transform # Download human-readable labels for Bamboo. with open('./trainid2name.json') as f: id2name = json.load(f) ''' build model ''' model = timmvit(pretrain_path='./Bamboo_v0-1_ViT-B16.pth.tar.convert') model.eval() ''' borrow code from here: https://github.com/jacobgil/pytorch-grad-cam/blob/master/pytorch_grad_cam/utils/image.py ''' def show_cam_on_image(img: np.ndarray, mask: np.ndarray, use_rgb: bool = False, colormap: int = cv2.COLORMAP_JET) -> np.ndarray: """ This function overlays the cam mask on the image as an heatmap. By default the heatmap is in BGR format. :param img: The base image in RGB or BGR format. :param mask: The cam mask. :param use_rgb: Whether to use an RGB or BGR heatmap, this should be set to True if 'img' is in RGB format. :param colormap: The OpenCV colormap to be used. :returns: The default image with the cam overlay. """ heatmap = cv2.applyColorMap(np.uint8(255 * mask), colormap) if use_rgb: heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) heatmap = np.float32(heatmap) / 255 if np.max(img) > 1: raise Exception( "The input image should np.float32 in the range [0, 1]") cam = 0.7*heatmap + 0.3*img # cam = cam / np.max(cam) return np.uint8(255 * cam) def recognize_image(image): img_t = eval_transforms(image) # compute output output = model(img_t.unsqueeze(0)) prediction = output.softmax(-1).flatten() _,top5_idx = torch.topk(prediction, 5) return {id2name[str(i)][0]: float(prediction[i]) for i in top5_idx.tolist()} eval_transforms = build_transforms(224) image = gr.inputs.Image() label = gr.outputs.Label(num_top_classes=5) gr.Interface( description="Bamboo for Image Recognition Demo (https://github.com/Davidzhangyuanhan/Bamboo). Bamboo knows what this object is and what you are doing in a very fine-grain granularity: fratercula arctica (fig.5) and dribbler (fig.2)).", fn=recognize_image, inputs=["image"], outputs=[ label, ], examples=[ ["./examples/playing_mahjong.jpg"], ["./examples/dribbler.jpg"], ["./examples/Ferrari-F355.jpg"], ["./examples/northern_oriole.jpg"], ["./examples/fratercula_arctica.jpg"], ["./examples/husky.jpg"], ["./examples/taraxacum_erythrospermum.jpg"], ], ).launch()