jcmc commited on
Commit
9c80c48
1 Parent(s): a4921fd

Adjusting application file and adding dependencies

Browse files
Files changed (2) hide show
  1. app.py +260 -4
  2. requirements.txt +11 -0
app.py CHANGED
@@ -1,7 +1,263 @@
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
  import gradio as gr
4
+ os.system('git clone https://github.com/openai/CLIP')
5
+ os.system('git clone https://github.com/DmitryUlyanov/deep-image-prior')
6
+ os.system('pip install -e ./CLIP')
7
+ os.system('pip install kornia einops madgrad')
8
+ import io
9
+ import math
10
+ import sys
11
+ import random
12
+ import time
13
+ import requests
14
+ sys.path.append('./CLIP')
15
+ sys.path.append('deep-image-prior')
16
+ import cv2
17
+ from einops import rearrange
18
+ import gc
19
+ import imageio
20
+ from IPython import display
21
+ import kornia.augmentation as K
22
+ from madgrad import MADGRAD
23
+ import torch
24
+ import torch.optim
25
+ import torch.nn as nn
26
+ from torch.nn import functional as F
27
+ import torchvision.transforms.functional as TF
28
+ import torchvision.transforms as T
29
+ import numpy as np
30
+ import clip
31
 
32
+ from models import *
33
+ from utils.sr_utils import *
34
 
