Spaces:
Runtime error
Runtime error
import os | |
import sys | |
import gradio as gr | |
os.system('git clone https://github.com/openai/CLIP') | |
os.system('git clone https://github.com/DmitryUlyanov/deep-image-prior') | |
os.system('pip install -e ./CLIP') | |
os.system('pip install kornia einops madgrad') | |
import io | |
import math | |
import sys | |
import random | |
import time | |
import requests | |
sys.path.append('./CLIP') | |
sys.path.append('deep-image-prior') | |
from einops import rearrange | |
import gc | |
import imageio | |
from IPython import display | |
import kornia.augmentation as K | |
from madgrad import MADGRAD | |
import torch | |
import torch.optim | |
import torch.nn as nn | |
from torch.nn import functional as F | |
import torchvision | |
import torchvision.transforms.functional as TF | |
import torchvision.transforms as T | |
import numpy as np | |
import clip | |
from models import * | |
from utils.sr_utils import * | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
clip_model_vit_b_32 = clip.load('ViT-B/32', device=device)[0].eval().requires_grad_(False) | |
clip_model_vit_b_16 = clip.load('ViT-B/16', device=device)[0].eval().requires_grad_(False) | |
clip_models = {'ViT-B/32': clip_model_vit_b_32, 'ViT-B/16': clip_model_vit_b_16} | |
clip_normalize = T.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]) | |
clip_model = 'ViT-B/16' | |
sideX, sideY = 256, 256 # Resolution | |
inv_color_scale = 1.6 | |
anneal_lr = True | |
display_augs = False | |
class MakeCutouts(torch.nn.Module): | |
def __init__(self, cut_size, cutn): | |
super().__init__() | |
self.cut_size = cut_size | |
self.cutn = cutn | |
self.augs = T.Compose([ | |
K.RandomHorizontalFlip(p=0.5), | |
K.RandomAffine(degrees=15, translate=0.1, p=0.8, padding_mode='border', resample='bilinear'), | |
K.RandomPerspective(0.4, p=0.7, resample='bilinear'), | |
K.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1, p=0.7), | |
K.RandomGrayscale(p=0.15), | |
]) | |
def forward(self, input): | |
sideY, sideX = input.shape[2:4] | |
if sideY != sideX: | |
input = K.RandomAffine(degrees=0, shear=10, p=0.5, padding_mode='border')(input) | |
max_size = min(sideX, sideY) | |
cutouts = [] | |
for cn in range(self.cutn): | |
if cn > self.cutn - self.cutn//4: | |
cutout = input | |
else: | |
size = int(max_size * torch.zeros(1,).normal_(mean=.8, std=.3).clip(float(self.cut_size/max_size), 1.)) | |
offsetx = torch.randint(0, sideX - size + 1, ()) | |
offsety = torch.randint(0, sideY - size + 1, ()) | |
cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size] | |
cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size)) | |
cutouts = torch.cat(cutouts) | |
cutouts = self.augs(cutouts) | |
return cutouts | |
class DecorrelatedColorsToRGB(nn.Module): | |
"""From https://github.com/eps696/aphantasia.""" | |
def __init__(self, inv_color_scale=1.): | |
super().__init__() | |
color_correlation_svd_sqrt = torch.tensor([[0.26, 0.09, 0.02], [0.27, 0.00, -0.05], [0.27, -0.09, 0.03]]) | |
color_correlation_svd_sqrt /= torch.tensor([inv_color_scale, 1., 1.]) # saturate, empirical | |
max_norm_svd_sqrt = color_correlation_svd_sqrt.norm(dim=0).max() | |
color_correlation_normalized = color_correlation_svd_sqrt / max_norm_svd_sqrt | |
self.register_buffer('colcorr_t', color_correlation_normalized.T) | |
def inverse(self, image): | |
colcorr_t_inv = torch.linalg.inv(self.colcorr_t) | |
return torch.einsum('nchw,cd->ndhw', image, colcorr_t_inv) | |
def forward(self, image): | |
return torch.einsum('nchw,cd->ndhw', image, self.colcorr_t) | |
class CaptureOutput: | |
"""Captures a layer's output activations using a forward hook.""" | |
def __init__(self, module): | |
self.output = None | |
self.handle = module.register_forward_hook(self) | |
def __call__(self, module, input, output): | |
self.output = output | |
def __del__(self): | |
self.handle.remove() | |
def get_output(self): | |
return self.output | |
class CLIPActivationLoss(nn.Module): | |
"""Maximizes or minimizes a single neuron's activations.""" | |
def __init__(self, module, neuron, class_token=False, maximize=True): | |
super().__init__() | |
self.capture = CaptureOutput(module) | |
self.neuron = neuron | |
self.class_token = class_token | |
self.maximize = maximize | |
def forward(self): | |
activations = self.capture.get_output() | |
if self.class_token: | |
loss = activations[0, :, self.neuron].mean() | |
else: | |
loss = activations[1:, :, self.neuron].mean() | |
return -loss if self.maximize else loss | |
def optimize_network( | |
seed, | |
opt_type, | |
lr, | |
num_iterations, | |
cutn, | |
layer, | |
neuron, | |
class_token, | |
maximize, | |
display_rate, | |
video_writer | |
): | |
global itt | |
itt = 0 | |
# if seed is not None: | |
# np.random.seed(seed) | |
# torch.manual_seed(seed) | |
# random.seed(seed) | |
save_progress_video = True | |
make_cutouts = MakeCutouts(clip_models[clip_model].visual.input_resolution, cutn) | |
loss_fn = CLIPActivationLoss(clip_models[clip_model].visual.transformer.resblocks[layer], | |
neuron, class_token, maximize) | |
# Initialize DIP skip network | |
input_depth = 32 | |
net = get_net( | |
input_depth, 'skip', | |
pad='reflection', | |
skip_n33d=128, skip_n33u=128, | |
skip_n11=4, num_scales=6, # If you decrease the output size to 256x256 you might want to use num_scales=6 | |
upsample_mode='bilinear', | |
downsample_mode='lanczos2', | |
) | |
# Modify DIP to operate in a decorrelated color space | |
net = net[:-1] # remove the sigmoid at the end | |
net.add(DecorrelatedColorsToRGB(inv_color_scale)) | |
net.add(nn.Sigmoid()) | |
net = net.to(device) | |
# Initialize input noise | |
net_input = torch.zeros([1, input_depth, sideY, sideX], device=device).normal_().div(10).detach() | |
if opt_type == 'Adam': | |
optimizer = torch.optim.Adam(net.parameters(), lr) | |
elif opt_type == 'MADGRAD': | |
optimizer = MADGRAD(net.parameters(), lr, momentum=0.9) | |
scaler = torch.cuda.amp.GradScaler() | |
try: | |
for _ in range(num_iterations): | |
optimizer.zero_grad(set_to_none=True) | |
with torch.cuda.amp.autocast(): | |
out = net(net_input).float() | |
cutouts = make_cutouts(out) | |
image_embeds = clip_models[clip_model].encode_image(clip_normalize(cutouts)) | |
loss = loss_fn() | |
optimizer.step() | |
# scaler.scale(loss).backward() | |
# scaler.step(optimizer) | |
# scaler.update() | |
itt += 1 | |
if itt % display_rate == 0 or save_progress_video: | |
with torch.inference_mode(): | |
image = TF.to_pil_image(out[0].clamp(0, 1)) | |
if itt % display_rate == 0: | |
display.clear_output(wait=True) | |
display.display(image) | |
if display_augs: | |
aug_grid = torchvision.utils.make_grid(cutouts, nrow=math.ceil(math.sqrt(cutn))) | |
display.display(TF.to_pil_image(aug_grid.clamp(0, 1))) | |
if save_progress_video: | |
video_writer.append_data(np.asarray(image)) | |
if anneal_lr: | |
optimizer.param_groups[0]['lr'] = max(0.00001, .99 * optimizer.param_groups[0]['lr']) | |
print(f'Iteration {itt} of {num_iterations}, loss: {loss.item():g}') | |
except KeyboardInterrupt: | |
pass | |
return TF.to_pil_image(net(net_input)[0]) | |
# seed, | |
# opt_type, | |
def inference( | |
lr, | |
num_iterations, | |
cutn, | |
layer, | |
neuron, | |
class_token, | |
maximize, | |
display_rate = 20 | |
): | |
layer = int(layer) | |
cutn = int(cutn) | |
num_iterations = int(num_iterations) | |
neuron = int(neuron) | |
display_rate = int(display_rate) | |
opt_type = 'MADGRAD' | |
seed = 20 | |
save_progress_video = True | |
timestring = time.strftime('%Y%m%d%H%M%S') | |
if save_progress_video: | |
video_writer = imageio.get_writer('video.mp4', fps=10) | |
# Begin optimization / generation | |
gc.collect() | |
# torch.cuda.empty_cache() | |
out = optimize_network( | |
seed, | |
opt_type, | |
lr, | |
num_iterations, | |
cutn, | |
layer, | |
neuron, | |
class_token, | |
maximize, | |
display_rate, | |
video_writer = video_writer | |
) | |
# out.save(f'dip_{timestring}.png', quality=100) | |
if save_progress_video: | |
video_writer.close() | |
return out, 'video.mp4' | |
iface = gr.Interface(fn=inference, | |
inputs=[ | |
gr.inputs.Number(default=1e-3, label="learning rate"), | |
gr.inputs.Number(default=50, label="number of iterations (more is better)"), | |
gr.inputs.Number(default=32, label="cutn (number of cuts)"), | |
gr.inputs.Number(default=10, label="layer"), | |
gr.inputs.Number(default=1e-3, label="neuron"), | |
gr.inputs.Checkbox(default=False, label="class_token"), | |
gr.inputs.Checkbox(default=True, label="maximise"), | |
gr.inputs.Slider(minimum=0, maximum=30, default=10, label='display rate'), | |
], | |
outputs=["image","video"]).launch() |