jcmc's picture
adjusting interface inputs
dd7783e
import os
import sys
import gradio as gr
os.system('git clone https://github.com/openai/CLIP')
os.system('git clone https://github.com/DmitryUlyanov/deep-image-prior')
os.system('pip install -e ./CLIP')
os.system('pip install kornia einops madgrad')
import io
import math
import sys
import random
import time
import requests
sys.path.append('./CLIP')
sys.path.append('deep-image-prior')
from einops import rearrange
import gc
import imageio
from IPython import display
import kornia.augmentation as K
from madgrad import MADGRAD
import torch
import torch.optim
import torch.nn as nn
from torch.nn import functional as F
import torchvision
import torchvision.transforms.functional as TF
import torchvision.transforms as T
import numpy as np
import clip
from models import *
from utils.sr_utils import *
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
clip_model_vit_b_32 = clip.load('ViT-B/32', device=device)[0].eval().requires_grad_(False)
clip_model_vit_b_16 = clip.load('ViT-B/16', device=device)[0].eval().requires_grad_(False)
clip_models = {'ViT-B/32': clip_model_vit_b_32, 'ViT-B/16': clip_model_vit_b_16}
clip_normalize = T.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])
clip_model = 'ViT-B/16'
sideX, sideY = 256, 256 # Resolution
inv_color_scale = 1.6
anneal_lr = True
display_augs = False
class MakeCutouts(torch.nn.Module):
def __init__(self, cut_size, cutn):
super().__init__()
self.cut_size = cut_size
self.cutn = cutn
self.augs = T.Compose([
K.RandomHorizontalFlip(p=0.5),
K.RandomAffine(degrees=15, translate=0.1, p=0.8, padding_mode='border', resample='bilinear'),
K.RandomPerspective(0.4, p=0.7, resample='bilinear'),
K.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1, p=0.7),
K.RandomGrayscale(p=0.15),
])
def forward(self, input):
sideY, sideX = input.shape[2:4]
if sideY != sideX:
input = K.RandomAffine(degrees=0, shear=10, p=0.5, padding_mode='border')(input)
max_size = min(sideX, sideY)
cutouts = []
for cn in range(self.cutn):
if cn > self.cutn - self.cutn//4:
cutout = input
else:
size = int(max_size * torch.zeros(1,).normal_(mean=.8, std=.3).clip(float(self.cut_size/max_size), 1.))
offsetx = torch.randint(0, sideX - size + 1, ())
offsety = torch.randint(0, sideY - size + 1, ())
cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size))
cutouts = torch.cat(cutouts)
cutouts = self.augs(cutouts)
return cutouts
class DecorrelatedColorsToRGB(nn.Module):
"""From https://github.com/eps696/aphantasia."""
def __init__(self, inv_color_scale=1.):
super().__init__()
color_correlation_svd_sqrt = torch.tensor([[0.26, 0.09, 0.02], [0.27, 0.00, -0.05], [0.27, -0.09, 0.03]])
color_correlation_svd_sqrt /= torch.tensor([inv_color_scale, 1., 1.]) # saturate, empirical
max_norm_svd_sqrt = color_correlation_svd_sqrt.norm(dim=0).max()
color_correlation_normalized = color_correlation_svd_sqrt / max_norm_svd_sqrt
self.register_buffer('colcorr_t', color_correlation_normalized.T)
def inverse(self, image):
colcorr_t_inv = torch.linalg.inv(self.colcorr_t)
return torch.einsum('nchw,cd->ndhw', image, colcorr_t_inv)
def forward(self, image):
return torch.einsum('nchw,cd->ndhw', image, self.colcorr_t)
class CaptureOutput:
"""Captures a layer's output activations using a forward hook."""
def __init__(self, module):
self.output = None
self.handle = module.register_forward_hook(self)
def __call__(self, module, input, output):
self.output = output
def __del__(self):
self.handle.remove()
def get_output(self):
return self.output
class CLIPActivationLoss(nn.Module):
"""Maximizes or minimizes a single neuron's activations."""
def __init__(self, module, neuron, class_token=False, maximize=True):
super().__init__()
self.capture = CaptureOutput(module)
self.neuron = neuron
self.class_token = class_token
self.maximize = maximize
def forward(self):
activations = self.capture.get_output()
if self.class_token:
loss = activations[0, :, self.neuron].mean()
else:
loss = activations[1:, :, self.neuron].mean()
return -loss if self.maximize else loss
def optimize_network(
seed,
opt_type,
lr,
num_iterations,
cutn,
layer,
neuron,
class_token,
maximize,
display_rate,
video_writer
):
global itt
itt = 0
# if seed is not None:
# np.random.seed(seed)
# torch.manual_seed(seed)
# random.seed(seed)
save_progress_video = True
make_cutouts = MakeCutouts(clip_models[clip_model].visual.input_resolution, cutn)
loss_fn = CLIPActivationLoss(clip_models[clip_model].visual.transformer.resblocks[layer],
neuron, class_token, maximize)
# Initialize DIP skip network
input_depth = 32
net = get_net(
input_depth, 'skip',
pad='reflection',
skip_n33d=128, skip_n33u=128,
skip_n11=4, num_scales=6, # If you decrease the output size to 256x256 you might want to use num_scales=6
upsample_mode='bilinear',
downsample_mode='lanczos2',
)
# Modify DIP to operate in a decorrelated color space
net = net[:-1] # remove the sigmoid at the end
net.add(DecorrelatedColorsToRGB(inv_color_scale))
net.add(nn.Sigmoid())
net = net.to(device)
# Initialize input noise
net_input = torch.zeros([1, input_depth, sideY, sideX], device=device).normal_().div(10).detach()
if opt_type == 'Adam':
optimizer = torch.optim.Adam(net.parameters(), lr)
elif opt_type == 'MADGRAD':
optimizer = MADGRAD(net.parameters(), lr, momentum=0.9)
scaler = torch.cuda.amp.GradScaler()
try:
for _ in range(num_iterations):
optimizer.zero_grad(set_to_none=True)
with torch.cuda.amp.autocast():
out = net(net_input).float()
cutouts = make_cutouts(out)
image_embeds = clip_models[clip_model].encode_image(clip_normalize(cutouts))
loss = loss_fn()
optimizer.step()
# scaler.scale(loss).backward()
# scaler.step(optimizer)
# scaler.update()
itt += 1
if itt % display_rate == 0 or save_progress_video:
with torch.inference_mode():
image = TF.to_pil_image(out[0].clamp(0, 1))
if itt % display_rate == 0:
display.clear_output(wait=True)
display.display(image)
if display_augs:
aug_grid = torchvision.utils.make_grid(cutouts, nrow=math.ceil(math.sqrt(cutn)))
display.display(TF.to_pil_image(aug_grid.clamp(0, 1)))
if save_progress_video:
video_writer.append_data(np.asarray(image))
if anneal_lr:
optimizer.param_groups[0]['lr'] = max(0.00001, .99 * optimizer.param_groups[0]['lr'])
print(f'Iteration {itt} of {num_iterations}, loss: {loss.item():g}')
except KeyboardInterrupt:
pass
return TF.to_pil_image(net(net_input)[0])
# seed,
# opt_type,
def inference(
lr,
num_iterations,
cutn,
layer,
neuron,
class_token,
maximize,
display_rate = 20
):
layer = int(layer)
cutn = int(cutn)
num_iterations = int(num_iterations)
neuron = int(neuron)
display_rate = int(display_rate)
opt_type = 'MADGRAD'
seed = 20
save_progress_video = True
timestring = time.strftime('%Y%m%d%H%M%S')
if save_progress_video:
video_writer = imageio.get_writer('video.mp4', fps=10)
# Begin optimization / generation
gc.collect()
# torch.cuda.empty_cache()
out = optimize_network(
seed,
opt_type,
lr,
num_iterations,
cutn,
layer,
neuron,
class_token,
maximize,
display_rate,
video_writer = video_writer
)
# out.save(f'dip_{timestring}.png', quality=100)
if save_progress_video:
video_writer.close()
return out, 'video.mp4'
iface = gr.Interface(fn=inference,
inputs=[
gr.inputs.Number(default=1e-3, label="learning rate"),
gr.inputs.Number(default=50, label="number of iterations (more is better)"),
gr.inputs.Number(default=32, label="cutn (number of cuts)"),
gr.inputs.Number(default=10, label="layer"),
gr.inputs.Number(default=1e-3, label="neuron"),
gr.inputs.Checkbox(default=False, label="class_token"),
gr.inputs.Checkbox(default=True, label="maximise"),
gr.inputs.Slider(minimum=0, maximum=30, default=10, label='display rate'),
],
outputs=["image","video"]).launch()