|
import torch |
|
|
|
|
|
model = torch.hub.load("facebookresearch/swag", model="vit_h14_in1k") |
|
|
|
|
|
model.eval() |
|
|
|
resolution = 518 |
|
|
|
import os |
|
os.system("wget https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json -O in_cls_idx.json") |
|
|
|
import gradio as gr |
|
|
|
from PIL import Image |
|
from torchvision import transforms |
|
|
|
import json |
|
|
|
|
|
with open("in_cls_idx.json", "r") as f: |
|
imagenet_id_to_name = {int(cls_id): name for cls_id, (label, name) in json.load(f).items()} |
|
|
|
|
|
|
|
|
|
def load_image(image_path): |
|
return Image.open(image_path).convert("RGB") |
|
|
|
|
|
|
|
def transform_image(image, resolution): |
|
transform = transforms.Compose([ |
|
transforms.Resize( |
|
resolution, |
|
interpolation=transforms.InterpolationMode.BICUBIC, |
|
), |
|
transforms.CenterCrop(resolution), |
|
transforms.ToTensor(), |
|
transforms.Normalize( |
|
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] |
|
), |
|
]) |
|
image = transform(image) |
|
|
|
image = image[None, :] |
|
return image |
|
|
|
def visualize_and_predict(model, resolution, image_path): |
|
image = load_image(image_path) |
|
image = transform_image(image, resolution) |
|
|
|
|
|
with torch.no_grad(): |
|
_, preds = model(image).topk(5) |
|
|
|
preds = preds.tolist()[0] |
|
|
|
return preds |
|
|
|
os.system("wget https://github.com/pytorch/hub/raw/master/images/dog.jpg -O dog.jpg") |
|
|
|
|
|
|
|
def inference(img): |
|
preds = visualize_and_predict(model, resolution, img) |
|
return [imagenet_id_to_name[cls_id] for cls_id in preds] |
|
|
|
inputs = gr.inputs.Image(type='filepath') |
|
outputs = gr.outputs.Textbox(label="Output") |
|
|
|
title = "SWAG" |
|
|
|
description = "Gradio demo for Revisiting Weakly Supervised Pre-Training of Visual Perception Models. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below." |
|
|
|
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2201.08371' target='_blank'>Revisiting Weakly Supervised Pre-Training of Visual Perception Models</a> | <a href='https://github.com/facebookresearch/SWAG' target='_blank'>Github Repo</a></p>" |
|
|
|
|
|
gr.Interface(inference, inputs, outputs, title=title, description=description, article=article, examples=[['dog.jpg']]).launch(enable_queue=True,cache_examples=True) |