File size: 6,084 Bytes
2773b59
 
 
 
 
 
 
 
ec81c7a
 
2773b59
cff97d1
2773b59
cff97d1
ec81c7a
2773b59
 
 
 
 
 
ec81c7a
6d1b898
2773b59
ec81c7a
2773b59
ec81c7a
2773b59
ec81c7a
2773b59
ec81c7a
2773b59
ec81c7a
2773b59
 
ec81c7a
2773b59
 
ec81c7a
2773b59
 
ec81c7a
2773b59
 
 
ec81c7a
2773b59
ec81c7a
2773b59
 
ec81c7a
2773b59
 
ec81c7a
2773b59
 
ec81c7a
2773b59
ec81c7a
2773b59
 
 
 
ec81c7a
2773b59
ec81c7a
2773b59
ec81c7a
2773b59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ec81c7a
2773b59
 
ec81c7a
2773b59
6d1b898
2773b59
 
ec81c7a
2773b59
ec81c7a
2773b59
2f08c6b
2773b59
 
 
ec81c7a
2773b59
 
ec81c7a
2773b59
 
 
 
ec81c7a
 
2773b59
 
 
cff97d1
2773b59
 
 
 
 
 
 
 
 
 
 
 
 
 
cff97d1
2773b59
 
 
cff97d1
2773b59
ec81c7a
2773b59
 
 
ec81c7a
2773b59
 
 
 
cff97d1
2773b59
 
 
 
6d1b898
2773b59
 
cff97d1
2773b59
cff97d1
2773b59
cff97d1
 
2773b59
 
 
ea753ea
4912cf9
2773b59
 
cff97d1
 
 
ec81c7a
 
2773b59
4912cf9
7ec5667
 
 
 
 
 
 
 
cff97d1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import torch
import torch.nn as nn
import numpy as np
import json
import captioning.utils.opts as opts
import captioning.models as models
import captioning.utils.misc as utils
import pytorch_lightning as pl
import gradio as gr

from diffusers import LDMTextToImagePipeline
import random
import os


# Checkpoint class
class ModelCheckpoint(pl.callbacks.ModelCheckpoint):
    def on_keyboard_interrupt(self, trainer, pl_module):
        # Save model when keyboard interrupt
        filepath = os.path.join(self.dirpath, self.prefix + 'interrupt.ckpt')
        self._save_model(filepath)
        
device = 'cpu'
reward = 'clips_grammar'

cfg = f'./configs/phase2/clipRN50_{reward}.yml'

print("Loading cfg from", cfg)

opt = opts.parse_opt(parse=False, cfg=cfg)

import gdown

url = "https://drive.google.com/drive/folders/1nSX9aS7pPK4-OTHYtsUD_uEkwIQVIV7W"
gdown.download_folder(url, quiet=True, use_cookies=False, output="save/")

url = "https://drive.google.com/uc?id=1HNRE1MYO9wxmtMHLC8zURraoNFu157Dp"
gdown.download(url, quiet=True, use_cookies=False, output="data/")

dict_json = json.load(open('./data/cocotalk.json'))
print(dict_json.keys())

ix_to_word = dict_json['ix_to_word']
vocab_size = len(ix_to_word)
print('vocab size:', vocab_size)

seq_length = 1

opt.vocab_size = vocab_size
opt.seq_length = seq_length

opt.batch_size = 1
opt.vocab = ix_to_word

model = models.setup(opt)
del opt.vocab

ckpt_path = opt.checkpoint_path + '-last.ckpt'

print("Loading checkpoint from", ckpt_path)
raw_state_dict = torch.load(
    ckpt_path,
    map_location=device)

strict = True

state_dict = raw_state_dict['state_dict']

if '_vocab' in state_dict:
    model.vocab = utils.deserialize(state_dict['_vocab'])
    del state_dict['_vocab']
elif strict:
    raise KeyError
