Andrei Boiarov commited on
Commit
14bb247
1 Parent(s): 7440015

v1 version of app

Browse files
Files changed (2) hide show
  1. app.py +24 -63
  2. requirements.txt +1 -3
app.py CHANGED
@@ -1,25 +1,21 @@
1
- from transformers import ViTFeatureExtractor, ViTMAEForPreTraining
2
  import numpy as np
3
  import torch
4
- from PIL import Image
5
- import requests
6
-
7
  import gradio as gr
8
 
9
- feature_extractor = ViTFeatureExtractor.from_pretrained('andrewbo29/vit-mae-base-formula1')
10
  model = ViTMAEForPreTraining.from_pretrained('andrewbo29/vit-mae-base-formula1')
11
 
12
- imagenet_mean = np.array(feature_extractor.image_mean)
13
- imagenet_std = np.array(feature_extractor.image_std)
14
 
15
 
16
  def prep_image(image):
17
  return torch.clip((image * imagenet_std + imagenet_mean) * 255, 0, 255).int().cpu().numpy()
18
 
19
 
20
- def reconstruct(img):
21
- image = Image.fromarray(img)
22
- pixel_values = feature_extractor(image, return_tensors='pt').pixel_values
23
 
24
  outputs = model(pixel_values)
25
  y = model.unpatchify(outputs.logits)
@@ -39,75 +35,40 @@ def reconstruct(img):
39
  # MAE reconstruction pasted with visible patches
40
  im_paste = x * (1 - mask) + y * mask
41
 
 
42
  out_masked = prep_image(im_masked[0])
43
  out_rec = prep_image(y[0])
44
  out_rec_vis = prep_image(im_paste[0])
45
 
46
- return [(out_masked, 'masked'), (out_rec, 'reconstruction'), (out_rec_vis, 'reconstruction + visible')]
 
 
 
47
 
48
 
49
  with gr.Blocks() as demo:
50
- with gr.Column(variant="panel"):
51
- with gr.Row():
52
  img = gr.Image(
53
- label="Enter your prompt",
54
  container=False,
 
 
 
 
 
55
  )
56
- btn = gr.Button("Generate image", scale=0)
57
-
58
- # gallery = gr.Gallery(
59
- # label="Generated images", show_label=False, elem_id="gallery"
60
- # , columns=[3], rows=[1], height='auto', container=True)
61
 
62
- gallery = gr.Gallery(columns=3,
63
- rows=1,
64
- height='800px',
65
- object_fit='none')
 
 
66
 
67
  btn.click(reconstruct, img, gallery)
68
 
69
  if __name__ == "__main__":
70
  demo.launch()
71
 
72
- # import random
73
- #
74
- # import gradio as gr
75
- #
76
- #
77
- # def fake_gan():
78
- # images = [
79
- # (random.choice(
80
- # [
81
- # "https://images.unsplash.com/photo-1507003211169-0a1dd7228f2d?ixlib=rb-1.2.1&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=387&q=80",
82
- # "https://images.unsplash.com/photo-1554151228-14d9def656e4?ixlib=rb-1.2.1&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=386&q=80",
83
- # "https://images.unsplash.com/photo-1542909168-82c3e7fdca5c?ixlib=rb-1.2.1&ixid=MnwxMjA3fDB8MHxzZWFyY2h8MXx8aHVtYW4lMjBmYWNlfGVufDB8fDB8fA%3D%3D&w=1000&q=80",
84
- # "https://images.unsplash.com/photo-1546456073-92b9f0a8d413?ixlib=rb-1.2.1&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=387&q=80",
85
- # "https://images.unsplash.com/photo-1601412436009-d964bd02edbc?ixlib=rb-1.2.1&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=464&q=80",
86
- # ]
87
- # ), f"label {i}" if i != 0 else "label" * 50)
88
- # for i in range(3)
89
- # ]
90
- # return images
91
- #
92
- #
93
- # with gr.Blocks() as demo:
94
- # with gr.Column(variant="panel"):
95
- # with gr.Row():
96
- # text = gr.Textbox(
97
- # label="Enter your prompt",
98
- # max_lines=1,
99
- # placeholder="Enter your prompt",
100
- # container=False,
101
- # )
102
- # btn = gr.Button("Generate image", scale=0)
103
- #
104
- # gallery = gr.Gallery(
105
- # label="Generated images", show_label=False, elem_id="gallery"
106
- # , columns=[2], rows=[2], object_fit="contain", height="auto")
107
- #
108
- # btn.click(fake_gan, None, gallery)
109
- #
110
- # if __name__ == "__main__":
111
- # demo.launch()
112
 
113
 
 
1
+ from transformers import ViTMAEForPreTraining, ViTImageProcessor
2
  import numpy as np
3
  import torch
 
 
 
4
  import gradio as gr
5
 
6
+ image_processor = ViTImageProcessor.from_pretrained('andrewbo29/vit-mae-base-formula1')
7
  model = ViTMAEForPreTraining.from_pretrained('andrewbo29/vit-mae-base-formula1')
8
 
9
+ imagenet_mean = np.array(image_processor.image_mean)
10
+ imagenet_std = np.array(image_processor.image_std)
11
 
12
 
13
  def prep_image(image):
14
  return torch.clip((image * imagenet_std + imagenet_mean) * 255, 0, 255).int().cpu().numpy()
15
 
16
 
17
+ def reconstruct(image):
18
+ pixel_values = image_processor.preprocess(image, return_tensors='pt').pixel_values
 
19
 
20
  outputs = model(pixel_values)
21
  y = model.unpatchify(outputs.logits)
 
35
  # MAE reconstruction pasted with visible patches
36
  im_paste = x * (1 - mask) + y * mask
37
 
38
+ out_orig = prep_image(x[0])
39
  out_masked = prep_image(im_masked[0])
40
  out_rec = prep_image(y[0])
41
  out_rec_vis = prep_image(im_paste[0])
42
 
43
+ return [(out_orig, 'original'),
44
+ (out_masked, 'masked'),
45
+ (out_rec, 'reconstruction'),
46
+ (out_rec_vis, 'reconstruction + visible')]
47
 
48
 
49
  with gr.Blocks() as demo:
50
+ with gr.Column(variant='panel'):
51
+ with gr.Column():
52
  img = gr.Image(
 
53
  container=False,
54
+ type='pil'
55
+ )
56
+ btn = gr.Button(
57
+ 'Apply F1 MAE',
58
+ scale=0
59
  )
 
 
 
 
 
60
 
61
+ gallery = gr.Gallery(
62
+ columns=4,
63
+ rows=1,
64
+ height='300px',
65
+ object_fit='none'
66
+ )
67
 
68
  btn.click(reconstruct, img, gallery)
69
 
70
  if __name__ == "__main__":
71
  demo.launch()
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
 
requirements.txt CHANGED
@@ -1,4 +1,2 @@
1
  torch
2
- transformers
3
- numpy
4
- pillow
 
1
  torch
2
+ transformers