Andrei Boiarov commited on
Commit
859c3ef
1 Parent(s): 3d04a5c

Update app file

Browse files
Files changed (3) hide show
  1. .gitignore +2 -1
  2. app.py +61 -4
  3. requirements.txt +2 -0
.gitignore CHANGED
@@ -1 +1,2 @@
1
- .idea/
 
 
1
+ .idea/
2
+ flagged/
app.py CHANGED
@@ -1,7 +1,64 @@
 
 
 
 
 
1
  import gradio as gr
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
 
1
+ from transformers import ViTFeatureExtractor, ViTMAEForPreTraining
2
+ import numpy as np
3
+ import torch
4
+ from PIL import Image
5
+
6
  import gradio as gr
7
 
8
+ feature_extractor = ViTFeatureExtractor.from_pretrained('andrewbo29/vit-mae-base-formula1')
9
+ model = ViTMAEForPreTraining.from_pretrained('andrewbo29/vit-mae-base-formula1')
10
+
11
+ imagenet_mean = np.array(feature_extractor.image_mean)
12
+ imagenet_std = np.array(feature_extractor.image_std)
13
+
14
+
15
+ def prep_image(image):
16
+ return torch.clip((image * imagenet_std + imagenet_mean) * 255, 0, 255).int().cpu().numpy()
17
+
18
+
19
+ def reconstruct(img):
20
+ image = Image.fromarray(img)
21
+ pixel_values = feature_extractor(image, return_tensors='pt').pixel_values
22
+
23
+ outputs = model(pixel_values)
24
+ y = model.unpatchify(outputs.logits)
25
+ y = torch.einsum('nchw->nhwc', y).detach().cpu()
26
+
27
+ # visualize the mask
28
+ mask = outputs.mask.detach()
29
+ mask = mask.unsqueeze(-1).repeat(1, 1, model.config.patch_size ** 2 * 3) # (N, H*W, p*p*3)
30
+ mask = model.unpatchify(mask) # 1 is removing, 0 is keeping
31
+ mask = torch.einsum('nchw->nhwc', mask).detach().cpu()
32
+
33
+ x = torch.einsum('nchw->nhwc', pixel_values).detach().cpu()
34
+
35
+ # masked image
36
+ im_masked = x * (1 - mask)
37
+
38
+ # MAE reconstruction pasted with visible patches
39
+ im_paste = x * (1 - mask) + y * mask
40
+
41
+ out_masked = prep_image(im_masked[0])
42
+ out_rec = prep_image(y[0])
43
+ out_rec_vis = prep_image(im_paste[0])
44
+
45
+ return [(out_masked, 'masked'), (out_rec, 'reconstruction'), (out_rec_vis, 'reconstruction + visible')]
46
+
47
+
48
+ with gr.Blocks() as demo:
49
+ with gr.Column(variant="panel"):
50
+ with gr.Row():
51
+ img = gr.Image(
52
+ label="Enter your prompt",
53
+ container=False,
54
+ )
55
+ btn = gr.Button("Generate image", scale=0)
56
+
57
+ gallery = gr.Gallery(
58
+ label="Generated images", show_label=False, elem_id="gallery"
59
+ , columns=[3], rows=[1], object_fit="contain", height='auto', container=True)
60
+
61
+ btn.click(reconstruct, img, gallery)
62
 
63
+ if __name__ == "__main__":
64
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ torch
2
+ transformers