import gradio as gr import torch import torch.nn as nn import torch.nn.functional as F from torch.autograd import Variable import torch.optim as optim import kornia.augmentation as K from CLIP import clip from torchvision import transforms from PIL import Image import numpy as np import math from matplotlib import pyplot as plt from fastprogress.fastprogress import master_bar, progress_bar from IPython.display import HTML from base64 import b64encode # Definitions device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') def sinc(x): return torch.where(x != 0, torch.sin(math.pi * x) / (math.pi * x), x.new_ones([])) def lanczos(x, a): cond = torch.logical_and(-a < x, x < a) out = torch.where(cond, sinc(x) * sinc(x/a), x.new_zeros([])) return out / out.sum() def ramp(ratio, width): n = math.ceil(width / ratio + 1) out = torch.empty([n]) cur = 0 for i in range(out.shape[0]): out[i] = cur cur += ratio return torch.cat([-out[1:].flip([0]), out])[1:-1] class Prompt(nn.Module): def __init__(self, embed, weight=1., stop=float('-inf')): super().__init__() self.register_buffer('embed', embed) self.register_buffer('weight', torch.as_tensor(weight)) self.register_buffer('stop', torch.as_tensor(stop)) def forward(self, input): input_normed = F.normalize(input.unsqueeze(1), dim=2) embed_normed = F.normalize(self.embed.unsqueeze(0), dim=2) dists = input_normed.sub(embed_normed).norm(dim=2).div(2).arcsin().pow(2).mul(2) dists = dists * self.weight.sign() return self.weight.abs() * replace_grad(dists, torch.maximum(dists, self.stop)).mean() class MakeCutouts(nn.Module): def __init__(self, cut_size, cutn, cut_pow=1.): super().__init__() self.cut_size = cut_size self.cutn = cutn self.cut_pow = cut_pow self.augs = nn.Sequential( K.RandomHorizontalFlip(p=0.5), K.RandomSharpness(0.3,p=0.4), K.RandomAffine(degrees=30, translate=0.1, p=0.8, padding_mode='border'), K.RandomPerspective(0.2,p=0.4), K.ColorJitter(hue=0.01, saturation=0.01, p=0.7)) self.noise_fac = 0.1 def forward(self, input): sideY, sideX = input.shape[2:4] max_size = min(sideX, sideY) min_size = min(sideX, sideY, self.cut_size) cutouts = [] for _ in range(self.cutn): size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size) offsetx = torch.randint(0, sideX - size + 1, ()) offsety = torch.randint(0, sideY - size + 1, ()) cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size] cutouts.append(resample(cutout, (self.cut_size, self.cut_size))) batch = self.augs(torch.cat(cutouts, dim=0)) if self.noise_fac: facs = batch.new_empty([self.cutn, 1, 1, 1]).uniform_(0, self.noise_fac) batch = batch + facs * torch.randn_like(batch) return batch def resample(input, size, align_corners=True): n, c, h, w = input.shape dh, dw = size input = input.view([n * c, 1, h, w]) if dh < h: kernel_h = lanczos(ramp(dh / h, 2), 2).to(input.device, input.dtype) pad_h = (kernel_h.shape[0] - 1) // 2 input = F.pad(input, (0, 0, pad_h, pad_h), 'reflect') input = F.conv2d(input, kernel_h[None, None, :, None]) if dw < w: kernel_w = lanczos(ramp(dw / w, 2), 2).to(input.device, input.dtype) pad_w = (kernel_w.shape[0] - 1) // 2 input = F.pad(input, (pad_w, pad_w, 0, 0), 'reflect') input = F.conv2d(input, kernel_w[None, None, None, :]) input = input.view([n, c, h, w]) return F.interpolate(input, size, mode='bicubic', align_corners=align_corners) class ReplaceGrad(torch.autograd.Function): @staticmethod def forward(ctx, x_forward, x_backward): ctx.shape = x_backward.shape return x_forward @staticmethod def backward(ctx, grad_in): return None, grad_in.sum_to_size(ctx.shape) replace_grad = ReplaceGrad.apply # Set up CLIP perceptor = clip.load('ViT-B/32', jit=False)[0].eval().requires_grad_(False).to(device) normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]) cut_size = perceptor.visual.input_resolution cutn=64 cut_pow=1 make_cutouts = MakeCutouts(cut_size, cutn, cut_pow=cut_pow) # ImStack class ImStack(nn.Module): """ This class represents an image as a series of stacked arrays, where each is 1/2 the resolution of the next. This is useful eg when trying to create an image to minimise some loss - parameters in the early (small) layers can have an affect on the overall structure and shapes while those in later layers act as residuals and fill in fine detail. """ def __init__(self, n_layers=3, base_size=32, scale=2, init_image=None, out_size=256, decay=0.7): """Constructs the Image Stack Args: TODO """ super().__init__() self.n_layers = n_layers self.base_size = base_size self.sig = nn.Sigmoid() self.layers = [] for i in range(n_layers): side = base_size * (scale**i) tim = torch.randn((3, side, side)).to(device)*(decay**i) self.layers.append(tim) self.scalers = [nn.Upsample(scale_factor=out_size/(l.shape[1]), mode='bilinear', align_corners=False) for l in self.layers] self.preview_scalers = [nn.Upsample(scale_factor=224/(l.shape[1]), mode='bilinear', align_corners=False) for l in self.layers] if init_image != None: # Given a PIL image, decompose it into a stack downscalers = [nn.Upsample(scale_factor=(l.shape[1]/out_size), mode='bilinear', align_corners=False) for l in self.layers] final_side = base_size * (scale ** n_layers) im = torch.tensor(np.array(init_image.resize((out_size, out_size)))/255).clip(1e-03, 1-1e-3) # Between 0 and 1 (non-inclusive) im = im.permute(2, 0, 1).unsqueeze(0).to(device) # torch.log(im/(1-im)) for i in range(n_layers):self.layers[i] *= 0 # Sero out the layers for i in range(n_layers): side = base_size * (scale**i) out = self.forward() residual = (torch.logit(im) - torch.logit(out)) Image.fromarray((torch.logit(residual).detach().cpu().squeeze().permute([1, 2, 0]) * 255).numpy().astype(np.uint8)).save(f'residual{i}.png') self.layers[i] = downscalers[i](residual).squeeze() for l in self.layers: l.requires_grad = True def forward(self): im = self.scalers[0](self.layers[0].unsqueeze(0)) for i in range(1, self.n_layers): im += self.scalers[i](self.layers[i].unsqueeze(0)) return self.sig(im) def preview(self, n_preview=2): im = self.preview_scalers[0](self.layers[0].unsqueeze(0)) for i in range(1, n_preview): im += self.preview_scalers[i](self.layers[i].unsqueeze(0)) return self.sig(im) def to_pil(self): return Image.fromarray((self.forward().detach().cpu().squeeze().permute([1, 2, 0]) * 255).numpy().astype(np.uint8)) def preview_pil(self): return Image.fromarray((self.preview().detach().cpu().squeeze().permute([1, 2, 0]) * 255).numpy().astype(np.uint8)) def save(self, fn): self.to_pil().save(fn) def plot_layers(self): fig, axs = plt.subplots(1, self.n_layers, figsize=(15, 5)) for i in range(self.n_layers): im = (self.sig(self.layers[i].unsqueeze(0)).detach().cpu().squeeze().permute([1, 2, 0]) * 255).numpy().astype(np.uint8) axs[i].imshow(im) def generate(text, n_iter): lr=0.25 #@param # init_image=None #@param weight_decay=1e-5 #@param out_size=540 #@param base_size=20 #@param n_layers=4 #@param scale=3 #@param p_prompts = [] embed = perceptor.encode_text(clip.tokenize(text).to(device)).float() p_prompts.append(Prompt(embed, 1, float('-inf')).to(device)) # 1 is the weight # SOme negative prompts n_prompts = [] for pr in ["Random noise", 'saturated rainbow RGB deep dream']: embed = perceptor.encode_text(clip.tokenize(pr).to(device)).float() n_prompts.append(Prompt(embed, 0.5, float('-inf')).to(device)) # 0.5 is the weight # The ImageStack - trying a different scale and n_layers ims = ImStack(base_size=base_size, scale=scale, n_layers=n_layers, out_size=out_size, decay=0.4) optimizer = optim.Adam(ims.layers, lr=lr, weight_decay=weight_decay) losses = [] for i in range(n_iter): optimizer.zero_grad() if i < 15: # Save time by skipping the cutouts and focusing on the lower layers im = ims.preview(n_preview=1 + i//20 ) iii = perceptor.encode_image(normalize(im)).float() else: im = ims() iii = perceptor.encode_image(normalize(make_cutouts(im))).float() l = 0 for prompt in p_prompts: l += prompt(iii) for prompt in n_prompts: l -= prompt(iii) losses.append(float(l.detach().cpu())) l.backward() # Backprop optimizer.step() # Update im = ims.to_pil() return np.array(im) iface = gr.Interface(fn=generate, inputs=[ gr.inputs.Textbox(label="Text Input"), gr.inputs.Number(default=42, label="N Steps") ], outputs=[ gr.outputs.Image(type="numpy", label="Output Image") ], ).launch()