import torch model = torch.hub.load("facebookresearch/swag", model="vit_h14_in1k") # we also convert the model to eval mode model.eval() resolution = 518 import os os.system("wget https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json -O in_cls_idx.json") import gradio as gr from PIL import Image from torchvision import transforms import json with open("in_cls_idx.json", "r") as f: imagenet_id_to_name = {int(cls_id): name for cls_id, (label, name) in json.load(f).items()} def load_image(image_path): return Image.open(image_path).convert("RGB") def transform_image(image, resolution): transform = transforms.Compose([ transforms.Resize( resolution, interpolation=transforms.InterpolationMode.BICUBIC, ), transforms.CenterCrop(resolution), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ), ]) image = transform(image) # we also add a batch dimension to the image since that is what the model expects image = image[None, :] return image def visualize_and_predict(model, resolution, image_path): image = load_image(image_path) image = transform_image(image, resolution) # we do not need to track gradients for inference with torch.no_grad(): _, preds = model(image).topk(5) # convert preds to a Python list and remove the batch dimension preds = preds.tolist()[0] return preds os.system("wget https://github.com/pytorch/hub/raw/master/images/dog.jpg -O dog.jpg") def inference(img): preds = visualize_and_predict(model, resolution, img) return [imagenet_id_to_name[cls_id] for cls_id in preds] inputs = gr.inputs.Image(type='filepath') outputs = gr.outputs.Textbox(label="Output") title = "SWAG" description = "Gradio demo for Revisiting Weakly Supervised Pre-Training of Visual Perception Models. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below." article = "

Revisiting Weakly Supervised Pre-Training of Visual Perception Models | Github Repo

" gr.Interface(inference, inputs, outputs, title=title, description=description, article=article, examples=[['dog.jpg']]).launch(enable_queue=True,cache_examples=True)