File size: 9,413 Bytes
9c80c48
 
248097f
9c80c48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4a8ae11
9c80c48
 
 
 
248097f
9c80c48
 
248097f
629c08f
9c80c48
 
 
 
 
 
d61193c
4a8ae11
d61193c
61d461f
 
9c80c48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d61193c
 
 
 
 
 
 
 
 
 
97adf14
 
d61193c
9c80c48
 
 
75b41db
 
 
 
74cb426
60e8895
9c80c48
 
 
 
 
 
 
 
 
 
dd7783e
9c80c48
 
 
 
 
 
 
 
 
 
 
 
 
 
d61193c
9c80c48
d61193c
9c80c48
 
 
 
 
 
 
 
 
 
 
 
 
629c08f
 
 
 
9c80c48
 
 
 
 
 
 
 
 
 
 
 
97adf14
9c80c48
 
 
 
 
 
 
 
 
 
 
 
718d615
 
1a5c3e6
9c80c48
 
 
 
 
 
 
 
 
74cb426
 
 
 
 
97adf14
718d615
 
9c80c48
 
 
a3146d6
9c80c48
 
75b41db
d61193c
 
 
 
 
 
 
 
 
 
97adf14
 
d61193c
 
358206b
9c80c48
 
b5a48ed
9c80c48
 
dd7783e
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
import os
import sys
import gradio as gr
os.system('git clone https://github.com/openai/CLIP')
os.system('git clone https://github.com/DmitryUlyanov/deep-image-prior')
os.system('pip install -e ./CLIP')
os.system('pip install kornia einops madgrad')
import io
import math
import sys
import random
import time
import requests
sys.path.append('./CLIP')
sys.path.append('deep-image-prior')
from einops import rearrange
import gc
import imageio
from IPython import display
import kornia.augmentation as K
from madgrad import MADGRAD
import torch
import torch.optim
import torch.nn as nn
from torch.nn import functional as F
import torchvision
import torchvision.transforms.functional as TF
import torchvision.transforms as T
import numpy as np
import clip

from models import *
from utils.sr_utils import *

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

clip_model_vit_b_32 = clip.load('ViT-B/32', device=device)[0].eval().requires_grad_(False)
clip_model_vit_b_16 = clip.load('ViT-B/16', device=device)[0].eval().requires_grad_(False)
clip_models = {'ViT-B/32': clip_model_vit_b_32, 'ViT-B/16': clip_model_vit_b_16}

clip_normalize = T.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])
clip_model = 'ViT-B/16'
sideX, sideY = 256, 256  # Resolution
inv_color_scale = 1.6
anneal_lr = True
display_augs = False 

class MakeCutouts(torch.nn.Module):
    def __init__(self, cut_size, cutn):
        super().__init__()
        self.cut_size = cut_size
        self.cutn = cutn
        self.augs = T.Compose([
            K.RandomHorizontalFlip(p=0.5),
            K.RandomAffine(degrees=15, translate=0.1, p=0.8, padding_mode='border', resample='bilinear'),
            K.RandomPerspective(0.4, p=0.7, resample='bilinear'),
            K.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1, p=0.7),
            K.RandomGrayscale(p=0.15),
        ])

    def forward(self, input):
        sideY, sideX = input.shape[2:4]
        if sideY != sideX:
            input = K.RandomAffine(degrees=0, shear=10, p=0.5, padding_mode='border')(input)

        max_size = min(sideX, sideY)
        cutouts = []
        for cn in range(self.cutn):
            if cn > self.cutn - self.cutn//4:
                cutout = input
            else:
                size = int(max_size * torch.zeros(1,).normal_(mean=.8, std=.3).clip(float(self.cut_size/max_size), 1.))
                offsetx = torch.randint(0, sideX - size + 1, ())
                offsety = torch.randint(0, sideY - size + 1, ())
                cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
            cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size))
        cutouts = torch.cat(cutouts)
        cutouts = self.augs(cutouts)
        return cutouts

class DecorrelatedColorsToRGB(nn.Module):
    """From https://github.com/eps696/aphantasia."""

    def __init__(self, inv_color_scale=1.):
        super().__init__()
        color_correlation_svd_sqrt = torch.tensor([[0.26, 0.09, 0.02], [0.27, 0.00, -0.05], [0.27, -0.09, 0.03]])
        color_correlation_svd_sqrt /= torch.tensor([inv_color_scale, 1., 1.])  # saturate, empirical
        max_norm_svd_sqrt = color_correlation_svd_sqrt.norm(dim=0).max()
        color_correlation_normalized = color_correlation_svd_sqrt / max_norm_svd_sqrt
        self.register_buffer('colcorr_t', color_correlation_normalized.T)

    def inverse(self, image):
        colcorr_t_inv = torch.linalg.inv(self.colcorr_t)
        return torch.einsum('nchw,cd->ndhw', image, colcorr_t_inv)

    def forward(self, image):
        return torch.einsum('nchw,cd->ndhw', image, self.colcorr_t)


class CaptureOutput:
    """Captures a layer's output activations using a forward hook."""

    def __init__(self, module):
        self.output = None
        self.handle = module.register_forward_hook(self)

    def __call__(self, module, input, output):
        self.output = output

    def __del__(self):
        self.handle.remove()

    def get_output(self):
        return self.output