if '_opt' in state_dict:
    saved_model_opt = utils.deserialize(state_dict['_opt'])
    del state_dict['_opt']
    # Make sure the saved opt is compatible with the curren topt
    need_be_same = ["caption_model",
                    "rnn_type", "rnn_size", "num_layers"]
    for checkme in need_be_same:
        if getattr(saved_model_opt, checkme) in ['updown', 'topdown'] and \
                getattr(opt, checkme) in ['updown', 'topdown']:
            continue
        assert getattr(saved_model_opt, checkme) == getattr(
            opt, checkme), "Command line argument and saved model disagree on '%s' " % checkme
elif strict:
    raise KeyError
res = model.load_state_dict(state_dict, strict)
print(res)

model = model.to(device)
model.eval();

import clip
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor
from PIL import Image
from timm.models.vision_transformer import resize_pos_embed

clip_model, clip_transform = clip.load("RN50", jit=False, device=device)

preprocess = Compose([
    Resize((448, 448), interpolation=Image.Resampling.BICUBIC),
    CenterCrop((448, 448)),
    ToTensor()
])

image_mean = torch.Tensor([0.48145466, 0.4578275, 0.40821073]).to(device).reshape(3, 1, 1)
image_std = torch.Tensor([0.26862954, 0.26130258, 0.27577711]).to(device).reshape(3, 1, 1)

num_patches = 196 #600 * 1000 // 32 // 32
pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, clip_model.visual.attnpool.positional_embedding.shape[-1],  device=device),)
pos_embed.weight = resize_pos_embed(clip_model.visual.attnpool.positional_embedding.unsqueeze(0), pos_embed)
clip_model.visual.attnpool.positional_embedding = pos_embed


# End below
print('Loading the model: CompVis/ldm-text2im-large-256')
ldm_pipeline = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256")

def generate_image_from_text(prompt, steps=100, seed=42, guidance_scale=6.0):
    print('RUN: generate_image_from_text')
    torch.cuda.empty_cache()
    generator = torch.manual_seed(seed)
    images = ldm_pipeline([prompt], generator=generator, num_inference_steps=steps, eta=0.3, guidance_scale=guidance_scale)["sample"]
    return images[0]

def generate_text_from_image(img):
    print('RUN: generate_text_from_image')
    with torch.no_grad():
        image = preprocess(img)
        image = torch.tensor(np.stack([image])).to(device)
        image -= image_mean
        image /= image_std
        
        tmp_att, tmp_fc = clip_model.encode_image(image)
        tmp_att = tmp_att[0].permute(1, 2, 0)
        tmp_fc = tmp_fc[0]
        
        att_feat = tmp_att
      
    # Inference configurations
    eval_kwargs = {}
    eval_kwargs.update(vars(opt))
  
    with torch.no_grad():
        fc_feats = torch.zeros((1,0)).to(device)
        att_feats = att_feat.view(1, 196, 2048).float().to(device)
        att_masks = None
    
        # forward the model to also get generated samples for each image
        # Only leave one feature for each image, in case duplicate sample
        tmp_eval_kwargs = eval_kwargs.copy()
        tmp_eval_kwargs.update({'sample_n': 1})
        seq, _ = model(
            fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample')
        seq = seq.data
    
        sents = utils.decode_sequence(model.vocab, seq)

        return sents[0]


def generate_drawing_from_image(img, steps=100, seed=42, guidance_scale=6.0):
    print('RUN: generate_drawing_from_image')
    caption = generate_text_from_image(img)
    caption = "a kid's drawing of " + caption
    print('\tcaption: ' + caption)
    gen_image = generate_image_from_text(caption, steps=steps, seed=seed, guidance_scale=guidance_scale)
    return gen_image


random_seed = random.randint(0, 2147483647)

gr.Interface(
    generate_drawing_from_image,
    title='Reimagine the same image but drawn by a kid :)',
    inputs=[
        gr.Image(type="pil"),
        gr.inputs.Slider(1, 100, label='Inference Steps', default=50, step=1),
        gr.inputs.Slider(0, 2147483647, label='Seed', default=random_seed, step=1),
        gr.inputs.Slider(1.0, 20.0, label='Guidance Scale - how much the prompt will influence the results', default=6.0, step=0.1),
    ],
    outputs=gr.Image(shape=[256,256], type="pil", elem_id="output_image"),
    css="#output_image{width: 256px}",
).launch()