Spaces:
Runtime error
Runtime error
| import os | |
| import sys | |
| import gradio as gr | |
| os.system('git clone https://github.com/openai/CLIP') | |
| os.system('git clone https://github.com/DmitryUlyanov/deep-image-prior') | |
| os.system('pip install -e ./CLIP') | |
| os.system('pip install kornia einops madgrad') | |
| import io | |
| import math | |
| import sys | |
| import random | |
| import time | |
| import requests | |
| sys.path.append('./CLIP') | |
| sys.path.append('deep-image-prior') | |
| from einops import rearrange | |
| import gc | |
| import imageio | |
| from IPython import display | |
| import kornia.augmentation as K | |
| from madgrad import MADGRAD | |
| import torch | |
| import torch.optim | |
| import torch.nn as nn | |
| from torch.nn import functional as F | |
| import torchvision | |
| import torchvision.transforms.functional as TF | |
| import torchvision.transforms as T | |
| import numpy as np | |
| import clip | |
| from models import * | |
| from utils.sr_utils import * | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| clip_model_vit_b_32 = clip.load('ViT-B/32', device=device)[0].eval().requires_grad_(False) | |
| clip_model_vit_b_16 = clip.load('ViT-B/16', device=device)[0].eval().requires_grad_(False) | |
| clip_models = {'ViT-B/32': clip_model_vit_b_32, 'ViT-B/16': clip_model_vit_b_16} | |
| clip_normalize = T.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]) | |
| clip_model = 'ViT-B/16' | |
| sideX, sideY = 256, 256 # Resolution | |
| inv_color_scale = 1.6 | |
| anneal_lr = True | |
| display_augs = False | |
| class MakeCutouts(torch.nn.Module): | |
| def __init__(self, cut_size, cutn): | |
| super().__init__() | |
| self.cut_size = cut_size | |
| self.cutn = cutn | |
| self.augs = T.Compose([ | |
| K.RandomHorizontalFlip(p=0.5), | |
| K.RandomAffine(degrees=15, translate=0.1, p=0.8, padding_mode='border', resample='bilinear'), | |
| K.RandomPerspective(0.4, p=0.7, resample='bilinear'), | |
| K.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1, p=0.7), | |
| K.RandomGrayscale(p=0.15), | |
| ]) | |
| def forward(self, input): | |
| sideY, sideX = input.shape[2:4] | |
| if sideY != sideX: | |
| input = K.RandomAffine(degrees=0, shear=10, p=0.5, padding_mode='border')(input) | |
| max_size = min(sideX, sideY) | |
| cutouts = [] | |
| for cn in range(self.cutn): | |
| if cn > self.cutn - self.cutn//4: | |
| cutout = input | |
| else: | |
| size = int(max_size * torch.zeros(1,).normal_(mean=.8, std=.3).clip(float(self.cut_size/max_size), 1.)) | |
| offsetx = torch.randint(0, sideX - size + 1, ()) | |
| offsety = torch.randint(0, sideY - size + 1, ()) | |
| cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size] | |
| cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size)) | |
| cutouts = torch.cat(cutouts) | |
| cutouts = self.augs(cutouts) | |
| return cutouts | |
| class DecorrelatedColorsToRGB(nn.Module): | |
| """From https://github.com/eps696/aphantasia.""" | |
| def __init__(self, inv_color_scale=1.): | |
| super().__init__() | |
| color_correlation_svd_sqrt = torch.tensor([[0.26, 0.09, 0.02], [0.27, 0.00, -0.05], [0.27, -0.09, 0.03]]) | |
| color_correlation_svd_sqrt /= torch.tensor([inv_color_scale, 1., 1.]) # saturate, empirical | |
| max_norm_svd_sqrt = color_correlation_svd_sqrt.norm(dim=0).max() | |
| color_correlation_normalized = color_correlation_svd_sqrt / max_norm_svd_sqrt | |
| self.register_buffer('colcorr_t', color_correlation_normalized.T) | |
| def inverse(self, image): | |
| colcorr_t_inv = torch.linalg.inv(self.colcorr_t) | |
| return torch.einsum('nchw,cd->ndhw', image, colcorr_t_inv) | |
| def forward(self, image): | |
| return torch.einsum('nchw,cd->ndhw', image, self.colcorr_t) | |
| class CaptureOutput: | |
| """Captures a layer's output activations using a forward hook.""" | |
| def __init__(self, module): | |
| self.output = None | |
| self.handle = module.register_forward_hook(self) | |
| def __call__(self, module, input, output): | |
| self.output = output | |
| def __del__(self): | |
| self.handle.remove() | |
| def get_output(self): | |
| return self.output | |
| class CLIPActivationLoss(nn.Module): | |
| """Maximizes or minimizes a single neuron's activations.""" | |
| def __init__(self, module, neuron, class_token=False, maximize=True): | |
| super().__init__() | |
| self.capture = CaptureOutput(module) | |
| self.neuron = neuron | |
| self.class_token = class_token | |
| self.maximize = maximize | |
| def forward(self): | |
| activations = self.capture.get_output() | |
| if self.class_token: | |
| loss = activations[0, :, self.neuron].mean() | |
| else: | |
| loss = activations[1:, :, self.neuron].mean() | |
| return -loss if self.maximize else loss | |
| def optimize_network( | |
| seed, | |
| opt_type, | |
| lr, | |
| num_iterations, | |
| cutn, | |
| layer, | |
| neuron, | |
| class_token, | |
| maximize, | |
| display_rate, | |
| video_writer | |
| ): | |
| global itt | |
| itt = 0 | |
| # if seed is not None: | |
| # np.random.seed(seed) | |
| # torch.manual_seed(seed) | |
| # random.seed(seed) | |
| save_progress_video = True | |
| make_cutouts = MakeCutouts(clip_models[clip_model].visual.input_resolution, cutn) | |
| loss_fn = CLIPActivationLoss(clip_models[clip_model].visual.transformer.resblocks[layer], | |
| neuron, class_token, maximize) | |
| # Initialize DIP skip network | |
| input_depth = 32 | |
| net = get_net( | |
| input_depth, 'skip', | |
| pad='reflection', | |
| skip_n33d=128, skip_n33u=128, | |
| skip_n11=4, num_scales=6, # If you decrease the output size to 256x256 you might want to use num_scales=6 | |
| upsample_mode='bilinear', | |
| downsample_mode='lanczos2', | |
| ) | |
| # Modify DIP to operate in a decorrelated color space | |
| net = net[:-1] # remove the sigmoid at the end | |
| net.add(DecorrelatedColorsToRGB(inv_color_scale)) | |
| net.add(nn.Sigmoid()) | |
| net = net.to(device) | |
| # Initialize input noise | |
| net_input = torch.zeros([1, input_depth, sideY, sideX], device=device).normal_().div(10).detach() | |
| if opt_type == 'Adam': | |
| optimizer = torch.optim.Adam(net.parameters(), lr) | |
| elif opt_type == 'MADGRAD': | |
| optimizer = MADGRAD(net.parameters(), lr, momentum=0.9) | |
| scaler = torch.cuda.amp.GradScaler() | |
| try: | |
| for _ in range(num_iterations): | |
| optimizer.zero_grad(set_to_none=True) | |
| with torch.cuda.amp.autocast(): | |
| out = net(net_input).float() | |
| cutouts = make_cutouts(out) | |
| image_embeds = clip_models[clip_model].encode_image(clip_normalize(cutouts)) | |
| loss = loss_fn() | |
| optimizer.step() | |
| # scaler.scale(loss).backward() | |
| # scaler.step(optimizer) | |
| # scaler.update() | |
| itt += 1 | |
| if itt % display_rate == 0 or save_progress_video: | |
| with torch.inference_mode(): | |
| image = TF.to_pil_image(out[0].clamp(0, 1)) | |
| if itt % display_rate == 0: | |
| display.clear_output(wait=True) | |
| display.display(image) | |
| if display_augs: | |
| aug_grid = torchvision.utils.make_grid(cutouts, nrow=math.ceil(math.sqrt(cutn))) | |
| display.display(TF.to_pil_image(aug_grid.clamp(0, 1))) | |
| if save_progress_video: | |
| video_writer.append_data(np.asarray(image)) | |
| if anneal_lr: | |
| optimizer.param_groups[0]['lr'] = max(0.00001, .99 * optimizer.param_groups[0]['lr']) | |
| print(f'Iteration {itt} of {num_iterations}, loss: {loss.item():g}') | |
| except KeyboardInterrupt: | |
| pass | |
| return TF.to_pil_image(net(net_input)[0]) | |
| # seed, | |
| # opt_type, | |
| def inference( | |
| lr, | |
| num_iterations, | |
| cutn, | |
| layer, | |
| neuron, | |
| class_token, | |
| maximize, | |
| display_rate = 20 | |
| ): | |
| layer = int(layer) | |
| cutn = int(cutn) | |
| num_iterations = int(num_iterations) | |
| neuron = int(neuron) | |
| display_rate = int(display_rate) | |
| opt_type = 'MADGRAD' | |
| seed = 20 | |
| save_progress_video = True | |
| timestring = time.strftime('%Y%m%d%H%M%S') | |
| if save_progress_video: | |
| video_writer = imageio.get_writer('video.mp4', fps=10) | |
| # Begin optimization / generation | |
| gc.collect() | |
| # torch.cuda.empty_cache() | |
| out = optimize_network( | |
| seed, | |
| opt_type, | |
| lr, | |
| num_iterations, | |
| cutn, | |
| layer, | |
| neuron, | |
| class_token, | |
| maximize, | |
| display_rate, | |
| video_writer = video_writer | |
| ) | |
| # out.save(f'dip_{timestring}.png', quality=100) | |
| if save_progress_video: | |
| video_writer.close() | |
| return out, 'video.mp4' | |
| iface = gr.Interface(fn=inference, | |
| inputs=[ | |
| gr.inputs.Number(default=1e-3, label="learning rate"), | |
| gr.inputs.Number(default=50, label="number of iterations (more is better)"), | |
| gr.inputs.Number(default=32, label="cutn (number of cuts)"), | |
| gr.inputs.Number(default=10, label="layer"), | |
| gr.inputs.Number(default=1e-3, label="neuron"), | |
| gr.inputs.Checkbox(default=False, label="class_token"), | |
| gr.inputs.Checkbox(default=True, label="maximise"), | |
| gr.inputs.Slider(minimum=0, maximum=30, default=10, label='display rate'), | |
| ], | |
| outputs=["image","video"]).launch() |