Spaces:
Sleeping
Sleeping
Andrei Boiarov
commited on
Commit
•
7440015
1
Parent(s):
688eb37
Update app file
Browse files
app.py
CHANGED
@@ -18,38 +18,32 @@ def prep_image(image):
|
|
18 |
|
19 |
|
20 |
def reconstruct(img):
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
#
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
#
|
37 |
-
|
38 |
-
|
39 |
-
#
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
#
|
48 |
-
# return [(out_masked, 'masked'), (out_rec, 'reconstruction'), (out_rec_vis, 'reconstruction + visible')]
|
49 |
-
url = "https://images.unsplash.com/photo-1507003211169-0a1dd7228f2d?ixlib=rb-1.2.1&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=387&q=80"
|
50 |
-
im = Image.open(requests.get(url, stream=True).raw)
|
51 |
-
im = np.array(im)
|
52 |
-
return [(im, 'label 1')]
|
53 |
|
54 |
|
55 |
with gr.Blocks() as demo:
|
@@ -65,7 +59,7 @@ with gr.Blocks() as demo:
|
|
65 |
# label="Generated images", show_label=False, elem_id="gallery"
|
66 |
# , columns=[3], rows=[1], height='auto', container=True)
|
67 |
|
68 |
-
gallery = gr.Gallery(columns=
|
69 |
rows=1,
|
70 |
height='800px',
|
71 |
object_fit='none')
|
|
|
18 |
|
19 |
|
20 |
def reconstruct(img):
|
21 |
+
image = Image.fromarray(img)
|
22 |
+
pixel_values = feature_extractor(image, return_tensors='pt').pixel_values
|
23 |
+
|
24 |
+
outputs = model(pixel_values)
|
25 |
+
y = model.unpatchify(outputs.logits)
|
26 |
+
y = torch.einsum('nchw->nhwc', y).detach().cpu()
|
27 |
+
|
28 |
+
# visualize the mask
|
29 |
+
mask = outputs.mask.detach()
|
30 |
+
mask = mask.unsqueeze(-1).repeat(1, 1, model.config.patch_size ** 2 * 3) # (N, H*W, p*p*3)
|
31 |
+
mask = model.unpatchify(mask) # 1 is removing, 0 is keeping
|
32 |
+
mask = torch.einsum('nchw->nhwc', mask).detach().cpu()
|
33 |
+
|
34 |
+
x = torch.einsum('nchw->nhwc', pixel_values).detach().cpu()
|
35 |
+
|
36 |
+
# masked image
|
37 |
+
im_masked = x * (1 - mask)
|
38 |
+
|
39 |
+
# MAE reconstruction pasted with visible patches
|
40 |
+
im_paste = x * (1 - mask) + y * mask
|
41 |
+
|
42 |
+
out_masked = prep_image(im_masked[0])
|
43 |
+
out_rec = prep_image(y[0])
|
44 |
+
out_rec_vis = prep_image(im_paste[0])
|
45 |
+
|
46 |
+
return [(out_masked, 'masked'), (out_rec, 'reconstruction'), (out_rec_vis, 'reconstruction + visible')]
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
|
48 |
|
49 |
with gr.Blocks() as demo:
|
|
|
59 |
# label="Generated images", show_label=False, elem_id="gallery"
|
60 |
# , columns=[3], rows=[1], height='auto', container=True)
|
61 |
|
62 |
+
gallery = gr.Gallery(columns=3,
|
63 |
rows=1,
|
64 |
height='800px',
|
65 |
object_fit='none')
|