35
+ device = torch.device('cuda')
36
+
37
+ # torch.hub.download_url_to_file('https://images.pexels.com/photos/68767/divers-underwater-ocean-swim-68767.jpeg', 'coralreef.jpeg')
38
+
39
+ # def fetch(url_or_path):
40
+ # if str(url_or_path).startswith('http://') or str(url_or_path).startswith('https://'):
41
+ # r = requests.get(url_or_path)
42
+ # r.raise_for_status()
43
+ # fd = io.BytesIO()
44
+ # fd.write(r.content)
45
+ # fd.seek(0)
46
+ # return fd
47
+ # return open(url_or_path, 'rb')
48
+
49
+ # def parse_prompt(prompt):
50
+ # if prompt.startswith('http://') or prompt.startswith('https://'):
51
+ # vals = prompt.rsplit(':', 2)
52
+ # vals = [vals[0] + ':' + vals[1], *vals[2:]]
53
+ # else:
54
+ # vals = prompt.rsplit(':', 1)
55
+ # vals = vals + ['', '1'][len(vals):]
56
+ # return vals[0], float(vals[1])
57
+ clip_model_vit_b_32 = clip.load('ViT-B/32', device=device)[0].eval().requires_grad_(False)
58
+ clip_model_vit_b_16 = clip.load('ViT-B/16', device=device)[0].eval().requires_grad_(False)
59
+ clip_models = {'ViT-B/32': clip_model_vit_b_32, 'ViT-B/16': clip_model_vit_b_16}
60
+
61
+ clip_normalize = T.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])
62
+
63
+
64
+ class MakeCutouts(torch.nn.Module):
65
+ def __init__(self, cut_size, cutn):
66
+ super().__init__()
67
+ self.cut_size = cut_size
68
+ self.cutn = cutn
69
+ self.augs = T.Compose([
70
+ K.RandomHorizontalFlip(p=0.5),
71
+ K.RandomAffine(degrees=15, translate=0.1, p=0.8, padding_mode='border', resample='bilinear'),
72
+ K.RandomPerspective(0.4, p=0.7, resample='bilinear'),
73
+ K.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1, p=0.7),
74
+ K.RandomGrayscale(p=0.15),
75
+ ])
76
+
77
+ def forward(self, input):
78
+ sideY, sideX = input.shape[2:4]
79
+ if sideY != sideX:
80
+ input = K.RandomAffine(degrees=0, shear=10, p=0.5, padding_mode='border')(input)
81
+
82
+ max_size = min(sideX, sideY)
83
+ cutouts = []
84
+ for cn in range(self.cutn):
85
+ if cn > self.cutn - self.cutn//4:
86
+ cutout = input
87
+ else:
88
+ size = int(max_size * torch.zeros(1,).normal_(mean=.8, std=.3).clip(float(self.cut_size/max_size), 1.))
89
+ offsetx = torch.randint(0, sideX - size + 1, ())
90
+ offsety = torch.randint(0, sideY - size + 1, ())
91
+ cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
92
+ cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size))
93
+ cutouts = torch.cat(cutouts)
94
+ cutouts = self.augs(cutouts)
95
+ return cutouts
96
+
97
+ class DecorrelatedColorsToRGB(nn.Module):
98
+ """From https://github.com/eps696/aphantasia."""
99
+
100
+ def __init__(self, inv_color_scale=1.):
101
+ super().__init__()
102
+ color_correlation_svd_sqrt = torch.tensor([[0.26, 0.09, 0.02], [0.27, 0.00, -0.05], [0.27, -0.09, 0.03]])
103
+ color_correlation_svd_sqrt /= torch.tensor([inv_color_scale, 1., 1.]) # saturate, empirical
104
+ max_norm_svd_sqrt = color_correlation_svd_sqrt.norm(dim=0).max()
105
+ color_correlation_normalized = color_correlation_svd_sqrt / max_norm_svd_sqrt
106
+ self.register_buffer('colcorr_t', color_correlation_normalized.T)
107
+
108
+ def inverse(self, image):
109
+ colcorr_t_inv = torch.linalg.inv(self.colcorr_t)
110
+ return torch.einsum('nchw,cd->ndhw', image, colcorr_t_inv)
111
+
112
+ def forward(self, image):
113
+ return torch.einsum('nchw,cd->ndhw', image, self.colcorr_t)
114
+
115
+
116
+ class CaptureOutput:
117
+ """Captures a layer's output activations using a forward hook."""
118
+
119
+ def __init__(self, module):
120
+ self.output = None
121
+ self.handle = module.register_forward_hook(self)
122
+
123
+ def __call__(self, module, input, output):
124
+ self.output = output
125
+
126
+ def __del__(self):
127
+ self.handle.remove()
128
+
129
+ def get_output(self):
130
+ return self.output
131
+
132
+
133
+ class CLIPActivationLoss(nn.Module):
134
+ """Maximizes or minimizes a single neuron's activations."""
135
+
136
+ def __init__(self, module, neuron, class_token=False, maximize=True):
137
+ super().__init__()
138
+ self.capture = CaptureOutput(module)
139
+ self.neuron = neuron
140
+ self.class_token = class_token
141
+ self.maximize = maximize
142
+
143
+ def forward(self):
144
+ activations = self.capture.get_output()
145
+ if self.class_token:
146
+ loss = activations[0, :, self.neuron].mean()
147
+ else:
148
+ loss = activations[1:, :, self.neuron].mean()
149
+ return -loss if self.maximize else loss
150
+
151
+
152
+ def optimize_network(seed, num_iterations, optimizer_type, lr):
153
+ global itt
154
+ itt = 0
155
+
156
+ if seed is not None:
157
+ np.random.seed(seed)
158
+ torch.manual_seed(seed)
159
+ random.seed(seed)
160
+
161
+ make_cutouts = MakeCutouts(clip_models[clip_model].visual.input_resolution, cutn)
162
+ loss_fn = CLIPActivationLoss(clip_models[clip_model].visual.transformer.resblocks[layer],
163
+ neuron, class_token, maximize)
164
+
165
+ # Initialize DIP skip network
166
+ input_depth = 32
167
+ net = get_net(
168
+ input_depth, 'skip',
169
+ pad='reflection',
170
+ skip_n33d=128, skip_n33u=128,
171
+ skip_n11=4, num_scales=7, # If you decrease the output size to 256x256 you might want to use num_scales=6
172
+ upsample_mode='bilinear',
173
+ downsample_mode='lanczos2',
174
+ )
175
+
176
+ # Modify DIP to operate in a decorrelated color space
177
+ net = net[:-1] # remove the sigmoid at the end
178
+ net.add(DecorrelatedColorsToRGB(inv_color_scale))
179
+ net.add(nn.Sigmoid())
180
+
181
+ net = net.to(device)
182
+
183
+ # Initialize input noise
184
+ net_input = torch.zeros([1, input_depth, sideY, sideX], device=device).normal_().div(10).detach()
185
+
186
+ if optimizer_type == 'Adam':
187
+ optimizer = torch.optim.Adam(net.parameters(), lr)
188
+ elif optimizer_type == 'MADGRAD':
189
+ optimizer = MADGRAD(net.parameters(), lr, momentum=0.9)
190
+ scaler = torch.cuda.amp.GradScaler()
191
+
192
+ try:
193
+ for _ in range(num_iterations):
194
+ optimizer.zero_grad(set_to_none=True)
195
+
196
+ with torch.cuda.amp.autocast():
197
+ out = net(net_input).float()
198
+ cutouts = make_cutouts(out)
199
+ image_embeds = clip_models[clip_model].encode_image(clip_normalize(cutouts))
200
+ loss = loss_fn()
201
+
202
+ scaler.scale(loss).backward()
203
+ scaler.step(optimizer)
204
+ scaler.update()
205
+
206
+ itt += 1
207
+
208
+ if itt % display_rate == 0 or save_progress_video:
209
+ with torch.inference_mode():
210
+ image = TF.to_pil_image(out[0].clamp(0, 1))
211
+ if itt % display_rate == 0:
212
+ display.clear_output(wait=True)
213
+ display.display(image)
214
+ if display_augs:
215
+ aug_grid = torchvision.utils.make_grid(cutouts, nrow=math.ceil(math.sqrt(cutn)))
216
+ display.display(TF.to_pil_image(aug_grid.clamp(0, 1)))
217
+ if save_progress_video and itt > 15:
218
+ video_writer.append_data(np.asarray(image))
219
+
220
+ if anneal_lr:
221
+ optimizer.param_groups[0]['lr'] = max(0.00001, .99 * optimizer.param_groups[0]['lr'])
222
+
223
+ print(f'Iteration {itt} of {num_iterations}, loss: {loss.item():g}')
224
+
225
+ except KeyboardInterrupt:
226
+ pass
227
+
228
+ return TF.to_pil_image(net(net_input)[0])
229
+
230
+
231
+ def inference(
232
+ seed,
233
+ opt_type,
234
+ lr,
235
+ num_iterations,
236
+ cutn,
237
+ clip_model,
238
+ layer,
239
+ neuron,
240
+ class_token,
241
+ maximize,
242
+ display_rate = 20
243
+ ):
244
+ save_progress_video = True
245
+ timestring = time.strftime('%Y%m%d%H%M%S')
246
+ if save_progress_video:
247
+ video_writer = imageio.get_writer(f'dip_{timestring}.mp4', mode='I', fps=30, codec='libx264', quality=7, pixelformat='yuv420p')
248
+
249
+ # Begin optimization / generation
250
+ gc.collect()
251
+ torch.cuda.empty_cache()
252
+ out = optimize_network(seed, num_iterations, opt_type, lr)
253
+ out.save(f'dip_{timestring}.png', quality=100)
254
+ if save_progress_video:
255
+ video_writer.close()
256
+ return out
257
+
258
+ iface = gr.Interface(fn=inference,
259
+ inputs=["number", "text", "number", "number", "number", "text", "number", "number",
260
+ gr.inputs.Checkbox(default=False, label="class_token"),
261
+ gr.inputs.Checkbox(default=True, label="maximise"),
262
+ "number"],
263
+ outputs="image").launch()
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ kornia
4
+ tqdm
5
+ clip-anytorch
6
+ requests
7
+ lpips
8
+ numpy
9
+ imageio
10
+ einops
11
+ madgrad