Spaces:
Runtime error
Runtime error
File size: 9,413 Bytes
9c80c48 248097f 9c80c48 4a8ae11 9c80c48 248097f 9c80c48 248097f 629c08f 9c80c48 d61193c 4a8ae11 d61193c 61d461f 9c80c48 d61193c 97adf14 d61193c 9c80c48 75b41db 74cb426 60e8895 9c80c48 dd7783e 9c80c48 d61193c 9c80c48 d61193c 9c80c48 629c08f 9c80c48 97adf14 9c80c48 718d615 1a5c3e6 9c80c48 74cb426 97adf14 718d615 9c80c48 a3146d6 9c80c48 75b41db d61193c 97adf14 d61193c 358206b 9c80c48 b5a48ed 9c80c48 dd7783e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 |
import os
import sys
import gradio as gr
os.system('git clone https://github.com/openai/CLIP')
os.system('git clone https://github.com/DmitryUlyanov/deep-image-prior')
os.system('pip install -e ./CLIP')
os.system('pip install kornia einops madgrad')
import io
import math
import sys
import random
import time
import requests
sys.path.append('./CLIP')
sys.path.append('deep-image-prior')
from einops import rearrange
import gc
import imageio
from IPython import display
import kornia.augmentation as K
from madgrad import MADGRAD
import torch
import torch.optim
import torch.nn as nn
from torch.nn import functional as F
import torchvision
import torchvision.transforms.functional as TF
import torchvision.transforms as T
import numpy as np
import clip
from models import *
from utils.sr_utils import *
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
clip_model_vit_b_32 = clip.load('ViT-B/32', device=device)[0].eval().requires_grad_(False)
clip_model_vit_b_16 = clip.load('ViT-B/16', device=device)[0].eval().requires_grad_(False)
clip_models = {'ViT-B/32': clip_model_vit_b_32, 'ViT-B/16': clip_model_vit_b_16}
clip_normalize = T.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])
clip_model = 'ViT-B/16'
sideX, sideY = 256, 256 # Resolution
inv_color_scale = 1.6
anneal_lr = True
display_augs = False
class MakeCutouts(torch.nn.Module):
def __init__(self, cut_size, cutn):
super().__init__()
self.cut_size = cut_size
self.cutn = cutn
self.augs = T.Compose([
K.RandomHorizontalFlip(p=0.5),
K.RandomAffine(degrees=15, translate=0.1, p=0.8, padding_mode='border', resample='bilinear'),
K.RandomPerspective(0.4, p=0.7, resample='bilinear'),
K.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1, p=0.7),
K.RandomGrayscale(p=0.15),
])
def forward(self, input):
sideY, sideX = input.shape[2:4]
if sideY != sideX:
input = K.RandomAffine(degrees=0, shear=10, p=0.5, padding_mode='border')(input)
max_size = min(sideX, sideY)
cutouts = []
for cn in range(self.cutn):
if cn > self.cutn - self.cutn//4:
cutout = input
else:
size = int(max_size * torch.zeros(1,).normal_(mean=.8, std=.3).clip(float(self.cut_size/max_size), 1.))
offsetx = torch.randint(0, sideX - size + 1, ())
offsety = torch.randint(0, sideY - size + 1, ())
cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size))
cutouts = torch.cat(cutouts)
cutouts = self.augs(cutouts)
return cutouts
class DecorrelatedColorsToRGB(nn.Module):
"""From https://github.com/eps696/aphantasia."""
def __init__(self, inv_color_scale=1.):
super().__init__()
color_correlation_svd_sqrt = torch.tensor([[0.26, 0.09, 0.02], [0.27, 0.00, -0.05], [0.27, -0.09, 0.03]])
color_correlation_svd_sqrt /= torch.tensor([inv_color_scale, 1., 1.]) # saturate, empirical
max_norm_svd_sqrt = color_correlation_svd_sqrt.norm(dim=0).max()
color_correlation_normalized = color_correlation_svd_sqrt / max_norm_svd_sqrt
self.register_buffer('colcorr_t', color_correlation_normalized.T)
def inverse(self, image):
colcorr_t_inv = torch.linalg.inv(self.colcorr_t)
return torch.einsum('nchw,cd->ndhw', image, colcorr_t_inv)
def forward(self, image):
return torch.einsum('nchw,cd->ndhw', image, self.colcorr_t)
class CaptureOutput:
"""Captures a layer's output activations using a forward hook."""
def __init__(self, module):
self.output = None
self.handle = module.register_forward_hook(self)
def __call__(self, module, input, output):
self.output = output
def __del__(self):
self.handle.remove()
def get_output(self):
return self.output
class CLIPActivationLoss(nn.Module):
"""Maximizes or minimizes a single neuron's activations."""
def __init__(self, module, neuron, class_token=False, maximize=True):
super().__init__()
self.capture = CaptureOutput(module)
self.neuron = neuron
self.class_token = class_token
self.maximize = maximize
def forward(self):
activations = self.capture.get_output()
if self.class_token:
loss = activations[0, :, self.neuron].mean()
else:
loss = activations[1:, :, self.neuron].mean()
return -loss if self.maximize else loss
def optimize_network(
seed,
opt_type,
lr,
num_iterations,
cutn,
layer,
neuron,
class_token,
maximize,
display_rate,
video_writer
):
global itt
itt = 0
# if seed is not None:
# np.random.seed(seed)
# torch.manual_seed(seed)
# random.seed(seed)
save_progress_video = True
make_cutouts = MakeCutouts(clip_models[clip_model].visual.input_resolution, cutn)
loss_fn = CLIPActivationLoss(clip_models[clip_model].visual.transformer.resblocks[layer],
neuron, class_token, maximize)
# Initialize DIP skip network
input_depth = 32
net = get_net(
input_depth, 'skip',
pad='reflection',
skip_n33d=128, skip_n33u=128,
skip_n11=4, num_scales=6, # If you decrease the output size to 256x256 you might want to use num_scales=6
upsample_mode='bilinear',
downsample_mode='lanczos2',
)
# Modify DIP to operate in a decorrelated color space
net = net[:-1] # remove the sigmoid at the end
net.add(DecorrelatedColorsToRGB(inv_color_scale))
net.add(nn.Sigmoid())
net = net.to(device)
# Initialize input noise
net_input = torch.zeros([1, input_depth, sideY, sideX], device=device).normal_().div(10).detach()
if opt_type == 'Adam':
optimizer = torch.optim.Adam(net.parameters(), lr)
elif opt_type == 'MADGRAD':
optimizer = MADGRAD(net.parameters(), lr, momentum=0.9)
scaler = torch.cuda.amp.GradScaler()
try:
for _ in range(num_iterations):
optimizer.zero_grad(set_to_none=True)
with torch.cuda.amp.autocast():
out = net(net_input).float()
cutouts = make_cutouts(out)
image_embeds = clip_models[clip_model].encode_image(clip_normalize(cutouts))
loss = loss_fn()
optimizer.step()
# scaler.scale(loss).backward()
# scaler.step(optimizer)
# scaler.update()
itt += 1
if itt % display_rate == 0 or save_progress_video:
with torch.inference_mode():
image = TF.to_pil_image(out[0].clamp(0, 1))
if itt % display_rate == 0:
display.clear_output(wait=True)
display.display(image)
if display_augs:
aug_grid = torchvision.utils.make_grid(cutouts, nrow=math.ceil(math.sqrt(cutn)))
display.display(TF.to_pil_image(aug_grid.clamp(0, 1)))
if save_progress_video:
video_writer.append_data(np.asarray(image))
if anneal_lr:
optimizer.param_groups[0]['lr'] = max(0.00001, .99 * optimizer.param_groups[0]['lr'])
print(f'Iteration {itt} of {num_iterations}, loss: {loss.item():g}')
except KeyboardInterrupt:
pass
return TF.to_pil_image(net(net_input)[0])
# seed,
# opt_type,
def inference(
lr,
num_iterations,
cutn,
layer,
neuron,
class_token,
maximize,
display_rate = 20
):
layer = int(layer)
cutn = int(cutn)
num_iterations = int(num_iterations)
neuron = int(neuron)
display_rate = int(display_rate)
opt_type = 'MADGRAD'
seed = 20
save_progress_video = True
timestring = time.strftime('%Y%m%d%H%M%S')
if save_progress_video:
video_writer = imageio.get_writer('video.mp4', fps=10)
# Begin optimization / generation
gc.collect()
# torch.cuda.empty_cache()
out = optimize_network(
seed,
opt_type,
lr,
num_iterations,
cutn,
layer,
neuron,
class_token,
maximize,
display_rate,
video_writer = video_writer
)
# out.save(f'dip_{timestring}.png', quality=100)
if save_progress_video:
video_writer.close()
return out, 'video.mp4'
iface = gr.Interface(fn=inference,
inputs=[
gr.inputs.Number(default=1e-3, label="learning rate"),
gr.inputs.Number(default=50, label="number of iterations (more is better)"),
gr.inputs.Number(default=32, label="cutn (number of cuts)"),
gr.inputs.Number(default=10, label="layer"),
gr.inputs.Number(default=1e-3, label="neuron"),
gr.inputs.Checkbox(default=False, label="class_token"),
gr.inputs.Checkbox(default=True, label="maximise"),
gr.inputs.Slider(minimum=0, maximum=30, default=10, label='display rate'),
],
outputs=["image","video"]).launch() |