tennant commited on
Commit
35d1065
1 Parent(s): 1ceff34
Files changed (2) hide show
  1. app.py +11 -7
  2. cat.jpeg +0 -0
app.py CHANGED
@@ -25,12 +25,12 @@ if torch.cuda.is_available():
25
  model.eval()
26
 
27
  @torch.no_grad()
28
- def visual_recon(x, model):
29
  target = model.patchify(x)
30
  mean = target.mean(dim=-1, keepdim=True)
31
  var = target.var(dim=-1, keepdim=True)
32
 
33
- latent, mask, ids_restore, _ = model.forward_encoder(x, mask_ratio=0.75)
34
  y, _ = model.forward_decoder(latent, ids_restore)
35
  y = y * (var + 1.e-6)**.5 + mean
36
  y = model.unpatchify(y)
@@ -82,7 +82,7 @@ def caption(max_len, latent, model, tokenizer, prefix='a photo of a'):
82
  return ' '.join(words)
83
 
84
 
85
- def gr_caption(x):
86
  imagenet_mean = np.array([0.485, 0.456, 0.406])
87
  imagenet_std = np.array([0.229, 0.224, 0.225])
88
  x = np.array(x) / 255.
@@ -100,8 +100,8 @@ def gr_caption(x):
100
  img = img * imagenet_std + imagenet_mean
101
  return np.clip(img, a_min=0., a_max=1.)
102
 
103
- masked, masked_recon, recon, latent = visual_recon(x, model)
104
- caption_from_model = caption(20, latent, model, tokenizer, )
105
 
106
  masked, masked_recon, recon = map(unnorm_pix, (masked, masked_recon, recon))
107
  return_img = np.concatenate([masked, masked_recon, recon], axis=1)
@@ -111,8 +111,12 @@ def gr_caption(x):
111
  import gradio as gr
112
 
113
  demo = gr.Interface(gr_caption,
114
- inputs=[gr.Image(shape=(224, 224))],
 
 
 
115
  outputs=[gr.Image(shape=(224, 224 * 3)),
116
- 'text'])
 
117
  demo.launch()
118
 
25
  model.eval()
26
 
27
  @torch.no_grad()
28
+ def visual_recon(x, model, mask_ratio=0.75):
29
  target = model.patchify(x)
30
  mean = target.mean(dim=-1, keepdim=True)
31
  var = target.var(dim=-1, keepdim=True)
32
 
33
+ latent, mask, ids_restore, _ = model.forward_encoder(x, mask_ratio=mask_ratio)
34
  y, _ = model.forward_decoder(latent, ids_restore)
35
  y = y * (var + 1.e-6)**.5 + mean
36
  y = model.unpatchify(y)
82
  return ' '.join(words)
83
 
84
 
85
+ def gr_caption(x, mask_ratio=0.75, max_len=20, prefix='a'):
86
  imagenet_mean = np.array([0.485, 0.456, 0.406])
87
  imagenet_std = np.array([0.229, 0.224, 0.225])
88
  x = np.array(x) / 255.
100
  img = img * imagenet_std + imagenet_mean
101
  return np.clip(img, a_min=0., a_max=1.)
102
 
103
+ masked, masked_recon, recon, latent = visual_recon(x, model, mask_ratio=mask_ratio)
104
+ caption_from_model = caption(max_len, latent, model, tokenizer, prefix=prefix)
105
 
106
  masked, masked_recon, recon = map(unnorm_pix, (masked, masked_recon, recon))
107
  return_img = np.concatenate([masked, masked_recon, recon], axis=1)
111
  import gradio as gr
112
 
113
  demo = gr.Interface(gr_caption,
114
+ inputs=[gr.Image(shape=(224, 224)),
115
+ 'number',
116
+ 'number',
117
+ 'text'],
118
  outputs=[gr.Image(shape=(224, 224 * 3)),
119
+ 'text'],
120
+ examples=[['cat.jpeg', 0.75, 20, 'a photo of a']],)
121
  demo.launch()
122
 
cat.jpeg ADDED