Ahsen Khaliq commited on
Commit
bedbc77
1 Parent(s): 2d1f3e5

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -0
app.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ model = torch.hub.load("facebookresearch/swag", model="vit_h14_in1k")
5
+
6
+ # we also convert the model to eval mode
7
+ model.eval()
8
+
9
+ resolution = 518
10
+
11
+ import os
12
+ os.system("wget https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json -O in_cls_idx.json")
13
+
14
+ import gradio as gr
15
+
16
+ from PIL import Image
17
+ from torchvision import transforms
18
+ import matplotlib.pyplot as plt
19
+
20
+
21
+ def load_image(image_path):
22
+ return Image.open(image_path).convert("RGB")
23
+
24
+
25
+
26
+ def transform_image(image, resolution):
27
+ transform = transforms.Compose([
28
+ transforms.Resize(
29
+ resolution,
30
+ interpolation=transforms.InterpolationMode.BICUBIC,
31
+ ),
32
+ transforms.CenterCrop(resolution),
33
+ transforms.ToTensor(),
34
+ transforms.Normalize(
35
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
36
+ ),
37
+ ])
38
+ image = transform(image)
39
+ # we also add a batch dimension to the image since that is what the model expects
40
+ image = image[None, :]
41
+ return image
42
+
43
+ def visualize_and_predict(model, resolution, image_path):
44
+ image = load_image(image_path)
45
+ image = transform_image(image, resolution)
46
+
47
+ # we do not need to track gradients for inference
48
+ with torch.no_grad():
49
+ _, preds = model(image).topk(5)
50
+ # convert preds to a Python list and remove the batch dimension
51
+ preds = preds.tolist()[0]
52
+ return preds
53
+
54
+ os.system("wget https://github.com/pytorch/hub/raw/master/images/dog.jpg -O dog.jpg")
55
+
56
+
57
+
58
+ def inference(img):
59
+ preds = visualize_and_predict(model, resolution, img)
60
+
61
+ return preds
62
+
63
+ inputs = gr.inputs.Image(type='pil')
64
+ outputs = gr.outputs.Textbox(label="Output")
65
+
66
+ title = "SWAG"
67
+
68
+ description = "Gradio demo for Revisiting Weakly Supervised Pre-Training of Visual Perception Models. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below."
69
+
70
+ article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2201.08371' target='_blank'>Revisiting Weakly Supervised Pre-Training of Visual Perception Models</a> | <a href='https://github.com/facebookresearch/SWAG' target='_blank'>Github Repo</a></p>"
71
+
72
+ examples = ['dog.jpg']
73
+
74
+ gr.Interface(inference, inputs, outputs, title=title, description=description, article=article, analytics_enabled=False, examples=examples).launch(enable_queue=True)