Andrei Boiarov commited on
Commit
581506e
1 Parent(s): aa94b73

Update app filew

Browse files
Files changed (1) hide show
  1. app.py +95 -97
app.py CHANGED
@@ -1,117 +1,115 @@
1
- # from transformers import ViTFeatureExtractor, ViTMAEForPreTraining
2
- # import numpy as np
3
- # import torch
4
- # from PIL import Image
5
- #
6
- # import gradio as gr
7
- #
8
- # feature_extractor = ViTFeatureExtractor.from_pretrained('andrewbo29/vit-mae-base-formula1')
9
- # model = ViTMAEForPreTraining.from_pretrained('andrewbo29/vit-mae-base-formula1')
10
- #
11
- # imagenet_mean = np.array(feature_extractor.image_mean)
12
- # imagenet_std = np.array(feature_extractor.image_std)
13
- #
14
- #
15
- # def prep_image(image):
16
- # return torch.clip((image * imagenet_std + imagenet_mean) * 255, 0, 255).int().cpu().numpy()
17
- #
18
- #
19
- # def reconstruct(img):
20
- # image = Image.fromarray(img)
21
- # pixel_values = feature_extractor(image, return_tensors='pt').pixel_values
22
- #
23
- # outputs = model(pixel_values)
24
- # y = model.unpatchify(outputs.logits)
25
- # y = torch.einsum('nchw->nhwc', y).detach().cpu()
26
- #
27
- # # visualize the mask
28
- # mask = outputs.mask.detach()
29
- # mask = mask.unsqueeze(-1).repeat(1, 1, model.config.patch_size ** 2 * 3) # (N, H*W, p*p*3)
30
- # mask = model.unpatchify(mask) # 1 is removing, 0 is keeping
31
- # mask = torch.einsum('nchw->nhwc', mask).detach().cpu()
32
- #
33
- # x = torch.einsum('nchw->nhwc', pixel_values).detach().cpu()
34
- #
35
- # # masked image
36
- # im_masked = x * (1 - mask)
37
- #
38
- # # MAE reconstruction pasted with visible patches
39
- # im_paste = x * (1 - mask) + y * mask
40
- #
41
- # out_masked = prep_image(im_masked[0])
42
- # out_rec = prep_image(y[0])
43
- # out_rec_vis = prep_image(im_paste[0])
44
- #
45
- # # out_masked, out_rec, out_rec_vis = img, img, img
46
- #
47
- # return [(out_masked, 'masked'), (out_rec, 'reconstruction'), (out_rec_vis, 'reconstruction + visible')]
48
- # # return [(img, '1')]
49
- #
50
- #
51
- # with gr.Blocks() as demo:
52
- # with gr.Column(variant="panel"):
53
- # with gr.Row():
54
- # img = gr.Image(
55
- # label="Enter your prompt",
56
- # container=False,
57
- # )
58
- # btn = gr.Button("Generate image", scale=0)
59
- #
60
- # # gallery = gr.Gallery(
61
- # # label="Generated images", show_label=False, elem_id="gallery"
62
- # # , columns=[3], rows=[1], height='auto', container=True)
63
- #
64
- # gallery = gr.Gallery(columns=3,
65
- # rows=1,
66
- # height='800px',
67
- # object_fit='none')
68
- #
69
- # btn.click(reconstruct, img, gallery)
70
- #
71
- # if __name__ == "__main__":
72
- # demo.launch()
73
-
74
- # This demo needs to be run from the repo folder.
75
- # python demo/fake_gan/run.py
76
- import random
77
 
78
  import gradio as gr
79
 
 
 
 
 
 
 
 
 
 
80
 
81
- def fake_gan():
82
- images = [
83
- (random.choice(
84
- [
85
- "https://images.unsplash.com/photo-1507003211169-0a1dd7228f2d?ixlib=rb-1.2.1&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=387&q=80",
86
- "https://images.unsplash.com/photo-1554151228-14d9def656e4?ixlib=rb-1.2.1&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=386&q=80",
87
- "https://images.unsplash.com/photo-1542909168-82c3e7fdca5c?ixlib=rb-1.2.1&ixid=MnwxMjA3fDB8MHxzZWFyY2h8MXx8aHVtYW4lMjBmYWNlfGVufDB8fDB8fA%3D%3D&w=1000&q=80",
88
- "https://images.unsplash.com/photo-1546456073-92b9f0a8d413?ixlib=rb-1.2.1&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=387&q=80",
89
- "https://images.unsplash.com/photo-1601412436009-d964bd02edbc?ixlib=rb-1.2.1&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=464&q=80",
90
- ]
91
- ), f"label {i}" if i != 0 else "label" * 50)
92
- for i in range(3)
93
- ]
94
- return images
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
 
97
  with gr.Blocks() as demo:
98
  with gr.Column(variant="panel"):
99
  with gr.Row():
100
- text = gr.Textbox(
101
  label="Enter your prompt",
102
- max_lines=1,
103
- placeholder="Enter your prompt",
104
  container=False,
105
  )
106
  btn = gr.Button("Generate image", scale=0)
107
 
108
- gallery = gr.Gallery(
109
- label="Generated images", show_label=False, elem_id="gallery"
110
- , columns=[2], rows=[2], object_fit="contain", height="auto")
 
 
 
 
 
111
 
112
- btn.click(fake_gan, None, gallery)
113
 
