milyiyo commited on
Commit
ec81c7a
1 Parent(s): e5a3796

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +171 -0
app.py CHANGED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ import json
5
+ import captioning.utils.opts as opts
6
+ import captioning.models as models
7
+ import captioning.utils.misc as utils
8
+ import pytorch_lightning as pl
9
+ import gradio as gr
10
+
11
+
12
+ # Checkpoint class
13
+ class ModelCheckpoint(pl.callbacks.ModelCheckpoint):
14
+ def on_keyboard_interrupt(self, trainer, pl_module):
15
+ # Save model when keyboard interrupt
16
+ filepath = os.path.join(self.dirpath, self.prefix + 'interrupt.ckpt')
17
+ self._save_model(filepath)
18
+
19
+ device = 'cpu' #@param ["cuda", "cpu"] {allow-input: true}
20
+ reward = 'clips_grammar'
21
+
22
+ cfg = f'./configs/phase2/clipRN50_{reward}.yml'
23
+
24
+ print("Loading cfg from", cfg)
25
+
26
+ opt = opts.parse_opt(parse=False, cfg=cfg)
27
+
28
+ import gdown
29
+
30
+ url = "https://drive.google.com/drive/folders/1nSX9aS7pPK4-OTHYtsUD_uEkwIQVIV7W"
31
+ gdown.download_folder(url, quiet=True, use_cookies=False, output="save/")
32
+
33
+ url = "https://drive.google.com/uc?id=1HNRE1MYO9wxmtMHLC8zURraoNFu157Dp"
34
+ gdown.download(url, quiet=True, use_cookies=False, output="data/")
35
+
36
+ dict_json = json.load(open('./data/cocotalk.json'))
37
+ print(dict_json.keys())
38
+
39
+ ix_to_word = dict_json['ix_to_word']
40
+ vocab_size = len(ix_to_word)
41
+ print('vocab size:', vocab_size)
42
+
43
+ seq_length = 1
44
+
45
+ opt.vocab_size = vocab_size
46
+ opt.seq_length = seq_length
47
+
48
+ opt.batch_size = 1
49
+ opt.vocab = ix_to_word
50
+ # opt.use_grammar = False
51
+
52
+ model = models.setup(opt)
53
+ del opt.vocab
54
+
55
+ ckpt_path = opt.checkpoint_path + '-last.ckpt'
56
+
57
+ print("Loading checkpoint from", ckpt_path)
58
+ raw_state_dict = torch.load(
59
+ ckpt_path,
60
+ map_location=device)
61
+
62
+ strict = True
63
+
64
+ state_dict = raw_state_dict['state_dict']
65
+
66
+ if '_vocab' in state_dict:
67
+ model.vocab = utils.deserialize(state_dict['_vocab'])
68
+ del state_dict['_vocab']
69
+ elif strict:
70
+ raise KeyError
71
+ if '_opt' in state_dict:
72
+ saved_model_opt = utils.deserialize(state_dict['_opt'])
73
+ del state_dict['_opt']
74
+ # Make sure the saved opt is compatible with the curren topt
75
+ need_be_same = ["caption_model",
76
+ "rnn_type", "rnn_size", "num_layers"]
77
+ for checkme in need_be_same:
78
+ if getattr(saved_model_opt, checkme) in ['updown', 'topdown'] and \
79
+ getattr(opt, checkme) in ['updown', 'topdown']:
80
+ continue
81
+ assert getattr(saved_model_opt, checkme) == getattr(
82
+ opt, checkme), "Command line argument and saved model disagree on '%s' " % checkme
83
+ elif strict:
84
+ raise KeyError
85
+ res = model.load_state_dict(state_dict, strict)
86
+ print(res)
87
+
88
+ model = model.to(device)
89
+ model.eval();
90
+
91
+ import clip
92
+ from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
93
+ from PIL import Image
94
+ from timm.models.vision_transformer import resize_pos_embed
95
+
96
+ clip_model, clip_transform = clip.load("RN50", jit=False, device=device)
97
+
98
+ preprocess = Compose([
99
+ Resize((448, 448), interpolation=Image.BICUBIC),
100
+ CenterCrop((448, 448)),
101
+ ToTensor()
102
+ ])
103
+
104
+ image_mean = torch.Tensor([0.48145466, 0.4578275, 0.40821073]).to(device).reshape(3, 1, 1)
105
+ image_std = torch.Tensor([0.26862954, 0.26130258, 0.27577711]).to(device).reshape(3, 1, 1)
106
+
107
+ num_patches = 196 #600 * 1000 // 32 // 32
108
+ pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, clip_model.visual.attnpool.positional_embedding.shape[-1], device=device),)
109
+ pos_embed.weight = resize_pos_embed(clip_model.visual.attnpool.positional_embedding.unsqueeze(0), pos_embed)
110
+ clip_model.visual.attnpool.positional_embedding = pos_embed
111
+
112
+
113
+ # End below
114
+
115
+ def generate_image(img, steps=100, seed=42, guidance_scale=6.0):
116
+
117
+ with torch.no_grad():
118
+ image = preprocess(img)
119
+ image = torch.tensor(np.stack([image])).to(device)
120
+ image -= image_mean
121
+ image /= image_std
122
+
123
+ tmp_att, tmp_fc = clip_model.encode_image(image)
124
+ tmp_att = tmp_att[0].permute(1, 2, 0)
125
+ tmp_fc = tmp_fc[0]
126
+
127
+ att_feat = tmp_att
128
+ fc_feat = tmp_fc
129
+
130
+
131
+ # Inference configurations
132
+ eval_kwargs = {}
133
+ eval_kwargs.update(vars(opt))
134
+
135
+ verbose = eval_kwargs.get('verbose', True)
136
+ verbose_beam = eval_kwargs.get('verbose_beam', 0)
137
+ verbose_loss = eval_kwargs.get('verbose_loss', 1)
138
+
139
+ # dataset = eval_kwargs.get('dataset', 'coco')
140
+ beam_size = eval_kwargs.get('beam_size', 1)
141
+ sample_n = eval_kwargs.get('sample_n', 1)
142
+ remove_bad_endings = eval_kwargs.get('remove_bad_endings', 0)
143
+
144
+ with torch.no_grad():
145
+ fc_feats = torch.zeros((1,0)).to(device)
146
+ att_feats = att_feat.view(1, 196, 2048).float().to(device)
147
+ att_masks = None
148
+
149
+ # forward the model to also get generated samples for each image
150
+ # Only leave one feature for each image, in case duplicate sample
151
+ tmp_eval_kwargs = eval_kwargs.copy()
152
+ tmp_eval_kwargs.update({'sample_n': 1})
153
+ seq, seq_logprobs = model(
154
+ fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample')
155
+ seq = seq.data
156
+
157
+ sents = utils.decode_sequence(model.vocab, seq)
158
+
159
+ return sents[0]
160
+
161
+ gr.Interface(
162
+ generate_image,
163
+ inputs=[
164
+ gr.Image(type="pil"),
165
+ gr.inputs.Slider(1, 100, label='Inference Steps', default=50, step=1),
166
+ gr.inputs.Slider(0, 2147483647, label='Seed', default=random_seed, step=1),
167
+ gr.inputs.Slider(1.0, 20.0, label='Guidance Scale - how much the prompt will influence the results', default=6.0, step=0.1),
168
+ ],
169
+ outputs=gr.Image(shape=[256,256], type="pil", elem_id="output_image"),
170
+ css="#output_image{width: 256px}",
171
+ ).launch()