|  | import torch | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | model = torch.hub.load("facebookresearch/swag", model="vit_h14_in1k") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | model.eval() | 
					
						
						|  |  | 
					
						
						|  | resolution = 518 | 
					
						
						|  |  | 
					
						
						|  | import os | 
					
						
						|  | os.system("wget https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json -O in_cls_idx.json") | 
					
						
						|  |  | 
					
						
						|  | import gradio as gr | 
					
						
						|  |  | 
					
						
						|  | from PIL import Image | 
					
						
						|  | from torchvision import transforms | 
					
						
						|  |  | 
					
						
						|  | import json | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | with open("in_cls_idx.json", "r") as f: | 
					
						
						|  | imagenet_id_to_name = {int(cls_id): name for cls_id, (label, name) in json.load(f).items()} | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def load_image(image_path): | 
					
						
						|  | return Image.open(image_path).convert("RGB") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def transform_image(image, resolution): | 
					
						
						|  | transform = transforms.Compose([ | 
					
						
						|  | transforms.Resize( | 
					
						
						|  | resolution, | 
					
						
						|  | interpolation=transforms.InterpolationMode.BICUBIC, | 
					
						
						|  | ), | 
					
						
						|  | transforms.CenterCrop(resolution), | 
					
						
						|  | transforms.ToTensor(), | 
					
						
						|  | transforms.Normalize( | 
					
						
						|  | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] | 
					
						
						|  | ), | 
					
						
						|  | ]) | 
					
						
						|  | image = transform(image) | 
					
						
						|  |  | 
					
						
						|  | image = image[None, :] | 
					
						
						|  | return image | 
					
						
						|  |  | 
					
						
						|  | def visualize_and_predict(model, resolution, image_path): | 
					
						
						|  | image = load_image(image_path) | 
					
						
						|  | image = transform_image(image, resolution) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | with torch.no_grad(): | 
					
						
						|  | _, preds = model(image).topk(5) | 
					
						
						|  |  | 
					
						
						|  | preds = preds.tolist()[0] | 
					
						
						|  |  | 
					
						
						|  | return preds | 
					
						
						|  |  | 
					
						
						|  | os.system("wget https://github.com/pytorch/hub/raw/master/images/dog.jpg -O dog.jpg") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def inference(img): | 
					
						
						|  | preds = visualize_and_predict(model, resolution, img) | 
					
						
						|  | return [imagenet_id_to_name[cls_id] for cls_id in preds] | 
					
						
						|  |  | 
					
						
						|  | inputs = gr.inputs.Image(type='filepath') | 
					
						
						|  | outputs = gr.outputs.Textbox(label="Output") | 
					
						
						|  |  | 
					
						
						|  | title = "SWAG" | 
					
						
						|  |  | 
					
						
						|  | description = "Gradio demo for Revisiting Weakly Supervised Pre-Training of Visual Perception Models. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below." | 
					
						
						|  |  | 
					
						
						|  | article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2201.08371' target='_blank'>Revisiting Weakly Supervised Pre-Training of Visual Perception Models</a> | <a href='https://github.com/facebookresearch/SWAG' target='_blank'>Github Repo</a></p>" | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | gr.Interface(inference, inputs, outputs, title=title, description=description, article=article, examples=[['dog.jpg']]).launch(enable_queue=True,cache_examples=True) |