#@title Gradio demo (used in space: ) from matplotlib import pyplot as plt from huggingface_hub import PyTorchModelHubMixin import numpy as np import gradio as gr #@title Defining Generator and associated code ourselves without the GPU requirements import os import json import multiprocessing from random import random import math from math import log2, floor from functools import partial from contextlib import contextmanager, ExitStack from pathlib import Path from shutil import rmtree import torch from torch.cuda.amp import autocast, GradScaler from torch.optim import Adam from torch import nn, einsum import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader from torch.autograd import grad as torch_grad from torch.utils.data.distributed import DistributedSampler from torch.nn.parallel import DistributedDataParallel as DDP from PIL import Image import torchvision from torchvision import transforms from kornia.filters import filter2d from tqdm import tqdm from einops import rearrange, reduce, repeat from adabelief_pytorch import AdaBelief # helpers def DiffAugment(x, types=[]): for p in types: for f in AUGMENT_FNS[p]: x = f(x) return x.contiguous() def exists(val): return val is not None @contextmanager def null_context(): yield def combine_contexts(contexts): @contextmanager def multi_contexts(): with ExitStack() as stack: yield [stack.enter_context(ctx()) for ctx in contexts] return multi_contexts def is_power_of_two(val): return log2(val).is_integer() def default(val, d): return val if exists(val) else d def set_requires_grad(model, bool): for p in model.parameters(): p.requires_grad = bool def cycle(iterable): while True: for i in iterable: yield i def raise_if_nan(t): if torch.isnan(t): raise NanException def gradient_accumulate_contexts(gradient_accumulate_every, is_ddp, ddps): if is_ddp: num_no_syncs = gradient_accumulate_every - 1 head = [combine_contexts(map(lambda ddp: ddp.no_sync, ddps))] * num_no_syncs tail = [null_context] contexts = head + tail else: contexts = [null_context] * gradient_accumulate_every for context in contexts: with context(): yield def evaluate_in_chunks(max_batch_size, model, *args): split_args = list(zip(*list(map(lambda x: x.split(max_batch_size, dim=0), args)))) chunked_outputs = [model(*i) for i in split_args] if len(chunked_outputs) == 1: return chunked_outputs[0] return torch.cat(chunked_outputs, dim=0) def slerp(val, low, high): low_norm = low / torch.norm(low, dim=1, keepdim=True) high_norm = high / torch.norm(high, dim=1, keepdim=True) omega = torch.acos((low_norm * high_norm).sum(1)) so = torch.sin(omega) res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high return res def safe_div(n, d): try: res = n / d except ZeroDivisionError: prefix = '' if int(n >= 0) else '-' res = float(f'{prefix}inf') return res # loss functions def gen_hinge_loss(fake, real): return fake.mean() def hinge_loss(real, fake): return (F.relu(1 + real) + F.relu(1 - fake)).mean() def dual_contrastive_loss(real_logits, fake_logits): device = real_logits.device real_logits, fake_logits = map(lambda t: rearrange(t, '... -> (...)'), (real_logits, fake_logits)) def loss_half(t1, t2): t1 = rearrange(t1, 'i -> i ()') t2 = repeat(t2, 'j -> i j', i = t1.shape[0]) t = torch.cat((t1, t2), dim = -1) return F.cross_entropy(t, torch.zeros(t1.shape[0], device = device, dtype = torch.long)) return loss_half(real_logits, fake_logits) + loss_half(-fake_logits, -real_logits) # helper classes class NanException(Exception): pass class EMA(): def __init__(self, beta): super().__init__() self.beta = beta def update_average(self, old, new): if not exists(old): return new return old * self.beta + (1 - self.beta) * new class RandomApply(nn.Module): def __init__(self, prob, fn, fn_else = lambda x: x): super().__init__() self.fn = fn self.fn_else = fn_else self.prob = prob def forward(self, x): fn = self.fn if random() < self.prob else self.fn_else return fn(x) class ChanNorm(nn.Module): def __init__(self, dim, eps = 1e-5): super().__init__() self.eps = eps self.g = nn.Parameter(torch.ones(1, dim, 1, 1)) self.b = nn.Parameter(torch.zeros(1, dim, 1, 1)) def forward(self, x): var = torch.var(x, dim = 1, unbiased = False, keepdim = True) mean = torch.mean(x, dim = 1, keepdim = True) return (x - mean) / (var + self.eps).sqrt() * self.g + self.b class PreNorm(nn.Module): def __init__(self, dim, fn): super().__init__() self.fn = fn self.norm = ChanNorm(dim) def forward(self, x): return self.fn(self.norm(x)) class Residual(nn.Module): def __init__(self, fn): super().__init__() self.fn = fn def forward(self, x): return self.fn(x) + x class SumBranches(nn.Module): def __init__(self, branches): super().__init__() self.branches = nn.ModuleList(branches) def forward(self, x): return sum(map(lambda fn: fn(x), self.branches)) class Blur(nn.Module): def __init__(self): super().__init__() f = torch.Tensor([1, 2, 1]) self.register_buffer('f', f) def forward(self, x): f = self.f f = f[None, None, :] * f [None, :, None] return filter2d(x, f, normalized=True) class Noise(nn.Module): def __init__(self): super().__init__() self.weight = nn.Parameter(torch.zeros(1)) def forward(self, x, noise = None): b, _, h, w, device = *x.shape, x.device if not exists(noise): noise = torch.randn(b, 1, h, w, device = device) return x + self.weight * noise def Conv2dSame(dim_in, dim_out, kernel_size, bias = True): pad_left = kernel_size // 2 pad_right = (pad_left - 1) if (kernel_size % 2) == 0 else pad_left return nn.Sequential( nn.ZeroPad2d((pad_left, pad_right, pad_left, pad_right)), nn.Conv2d(dim_in, dim_out, kernel_size, bias = bias) ) # attention class DepthWiseConv2d(nn.Module): def __init__(self, dim_in, dim_out, kernel_size, padding = 0, stride = 1, bias = True): super().__init__() self.net = nn.Sequential( nn.Conv2d(dim_in, dim_in, kernel_size = kernel_size, padding = padding, groups = dim_in, stride = stride, bias = bias), nn.Conv2d(dim_in, dim_out, kernel_size = 1, bias = bias) ) def forward(self, x): return self.net(x) class LinearAttention(nn.Module): def __init__(self, dim, dim_head = 64, heads = 8, kernel_size = 3): super().__init__() self.scale = dim_head ** -0.5 self.heads = heads self.dim_head = dim_head inner_dim = dim_head * heads self.kernel_size = kernel_size self.nonlin = nn.GELU() self.to_lin_q = nn.Conv2d(dim, inner_dim, 1, bias = False) self.to_lin_kv = DepthWiseConv2d(dim, inner_dim * 2, 3, padding = 1, bias = False) self.to_q = nn.Conv2d(dim, inner_dim, 1, bias = False) self.to_kv = nn.Conv2d(dim, inner_dim * 2, 1, bias = False) self.to_out = nn.Conv2d(inner_dim * 2, dim, 1) def forward(self, fmap): h, x, y = self.heads, *fmap.shape[-2:] # linear attention lin_q, lin_k, lin_v = (self.to_lin_q(fmap), *self.to_lin_kv(fmap).chunk(2, dim = 1)) lin_q, lin_k, lin_v = map(lambda t: rearrange(t, 'b (h c) x y -> (b h) (x y) c', h = h), (lin_q, lin_k, lin_v)) lin_q = lin_q.softmax(dim = -1) lin_k = lin_k.softmax(dim = -2) lin_q = lin_q * self.scale context = einsum('b n d, b n e -> b d e', lin_k, lin_v) lin_out = einsum('b n d, b d e -> b n e', lin_q, context) lin_out = rearrange(lin_out, '(b h) (x y) d -> b (h d) x y', h = h, x = x, y = y) # conv-like full attention q, k, v = (self.to_q(fmap), *self.to_kv(fmap).chunk(2, dim = 1)) q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> (b h) c x y', h = h), (q, k, v)) k = F.unfold(k, kernel_size = self.kernel_size, padding = self.kernel_size // 2) v = F.unfold(v, kernel_size = self.kernel_size, padding = self.kernel_size // 2) k, v = map(lambda t: rearrange(t, 'b (d j) n -> b n j d', d = self.dim_head), (k, v)) q = rearrange(q, 'b c ... -> b (...) c') * self.scale sim = einsum('b i d, b i j d -> b i j', q, k) sim = sim - sim.amax(dim = -1, keepdim = True).detach() attn = sim.softmax(dim = -1) full_out = einsum('b i j, b i j d -> b i d', attn, v) full_out = rearrange(full_out, '(b h) (x y) d -> b (h d) x y', h = h, x = x, y = y) # add outputs of linear attention + conv like full attention lin_out = self.nonlin(lin_out) out = torch.cat((lin_out, full_out), dim = 1) return self.to_out(out) # dataset def convert_image_to(img_type, image): if image.mode != img_type: return image.convert(img_type) return image class identity(object): def __call__(self, tensor): return tensor class expand_greyscale(object): def __init__(self, transparent): self.transparent = transparent def __call__(self, tensor): channels = tensor.shape[0] num_target_channels = 4 if self.transparent else 3 if channels == num_target_channels: return tensor alpha = None if channels == 1: color = tensor.expand(3, -1, -1) elif channels == 2: color = tensor[:1].expand(3, -1, -1) alpha = tensor[1:] else: raise Exception(f'image with invalid number of channels given {channels}') if not exists(alpha) and self.transparent: alpha = torch.ones(1, *tensor.shape[1:], device=tensor.device) return color if not self.transparent else torch.cat((color, alpha)) def resize_to_minimum_size(min_size, image): if max(*image.size) < min_size: return torchvision.transforms.functional.resize(image, min_size) return image class ImageDataset(Dataset): def __init__( self, folder, image_size, transparent = False, greyscale = False, aug_prob = 0. ): super().__init__() self.folder = folder self.image_size = image_size self.paths = [p for ext in EXTS for p in Path(f'{folder}').glob(f'**/*.{ext}')] assert len(self.paths) > 0, f'No images were found in {folder} for training' if transparent: num_channels = 4 pillow_mode = 'RGBA' expand_fn = expand_greyscale(transparent) elif greyscale: num_channels = 1 pillow_mode = 'L' expand_fn = identity() else: num_channels = 3 pillow_mode = 'RGB' expand_fn = expand_greyscale(transparent) convert_image_fn = partial(convert_image_to, pillow_mode) self.transform = transforms.Compose([ transforms.Lambda(convert_image_fn), transforms.Lambda(partial(resize_to_minimum_size, image_size)), transforms.Resize(image_size), RandomApply(aug_prob, transforms.RandomResizedCrop(image_size, scale=(0.5, 1.0), ratio=(0.98, 1.02)), transforms.CenterCrop(image_size)), transforms.ToTensor(), transforms.Lambda(expand_fn) ]) def __len__(self): return len(self.paths) def __getitem__(self, index): path = self.paths[index] img = Image.open(path) return self.transform(img) # augmentations def random_hflip(tensor, prob): if prob > random(): return tensor return torch.flip(tensor, dims=(3,)) class AugWrapper(nn.Module): def __init__(self, D, image_size): super().__init__() self.D = D def forward(self, images, prob = 0., types = [], detach = False, **kwargs): context = torch.no_grad if detach else null_context with context(): if random() < prob: images = random_hflip(images, prob=0.5) images = DiffAugment(images, types=types) return self.D(images, **kwargs) # modifiable global variables norm_class = nn.BatchNorm2d def upsample(scale_factor = 2): return nn.Upsample(scale_factor = scale_factor) # squeeze excitation classes # global context network # https://arxiv.org/abs/2012.13375 # similar to squeeze-excite, but with a simplified attention pooling and a subsequent layer norm class GlobalContext(nn.Module): def __init__( self, *, chan_in, chan_out ): super().__init__() self.to_k = nn.Conv2d(chan_in, 1, 1) chan_intermediate = max(3, chan_out // 2) self.net = nn.Sequential( nn.Conv2d(chan_in, chan_intermediate, 1), nn.LeakyReLU(0.1), nn.Conv2d(chan_intermediate, chan_out, 1), nn.Sigmoid() ) def forward(self, x): context = self.to_k(x) context = context.flatten(2).softmax(dim = -1) out = einsum('b i n, b c n -> b c i', context, x.flatten(2)) out = out.unsqueeze(-1) return self.net(out) # frequency channel attention # https://arxiv.org/abs/2012.11879 def get_1d_dct(i, freq, L): result = math.cos(math.pi * freq * (i + 0.5) / L) / math.sqrt(L) return result * (1 if freq == 0 else math.sqrt(2)) def get_dct_weights(width, channel, fidx_u, fidx_v): dct_weights = torch.zeros(1, channel, width, width) c_part = channel // len(fidx_u) for i, (u_x, v_y) in enumerate(zip(fidx_u, fidx_v)): for x in range(width): for y in range(width): coor_value = get_1d_dct(x, u_x, width) * get_1d_dct(y, v_y, width) dct_weights[:, i * c_part: (i + 1) * c_part, x, y] = coor_value return dct_weights class FCANet(nn.Module): def __init__( self, *, chan_in, chan_out, reduction = 4, width ): super().__init__() freq_w, freq_h = ([0] * 8), list(range(8)) # in paper, it seems 16 frequencies was ideal dct_weights = get_dct_weights(width, chan_in, [*freq_w, *freq_h], [*freq_h, *freq_w]) self.register_buffer('dct_weights', dct_weights) chan_intermediate = max(3, chan_out // reduction) self.net = nn.Sequential( nn.Conv2d(chan_in, chan_intermediate, 1), nn.LeakyReLU(0.1), nn.Conv2d(chan_intermediate, chan_out, 1), nn.Sigmoid() ) def forward(self, x): x = reduce(x * self.dct_weights, 'b c (h h1) (w w1) -> b c h1 w1', 'sum', h1 = 1, w1 = 1) return self.net(x) # generative adversarial network class Generator(nn.Module): def __init__( self, *, image_size, latent_dim = 256, fmap_max = 512, fmap_inverse_coef = 12, transparent = False, greyscale = False, attn_res_layers = [], freq_chan_attn = False ): super().__init__() resolution = log2(image_size) assert is_power_of_two(image_size), 'image size must be a power of 2' if transparent: init_channel = 4 elif greyscale: init_channel = 1 else: init_channel = 3 fmap_max = default(fmap_max, latent_dim) self.initial_conv = nn.Sequential( nn.ConvTranspose2d(latent_dim, latent_dim * 2, 4), norm_class(latent_dim * 2), nn.GLU(dim = 1) ) num_layers = int(resolution) - 2 features = list(map(lambda n: (n, 2 ** (fmap_inverse_coef - n)), range(2, num_layers + 2))) features = list(map(lambda n: (n[0], min(n[1], fmap_max)), features)) features = list(map(lambda n: 3 if n[0] >= 8 else n[1], features)) features = [latent_dim, *features] in_out_features = list(zip(features[:-1], features[1:])) self.res_layers = range(2, num_layers + 2) self.layers = nn.ModuleList([]) self.res_to_feature_map = dict(zip(self.res_layers, in_out_features)) self.sle_map = ((3, 7), (4, 8), (5, 9), (6, 10)) self.sle_map = list(filter(lambda t: t[0] <= resolution and t[1] <= resolution, self.sle_map)) self.sle_map = dict(self.sle_map) self.num_layers_spatial_res = 1 for (res, (chan_in, chan_out)) in zip(self.res_layers, in_out_features): image_width = 2 ** res attn = None if image_width in attn_res_layers: attn = PreNorm(chan_in, LinearAttention(chan_in)) sle = None if res in self.sle_map: residual_layer = self.sle_map[res] sle_chan_out = self.res_to_feature_map[residual_layer - 1][-1] if freq_chan_attn: sle = FCANet( chan_in = chan_out, chan_out = sle_chan_out, width = 2 ** (res + 1) ) else: sle = GlobalContext( chan_in = chan_out, chan_out = sle_chan_out ) layer = nn.ModuleList([ nn.Sequential( upsample(), Blur(), Conv2dSame(chan_in, chan_out * 2, 4), Noise(), norm_class(chan_out * 2), nn.GLU(dim = 1) ), sle, attn ]) self.layers.append(layer) self.out_conv = nn.Conv2d(features[-1], init_channel, 3, padding = 1) def forward(self, x): x = rearrange(x, 'b c -> b c () ()') x = self.initial_conv(x) x = F.normalize(x, dim = 1) residuals = dict() for (res, (up, sle, attn)) in zip(self.res_layers, self.layers): if exists(attn): x = attn(x) + x x = up(x) if exists(sle): out_res = self.sle_map[res] residual = sle(x) residuals[out_res] = residual next_res = res + 1 if next_res in residuals: x = x * residuals[next_res] return self.out_conv(x) # Initialize a generator model gan_new = Generator(latent_dim=256, image_size=256, attn_res_layers = [32]) # Load from local saved state dict # gan_new.load_state_dict(torch.load('/content/orbgan_e3_state_dict.pt')) # Load from model hub: class GeneratorWithPyTorchModelHubMixin(gan_new.__class__, PyTorchModelHubMixin): pass gan_new.__class__ = GeneratorWithPyTorchModelHubMixin gan_new = gan_new.from_pretrained('johnowhitaker/colorb_gan', latent_dim=256, image_size=256, attn_res_layers = [32]) def gen_ims(n_rows): ims = gan_new(torch.randn(int(n_rows)**2, 256)).clamp_(0., 1.) grid = torchvision.utils.make_grid(ims, nrow=int(n_rows)).permute(1, 2, 0).detach().cpu().numpy() return (grid*255).astype(np.uint8) iface = gr.Interface(fn=gen_ims, inputs=[gr.inputs.Slider(minimum=1, maximum=6, step=1, default=3,label="N rows")], outputs=[gr.outputs.Image(type="numpy", label="Generated Images")], title='Demo for Colorbgan model', article = 'A lightweight-gans trained on johnowhitaker/colorbs. See https://huggingface.co/johnowhitaker/orbgan_e1 for training and inference scripts' ) iface.launch()