class CLIPActivationLoss(nn.Module):
    """Maximizes or minimizes a single neuron's activations."""

    def __init__(self, module, neuron, class_token=False, maximize=True):
        super().__init__()
        self.capture = CaptureOutput(module)
        self.neuron = neuron
        self.class_token = class_token
        self.maximize = maximize

    def forward(self):
        activations = self.capture.get_output()
        if self.class_token:
            loss = activations[0, :, self.neuron].mean()
        else:
            loss = activations[1:, :, self.neuron].mean()
        return -loss if self.maximize else loss


def optimize_network(
    seed, 
    opt_type, 
    lr, 
    num_iterations, 
    cutn, 
    layer,
    neuron,
    class_token,
    maximize,
    display_rate,
    video_writer
):
    global itt
    itt = 0

    # if seed is not None:
    #     np.random.seed(seed)
    #     torch.manual_seed(seed)
    #     random.seed(seed)
    save_progress_video = True
    
    make_cutouts = MakeCutouts(clip_models[clip_model].visual.input_resolution, cutn)
    loss_fn = CLIPActivationLoss(clip_models[clip_model].visual.transformer.resblocks[layer],
                                 neuron, class_token, maximize)

    # Initialize DIP skip network
    input_depth = 32
    net = get_net(
        input_depth, 'skip',
        pad='reflection',
        skip_n33d=128, skip_n33u=128,
        skip_n11=4, num_scales=6,  # If you decrease the output size to 256x256 you might want to use num_scales=6
        upsample_mode='bilinear',
        downsample_mode='lanczos2',
    )

    # Modify DIP to operate in a decorrelated color space
    net = net[:-1]  # remove the sigmoid at the end
    net.add(DecorrelatedColorsToRGB(inv_color_scale))
    net.add(nn.Sigmoid())

    net = net.to(device)

    # Initialize input noise
    net_input = torch.zeros([1, input_depth, sideY, sideX], device=device).normal_().div(10).detach()

    if opt_type == 'Adam':
        optimizer = torch.optim.Adam(net.parameters(), lr)
    elif opt_type == 'MADGRAD':
        optimizer = MADGRAD(net.parameters(), lr, momentum=0.9)
    scaler = torch.cuda.amp.GradScaler()

    try:
        for _ in range(num_iterations):
            optimizer.zero_grad(set_to_none=True)
    
            with torch.cuda.amp.autocast():
                out = net(net_input).float()
            cutouts = make_cutouts(out)
            image_embeds = clip_models[clip_model].encode_image(clip_normalize(cutouts))
            loss = loss_fn()

            optimizer.step()
            # scaler.scale(loss).backward()
            # scaler.step(optimizer)
            # scaler.update()

            itt += 1

            if itt % display_rate == 0 or save_progress_video:
                with torch.inference_mode():
                    image = TF.to_pil_image(out[0].clamp(0, 1))
                    if itt % display_rate == 0:
                        display.clear_output(wait=True)
                        display.display(image)
                        if display_augs:
                            aug_grid = torchvision.utils.make_grid(cutouts, nrow=math.ceil(math.sqrt(cutn)))
                            display.display(TF.to_pil_image(aug_grid.clamp(0, 1)))
                    if save_progress_video:
                        video_writer.append_data(np.asarray(image))

            if anneal_lr:
                optimizer.param_groups[0]['lr'] = max(0.00001, .99 * optimizer.param_groups[0]['lr'])

            print(f'Iteration {itt} of {num_iterations}, loss: {loss.item():g}')
    
    except KeyboardInterrupt:
        pass

    return TF.to_pil_image(net(net_input)[0])

    # seed, 
    # opt_type, 
def inference(
    lr, 
    num_iterations, 
    cutn, 
    layer,
    neuron,
    class_token,
    maximize,
    display_rate = 20
):
    layer = int(layer)
    cutn = int(cutn)
    num_iterations = int(num_iterations)
    neuron = int(neuron)
    display_rate = int(display_rate)

    opt_type = 'MADGRAD'
    seed = 20
    save_progress_video = True
    timestring = time.strftime('%Y%m%d%H%M%S')
    if save_progress_video:
        video_writer = imageio.get_writer('video.mp4', fps=10)
    # Begin optimization / generation
    gc.collect()
    # torch.cuda.empty_cache()
    out = optimize_network(
            seed, 
            opt_type, 
            lr, 
            num_iterations, 
            cutn, 
            layer,
            neuron,
            class_token,
            maximize,
            display_rate,
            video_writer = video_writer
        )

    # out.save(f'dip_{timestring}.png', quality=100)
    if save_progress_video:
        video_writer.close()
    return out, 'video.mp4'

iface = gr.Interface(fn=inference, 
    inputs=[
        gr.inputs.Number(default=1e-3, label="learning rate"),
        gr.inputs.Number(default=50, label="number of iterations (more is better)"),
        gr.inputs.Number(default=32, label="cutn (number of cuts)"),
        gr.inputs.Number(default=10, label="layer"),
        gr.inputs.Number(default=1e-3, label="neuron"),
        gr.inputs.Checkbox(default=False, label="class_token"),
        gr.inputs.Checkbox(default=True, label="maximise"),
        gr.inputs.Slider(minimum=0, maximum=30, default=10, label='display rate'),
        ], 
    outputs=["image","video"]).launch()