114
  if __name__ == "__main__":
115
  demo.launch()
116
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
 
1
+ from transformers import ViTFeatureExtractor, ViTMAEForPreTraining
2
+ import numpy as np
3
+ import torch
4
+ from PIL import Image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  import gradio as gr
7
 
8
+ feature_extractor = ViTFeatureExtractor.from_pretrained('andrewbo29/vit-mae-base-formula1')
9
+ model = ViTMAEForPreTraining.from_pretrained('andrewbo29/vit-mae-base-formula1')
10
+
11
+ imagenet_mean = np.array(feature_extractor.image_mean)
12
+ imagenet_std = np.array(feature_extractor.image_std)
13
+
14
+
15
+ def prep_image(image):
16
+ return torch.clip((image * imagenet_std + imagenet_mean) * 255, 0, 255).int().cpu().numpy()
17
 
18
+
19
+ def reconstruct(img):
20
+ # image = Image.fromarray(img)
21
+ # pixel_values = feature_extractor(image, return_tensors='pt').pixel_values
22
+ #
23
+ # outputs = model(pixel_values)
24
+ # y = model.unpatchify(outputs.logits)
25
+ # y = torch.einsum('nchw->nhwc', y).detach().cpu()
26
+ #
27
+ # # visualize the mask
28
+ # mask = outputs.mask.detach()
29
+ # mask = mask.unsqueeze(-1).repeat(1, 1, model.config.patch_size ** 2 * 3) # (N, H*W, p*p*3)
30
+ # mask = model.unpatchify(mask) # 1 is removing, 0 is keeping
31
+ # mask = torch.einsum('nchw->nhwc', mask).detach().cpu()
32
+ #
33
+ # x = torch.einsum('nchw->nhwc', pixel_values).detach().cpu()
34
+ #
35
+ # # masked image
36
+ # im_masked = x * (1 - mask)
37
+ #
38
+ # # MAE reconstruction pasted with visible patches
39
+ # im_paste = x * (1 - mask) + y * mask
40
+ #
41
+ # out_masked = prep_image(im_masked[0])
42
+ # out_rec = prep_image(y[0])
43
+ # out_rec_vis = prep_image(im_paste[0])
44
+ #
45
+ # # out_masked, out_rec, out_rec_vis = img, img, img
46
+ #
47
+ # return [(out_masked, 'masked'), (out_rec, 'reconstruction'), (out_rec_vis, 'reconstruction + visible')]
48
+ return [(img, 'label 1')]
49
 
50
 
51
  with gr.Blocks() as demo:
52
  with gr.Column(variant="panel"):
53
  with gr.Row():
54
+ img = gr.Image(
55
  label="Enter your prompt",
 
 
56
  container=False,
57
  )
58
  btn = gr.Button("Generate image", scale=0)
59
 
60
+ # gallery = gr.Gallery(
61
+ # label="Generated images", show_label=False, elem_id="gallery"
62
+ # , columns=[3], rows=[1], height='auto', container=True)
63
+
64
+ gallery = gr.Gallery(columns=1,
65
+ rows=1,
66
+ height='800px',
67
+ object_fit='none')
68
 
69
+ btn.click(reconstruct, img, gallery)
70
 
71
  if __name__ == "__main__":
72
  demo.launch()
73
 
74
+ # import random
75
+ #
76
+ # import gradio as gr
77
+ #
78
+ #
79
+ # def fake_gan():
80
+ # images = [
81
+ # (random.choice(
82
+ # [
83
+ # "https://images.unsplash.com/photo-1507003211169-0a1dd7228f2d?ixlib=rb-1.2.1&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=387&q=80",
84
+ # "https://images.unsplash.com/photo-1554151228-14d9def656e4?ixlib=rb-1.2.1&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=386&q=80",
85
+ # "https://images.unsplash.com/photo-1542909168-82c3e7fdca5c?ixlib=rb-1.2.1&ixid=MnwxMjA3fDB8MHxzZWFyY2h8MXx8aHVtYW4lMjBmYWNlfGVufDB8fDB8fA%3D%3D&w=1000&q=80",
86
+ # "https://images.unsplash.com/photo-1546456073-92b9f0a8d413?ixlib=rb-1.2.1&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=387&q=80",
87
+ # "https://images.unsplash.com/photo-1601412436009-d964bd02edbc?ixlib=rb-1.2.1&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=464&q=80",
88
+ # ]
89
+ # ), f"label {i}" if i != 0 else "label" * 50)
90
+ # for i in range(3)
91
+ # ]
92
+ # return images
93
+ #
94
+ #
95
+ # with gr.Blocks() as demo:
96
+ # with gr.Column(variant="panel"):
97
+ # with gr.Row():
98
+ # text = gr.Textbox(
99
+ # label="Enter your prompt",
100
+ # max_lines=1,
101
+ # placeholder="Enter your prompt",
102
+ # container=False,
103
+ # )
104
+ # btn = gr.Button("Generate image", scale=0)
105
+ #
106
+ # gallery = gr.Gallery(
107
+ # label="Generated images", show_label=False, elem_id="gallery"
108
+ # , columns=[2], rows=[2], object_fit="contain", height="auto")
109
+ #
110
+ # btn.click(fake_gan, None, gallery)
111
+ #
112
+ # if __name__ == "__main__":
113
+ # demo.launch()
114
+
115