import os import sys import gradio as gr os.system('git clone https://github.com/openai/CLIP') os.system('git clone https://github.com/DmitryUlyanov/deep-image-prior') os.system('pip install -e ./CLIP') os.system('pip install kornia einops madgrad') import io import math import sys import random import time import requests sys.path.append('./CLIP') sys.path.append('deep-image-prior') from einops import rearrange import gc import imageio from IPython import display import kornia.augmentation as K from madgrad import MADGRAD import torch import torch.optim import torch.nn as nn from torch.nn import functional as F import torchvision import torchvision.transforms.functional as TF import torchvision.transforms as T import numpy as np import clip from models import * from utils.sr_utils import * device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') clip_model_vit_b_32 = clip.load('ViT-B/32', device=device)[0].eval().requires_grad_(False) clip_model_vit_b_16 = clip.load('ViT-B/16', device=device)[0].eval().requires_grad_(False) clip_models = {'ViT-B/32': clip_model_vit_b_32, 'ViT-B/16': clip_model_vit_b_16} clip_normalize = T.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]) clip_model = 'ViT-B/16' sideX, sideY = 256, 256 # Resolution inv_color_scale = 1.6 anneal_lr = True display_augs = False class MakeCutouts(torch.nn.Module): def __init__(self, cut_size, cutn): super().__init__() self.cut_size = cut_size self.cutn = cutn self.augs = T.Compose([ K.RandomHorizontalFlip(p=0.5), K.RandomAffine(degrees=15, translate=0.1, p=0.8, padding_mode='border', resample='bilinear'), K.RandomPerspective(0.4, p=0.7, resample='bilinear'), K.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1, p=0.7), K.RandomGrayscale(p=0.15), ]) def forward(self, input): sideY, sideX = input.shape[2:4] if sideY != sideX: input = K.RandomAffine(degrees=0, shear=10, p=0.5, padding_mode='border')(input) max_size = min(sideX, sideY) cutouts = [] for cn in range(self.cutn): if cn > self.cutn - self.cutn//4: cutout = input else: size = int(max_size * torch.zeros(1,).normal_(mean=.8, std=.3).clip(float(self.cut_size/max_size), 1.)) offsetx = torch.randint(0, sideX - size + 1, ()) offsety = torch.randint(0, sideY - size + 1, ()) cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size] cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size)) cutouts = torch.cat(cutouts) cutouts = self.augs(cutouts) return cutouts class DecorrelatedColorsToRGB(nn.Module): """From https://github.com/eps696/aphantasia.""" def __init__(self, inv_color_scale=1.): super().__init__() color_correlation_svd_sqrt = torch.tensor([[0.26, 0.09, 0.02], [0.27, 0.00, -0.05], [0.27, -0.09, 0.03]]) color_correlation_svd_sqrt /= torch.tensor([inv_color_scale, 1., 1.]) # saturate, empirical max_norm_svd_sqrt = color_correlation_svd_sqrt.norm(dim=0).max() color_correlation_normalized = color_correlation_svd_sqrt / max_norm_svd_sqrt self.register_buffer('colcorr_t', color_correlation_normalized.T) def inverse(self, image): colcorr_t_inv = torch.linalg.inv(self.colcorr_t) return torch.einsum('nchw,cd->ndhw', image, colcorr_t_inv) def forward(self, image): return torch.einsum('nchw,cd->ndhw', image, self.colcorr_t) class CaptureOutput: """Captures a layer's output activations using a forward hook.""" def __init__(self, module): self.output = None self.handle = module.register_forward_hook(self) def __call__(self, module, input, output): self.output = output def __del__(self): self.handle.remove() def get_output(self): return self.output class CLIPActivationLoss(nn.Module): """Maximizes or minimizes a single neuron's activations.""" def __init__(self, module, neuron, class_token=False, maximize=True): super().__init__() self.capture = CaptureOutput(module) self.neuron = neuron self.class_token = class_token self.maximize = maximize def forward(self): activations = self.capture.get_output() if self.class_token: loss = activations[0, :, self.neuron].mean() else: loss = activations[1:, :, self.neuron].mean() return -loss if self.maximize else loss def optimize_network( seed, opt_type, lr, num_iterations, cutn, layer, neuron, class_token, maximize, display_rate, video_writer ): global itt itt = 0 # if seed is not None: # np.random.seed(seed) # torch.manual_seed(seed) # random.seed(seed) save_progress_video = True make_cutouts = MakeCutouts(clip_models[clip_model].visual.input_resolution, cutn) loss_fn = CLIPActivationLoss(clip_models[clip_model].visual.transformer.resblocks[layer], neuron, class_token, maximize) # Initialize DIP skip network input_depth = 32 net = get_net( input_depth, 'skip', pad='reflection', skip_n33d=128, skip_n33u=128, skip_n11=4, num_scales=6, # If you decrease the output size to 256x256 you might want to use num_scales=6 upsample_mode='bilinear', downsample_mode='lanczos2', ) # Modify DIP to operate in a decorrelated color space net = net[:-1] # remove the sigmoid at the end net.add(DecorrelatedColorsToRGB(inv_color_scale)) net.add(nn.Sigmoid()) net = net.to(device) # Initialize input noise net_input = torch.zeros([1, input_depth, sideY, sideX], device=device).normal_().div(10).detach() if opt_type == 'Adam': optimizer = torch.optim.Adam(net.parameters(), lr) elif opt_type == 'MADGRAD': optimizer = MADGRAD(net.parameters(), lr, momentum=0.9) scaler = torch.cuda.amp.GradScaler() try: for _ in range(num_iterations): optimizer.zero_grad(set_to_none=True) with torch.cuda.amp.autocast(): out = net(net_input).float() cutouts = make_cutouts(out) image_embeds = clip_models[clip_model].encode_image(clip_normalize(cutouts)) loss = loss_fn() optimizer.step() # scaler.scale(loss).backward() # scaler.step(optimizer) # scaler.update() itt += 1 if itt % display_rate == 0 or save_progress_video: with torch.inference_mode(): image = TF.to_pil_image(out[0].clamp(0, 1)) if itt % display_rate == 0: display.clear_output(wait=True) display.display(image) if display_augs: aug_grid = torchvision.utils.make_grid(cutouts, nrow=math.ceil(math.sqrt(cutn))) display.display(TF.to_pil_image(aug_grid.clamp(0, 1))) if save_progress_video: video_writer.append_data(np.asarray(image)) if anneal_lr: optimizer.param_groups[0]['lr'] = max(0.00001, .99 * optimizer.param_groups[0]['lr']) print(f'Iteration {itt} of {num_iterations}, loss: {loss.item():g}') except KeyboardInterrupt: pass return TF.to_pil_image(net(net_input)[0]) # seed, # opt_type, def inference( lr, num_iterations, cutn, layer, neuron, class_token, maximize, display_rate = 20 ): layer = int(layer) cutn = int(cutn) num_iterations = int(num_iterations) neuron = int(neuron) display_rate = int(display_rate) opt_type = 'MADGRAD' seed = 20 save_progress_video = True timestring = time.strftime('%Y%m%d%H%M%S') if save_progress_video: video_writer = imageio.get_writer('video.mp4', fps=10) # Begin optimization / generation gc.collect() # torch.cuda.empty_cache() out = optimize_network( seed, opt_type, lr, num_iterations, cutn, layer, neuron, class_token, maximize, display_rate, video_writer = video_writer ) # out.save(f'dip_{timestring}.png', quality=100) if save_progress_video: video_writer.close() return out, 'video.mp4' iface = gr.Interface(fn=inference, inputs=[ gr.inputs.Number(default=1e-3, label="learning rate"), gr.inputs.Number(default=50, label="number of iterations (more is better)"), gr.inputs.Number(default=32, label="cutn (number of cuts)"), gr.inputs.Number(default=10, label="layer"), gr.inputs.Number(default=1e-3, label="neuron"), gr.inputs.Checkbox(default=False, label="class_token"), gr.inputs.Checkbox(default=True, label="maximise"), gr.inputs.Slider(minimum=0, maximum=30, default=10, label='display rate'), ], outputs=["image","video"]).launch()