Andrei Boiarov commited on
Commit
7440015
1 Parent(s): 688eb37

Update app file

Browse files
Files changed (1) hide show
  1. app.py +27 -33
app.py CHANGED
@@ -18,38 +18,32 @@ def prep_image(image):
18
 
19
 
20
  def reconstruct(img):
21
- # image = Image.fromarray(img)
22
- # pixel_values = feature_extractor(image, return_tensors='pt').pixel_values
23
- #
24
- # outputs = model(pixel_values)
25
- # y = model.unpatchify(outputs.logits)
26
- # y = torch.einsum('nchw->nhwc', y).detach().cpu()
27
- #
28
- # # visualize the mask
29
- # mask = outputs.mask.detach()
30
- # mask = mask.unsqueeze(-1).repeat(1, 1, model.config.patch_size ** 2 * 3) # (N, H*W, p*p*3)
31
- # mask = model.unpatchify(mask) # 1 is removing, 0 is keeping
32
- # mask = torch.einsum('nchw->nhwc', mask).detach().cpu()
33
- #
34
- # x = torch.einsum('nchw->nhwc', pixel_values).detach().cpu()
35
- #
36
- # # masked image
37
- # im_masked = x * (1 - mask)
38
- #
39
- # # MAE reconstruction pasted with visible patches
40
- # im_paste = x * (1 - mask) + y * mask
41
- #
42
- # out_masked = prep_image(im_masked[0])
43
- # out_rec = prep_image(y[0])
44
- # out_rec_vis = prep_image(im_paste[0])
45
- #
46
- # # out_masked, out_rec, out_rec_vis = img, img, img
47
- #
48
- # return [(out_masked, 'masked'), (out_rec, 'reconstruction'), (out_rec_vis, 'reconstruction + visible')]
49
- url = "https://images.unsplash.com/photo-1507003211169-0a1dd7228f2d?ixlib=rb-1.2.1&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=387&q=80"
50
- im = Image.open(requests.get(url, stream=True).raw)
51
- im = np.array(im)
52
- return [(im, 'label 1')]
53
 
54
 
55
  with gr.Blocks() as demo:
@@ -65,7 +59,7 @@ with gr.Blocks() as demo:
65
  # label="Generated images", show_label=False, elem_id="gallery"
66
  # , columns=[3], rows=[1], height='auto', container=True)
67
 
68
- gallery = gr.Gallery(columns=1,
69
  rows=1,
70
  height='800px',
71
  object_fit='none')
 
18
 
19
 
20
  def reconstruct(img):
21
+ image = Image.fromarray(img)
22
+ pixel_values = feature_extractor(image, return_tensors='pt').pixel_values
23
+
24
+ outputs = model(pixel_values)
25
+ y = model.unpatchify(outputs.logits)
26
+ y = torch.einsum('nchw->nhwc', y).detach().cpu()
27
+
28
+ # visualize the mask
29
+ mask = outputs.mask.detach()
30
+ mask = mask.unsqueeze(-1).repeat(1, 1, model.config.patch_size ** 2 * 3) # (N, H*W, p*p*3)
31
+ mask = model.unpatchify(mask) # 1 is removing, 0 is keeping
32
+ mask = torch.einsum('nchw->nhwc', mask).detach().cpu()
33
+
34
+ x = torch.einsum('nchw->nhwc', pixel_values).detach().cpu()
35
+
36
+ # masked image
37
+ im_masked = x * (1 - mask)
38
+
39
+ # MAE reconstruction pasted with visible patches
40
+ im_paste = x * (1 - mask) + y * mask
41
+
42
+ out_masked = prep_image(im_masked[0])
43
+ out_rec = prep_image(y[0])
44
+ out_rec_vis = prep_image(im_paste[0])
45
+
46
+ return [(out_masked, 'masked'), (out_rec, 'reconstruction'), (out_rec_vis, 'reconstruction + visible')]
 
 
 
 
 
 
47
 
48
 
49
  with gr.Blocks() as demo:
 
59
  # label="Generated images", show_label=False, elem_id="gallery"
60
  # , columns=[3], rows=[1], height='auto', container=True)
61
 
62
+ gallery = gr.Gallery(columns=3,
63
  rows=1,
64
  height='800px',
65
  object_fit='none')