akhaliq HF staff commited on
Commit
40ad524
1 Parent(s): c80917c

Create new file

Browse files
Files changed (1) hide show
  1. app.py +193 -0
app.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ import numpy as np
5
+
6
+ import json
7
+
8
+ import captioning.utils.opts as opts
9
+ import captioning.models as models
10
+ import captioning.utils.misc as utils
11
+
12
+ import pytorch_lightning as pl
13
+
14
+ import gradio as gr
15
+
16
+
17
+ # Checkpoint class
18
+ class ModelCheckpoint(pl.callbacks.ModelCheckpoint):
19
+
20
+ def on_keyboard_interrupt(self, trainer, pl_module):
21
+ # Save model when keyboard interrupt
22
+ filepath = os.path.join(self.dirpath, self.prefix + 'interrupt.ckpt')
23
+ self._save_model(filepath)
24
+
25
+ device = 'cpu' #@param ["cuda", "cpu"] {allow-input: true}
26
+
27
+ reward = 'clips_grammar' #@param ["mle", "cider", "clips", "cider_clips", "clips_grammar"] {allow-input: true}
28
+
29
+ if reward == 'mle':
30
+ cfg = f'./configs/phase1/clipRN50_{reward}.yml'
31
+ else:
32
+ cfg = f'./configs/phase2/clipRN50_{reward}.yml'
33
+
34
+ print("Loading cfg from", cfg)
35
+
36
+ opt = opts.parse_opt(parse=False, cfg=cfg)
37
+
38
+ import gdown
39
+
40
+ if reward == "mle":
41
+ url = "https://drive.google.com/drive/folders/1hfHWDn5iXsdjB63E5zdZBAoRLWHQC3LD"
42
+ elif reward == "cider":
43
+ url = "https://drive.google.com/drive/folders/1MnSmCd8HFnBvQq_4K-q4vsVkzEw0OIOs"
44
+ elif reward == "clips":
45
+ url = "https://drive.google.com/drive/folders/1toceycN-qilHsbYjKalBLtHJck1acQVe"
46
+ elif reward == "cider_clips":
47
+ url = "https://drive.google.com/drive/folders/1toceycN-qilHsbYjKalBLtHJck1acQVe"
48
+ elif reward == "clips_grammar":
49
+ url = "https://drive.google.com/drive/folders/1nSX9aS7pPK4-OTHYtsUD_uEkwIQVIV7W"
50
+ gdown.download_folder(url, quiet=True, use_cookies=False, output="save/")
51
+
52
+ url = "https://drive.google.com/uc?id=1HNRE1MYO9wxmtMHLC8zURraoNFu157Dp"
53
+ gdown.download(url, quiet=True, use_cookies=False, output="data/")
54
+
55
+ dict_json = json.load(open('./data/cocotalk.json'))
56
+ print(dict_json.keys())
57
+
58
+ ix_to_word = dict_json['ix_to_word']
59
+ vocab_size = len(ix_to_word)
60
+ print('vocab size:', vocab_size)
61
+
62
+ seq_length = 1
63
+
64
+ opt.vocab_size = vocab_size
65
+ opt.seq_length = seq_length
66
+
67
+ opt.batch_size = 1
68
+ opt.vocab = ix_to_word
69
+ # opt.use_grammar = False
70
+
71
+ model = models.setup(opt)
72
+ del opt.vocab
73
+
74
+ ckpt_path = opt.checkpoint_path + '-last.ckpt'
75
+
76
+ print("Loading checkpoint from", ckpt_path)
77
+ raw_state_dict = torch.load(
78
+ ckpt_path,
79
+ map_location=device)
80
+
81
+ strict = True
82
+
83
+ state_dict = raw_state_dict['state_dict']
84
+
85
+ if '_vocab' in state_dict:
86
+ model.vocab = utils.deserialize(state_dict['_vocab'])
87
+ del state_dict['_vocab']
88
+ elif strict:
89
+ raise KeyError
90
+ if '_opt' in state_dict:
91
+ saved_model_opt = utils.deserialize(state_dict['_opt'])
92
+ del state_dict['_opt']
93
+ # Make sure the saved opt is compatible with the curren topt
94
+ need_be_same = ["caption_model",
95
+ "rnn_type", "rnn_size", "num_layers"]
96
+ for checkme in need_be_same:
97
+ if getattr(saved_model_opt, checkme) in ['updown', 'topdown'] and \
98
+ getattr(opt, checkme) in ['updown', 'topdown']:
99
+ continue
100
+ assert getattr(saved_model_opt, checkme) == getattr(
101
+ opt, checkme), "Command line argument and saved model disagree on '%s' " % checkme
102
+ elif strict:
103
+ raise KeyError
104
+ res = model.load_state_dict(state_dict, strict)
105
+ print(res)
106
+
107
+ model = model.to(device)
108
+ model.eval();
109
+
110
+ import clip
111
+ from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
112
+ from PIL import Image
113
+ from timm.models.vision_transformer import resize_pos_embed
114
+
115
+ clip_model, clip_transform = clip.load("RN50", jit=False, device=device)
116
+
117
+ preprocess = Compose([
118
+ Resize((448, 448), interpolation=Image.BICUBIC),
119
+ CenterCrop((448, 448)),
120
+ ToTensor()
121
+ ])
122
+
123
+ image_mean = torch.Tensor([0.48145466, 0.4578275, 0.40821073]).to(device).reshape(3, 1, 1)
124
+ image_std = torch.Tensor([0.26862954, 0.26130258, 0.27577711]).to(device).reshape(3, 1, 1)
125
+
126
+ num_patches = 196 #600 * 1000 // 32 // 32
127
+ pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, clip_model.visual.attnpool.positional_embedding.shape[-1], device=device),)
128
+ pos_embed.weight = resize_pos_embed(clip_model.visual.attnpool.positional_embedding.unsqueeze(0), pos_embed)
129
+ clip_model.visual.attnpool.positional_embedding = pos_embed
130
+
131
+ def inference(img):
132
+
133
+ with torch.no_grad():
134
+ image = preprocess(img)
135
+ image = torch.tensor(np.stack([image])).to(device)
136
+ image -= image_mean
137
+ image /= image_std
138
+
139
+ tmp_att, tmp_fc = clip_model.encode_image(image)
140
+ tmp_att = tmp_att[0].permute(1, 2, 0)
141
+ tmp_fc = tmp_fc[0]
142
+
143
+ att_feat = tmp_att
144
+ fc_feat = tmp_fc
145
+
146
+
147
+ # Inference configurations
148
+ eval_kwargs = {}
149
+ eval_kwargs.update(vars(opt))
150
+
151
+ verbose = eval_kwargs.get('verbose', True)
152
+ verbose_beam = eval_kwargs.get('verbose_beam', 0)
153
+ verbose_loss = eval_kwargs.get('verbose_loss', 1)
154
+
155
+ # dataset = eval_kwargs.get('dataset', 'coco')
156
+ beam_size = eval_kwargs.get('beam_size', 1)
157
+ sample_n = eval_kwargs.get('sample_n', 1)
158
+ remove_bad_endings = eval_kwargs.get('remove_bad_endings', 0)
159
+
160
+ with torch.no_grad():
161
+ fc_feats = torch.zeros((1,0)).to(device)
162
+ att_feats = att_feat.view(1, 196, 2048).float().to(device)
163
+ att_masks = None
164
+
165
+ # forward the model to also get generated samples for each image
166
+ # Only leave one feature for each image, in case duplicate sample
167
+ tmp_eval_kwargs = eval_kwargs.copy()
168
+ tmp_eval_kwargs.update({'sample_n': 1})
169
+ seq, seq_logprobs = model(
170
+ fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample')
171
+ seq = seq.data
172
+
173
+ sents = utils.decode_sequence(model.vocab, seq)
174
+
175
+ return sents
176
+
177
+ demo = gr.Blocks()
178
+
179
+ with demo:
180
+ gr.Markdown(
181
+ """
182
+ # Demo for CLIP-Caption-Reward
183
+ """)
184
+ inp = gr.Image(type="pil")
185
+ out = gr.Textbox()
186
+
187
+ image_button = gr.Button("Run")
188
+ image_button.click(fn=inference
189
+ inputs=inp,
190
+ outputs=out)
191
+
192
+
193
+ demo.launch()