#@title Gradio demo (used in space: ) from matplotlib import pyplot as plt from huggingface_hub import PyTorchModelHubMixin import numpy as np import gradio as gr ### A BIG CHUNK OF THIS IS COPIED FROM LIGHTWEIGHTGAN since the original has an assert requiring GPU import os import json import multiprocessing from random import random import math from math import log2, floor from functools import partial from contextlib import contextmanager, ExitStack from pathlib import Path from shutil import rmtree import torch from torch.cuda.amp import autocast, GradScaler from torch.optim import Adam from torch import nn, einsum import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader from torch.autograd import grad as torch_grad from torch.utils.data.distributed import DistributedSampler from torch.nn.parallel import DistributedDataParallel as DDP from PIL import Image import torchvision from torchvision import transforms from kornia.filters import filter2d from tqdm import tqdm from einops import rearrange, reduce, repeat from adabelief_pytorch import AdaBelief # helpers def DiffAugment(x, types=[]): for p in types: for f in AUGMENT_FNS[p]: x = f(x) return x.contiguous() @contextmanager def null_context(): yield def combine_contexts(contexts): @contextmanager def multi_contexts(): with ExitStack() as stack: yield [stack.enter_context(ctx()) for ctx in contexts] return multi_contexts def exists(val): return val is not None def is_power_of_two(val): return log2(val).is_integer() def default(val, d): return val if exists(val) else d def set_requires_grad(model, bool): for p in model.parameters(): p.requires_grad = bool def cycle(iterable): while True: for i in iterable: yield i def raise_if_nan(t): if torch.isnan(t): raise NanException def evaluate_in_chunks(max_batch_size, model, *args): split_args = list(zip(*list(map(lambda x: x.split(max_batch_size, dim=0), args)))) chunked_outputs = [model(*i) for i in split_args] if len(chunked_outputs) == 1: return chunked_outputs[0] return torch.cat(chunked_outputs, dim=0) def slerp(val, low, high): low_norm = low / torch.norm(low, dim=1, keepdim=True) high_norm = high / torch.norm(high, dim=1, keepdim=True) omega = torch.acos((low_norm * high_norm).sum(1)) so = torch.sin(omega) res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high return res def safe_div(n, d): try: res = n / d except ZeroDivisionError: prefix = '' if int(n >= 0) else '-' res = float(f'{prefix}inf') return res # loss functions def gen_hinge_loss(fake, real): return fake.mean() def hinge_loss(real, fake): return (F.relu(1 + real) + F.relu(1 - fake)).mean() def dual_contrastive_loss(real_logits, fake_logits): device = real_logits.device real_logits, fake_logits = map(lambda t: rearrange(t, '... -> (...)'), (real_logits, fake_logits)) def loss_half(t1, t2): t1 = rearrange(t1, 'i -> i ()') t2 = repeat(t2, 'j -> i j', i = t1.shape[0]) t = torch.cat((t1, t2), dim = -1) return F.cross_entropy(t, torch.zeros(t1.shape[0], device = device, dtype = torch.long)) return loss_half(real_logits, fake_logits) + loss_half(-fake_logits, -real_logits) # helper classes class NanException(Exception): pass class EMA(): def __init__(self, beta): super().__init__() self.beta = beta def update_average(self, old, new): if not exists(old): return new return old * self.beta + (1 - self.beta) * new class RandomApply(nn.Module): def __init__(self, prob, fn, fn_else = lambda x: x): super().__init__() self.fn = fn self.fn_else = fn_else self.prob = prob def forward(self, x): fn = self.fn if random() < self.prob else self.fn_else return fn(x) class ChanNorm(nn.Module): def __init__(self, dim, eps = 1e-5): super().__init__() self.eps = eps self.g = nn.Parameter(torch.ones(1, dim, 1, 1)) self.b = nn.Parameter(torch.zeros(1, dim, 1, 1)) def forward(self, x): var = torch.var(x, dim = 1, unbiased = False, keepdim = True) mean = torch.mean(x, dim = 1, keepdim = True) return (x - mean) / (var + self.eps).sqrt() * self.g + self.b class PreNorm(nn.Module): def __init__(self, dim, fn): super().__init__() self.fn = fn self.norm = ChanNorm(dim) def forward(self, x): return self.fn(self.norm(x)) class Residual(nn.Module): def __init__(self, fn): super().__init__() self.fn = fn def forward(self, x): return self.fn(x) + x class SumBranches(nn.Module): def __init__(self, branches): super().__init__() self.branches = nn.ModuleList(branches) def forward(self, x): return sum(map(lambda fn: fn(x), self.branches)) class Blur(nn.Module): def __init__(self): super().__init__() f = torch.Tensor([1, 2, 1]) self.register_buffer('f', f) def forward(self, x): f = self.f f = f[None, None, :] * f [None, :, None] return filter2d(x, f, normalized=True) # attention class DepthWiseConv2d(nn.Module): def __init__(self, dim_in, dim_out, kernel_size, padding = 0, stride = 1, bias = True): super().__init__() self.net = nn.Sequential( nn.Conv2d(dim_in, dim_in, kernel_size = kernel_size, padding = padding, groups = dim_in, stride = stride, bias = bias), nn.Conv2d(dim_in, dim_out, kernel_size = 1, bias = bias) ) def forward(self, x): return self.net(x) class LinearAttention(nn.Module): def __init__(self, dim, dim_head = 64, heads = 8): super().__init__() self.scale = dim_head ** -0.5 self.heads = heads inner_dim = dim_head * heads self.nonlin = nn.GELU() self.to_q = nn.Conv2d(dim, inner_dim, 1, bias = False) self.to_kv = DepthWiseConv2d(dim, inner_dim * 2, 3, padding = 1, bias = False) self.to_out = nn.Conv2d(inner_dim, dim, 1) def forward(self, fmap): h, x, y = self.heads, *fmap.shape[-2:] q, k, v = (self.to_q(fmap), *self.to_kv(fmap).chunk(2, dim = 1)) q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> (b h) (x y) c', h = h), (q, k, v)) q = q.softmax(dim = -1) k = k.softmax(dim = -2) q = q * self.scale context = einsum('b n d, b n e -> b d e', k, v) out = einsum('b n d, b d e -> b n e', q, context) out = rearrange(out, '(b h) (x y) d -> b (h d) x y', h = h, x = x, y = y) out = self.nonlin(out) return self.to_out(out) # global context network # https://arxiv.org/abs/2012.13375 # similar to squeeze-excite, but with a simplified attention pooling and a subsequent layer norm class GlobalContext(nn.Module): def __init__( self, *, chan_in, chan_out ): super().__init__() self.to_k = nn.Conv2d(chan_in, 1, 1) chan_intermediate = max(3, chan_out // 2) self.net = nn.Sequential( nn.Conv2d(chan_in, chan_intermediate, 1), nn.LeakyReLU(0.1), nn.Conv2d(chan_intermediate, chan_out, 1), nn.Sigmoid() ) def forward(self, x): context = self.to_k(x) context = context.flatten(2).softmax(dim = -1) out = einsum('b i n, b c n -> b c i', context, x.flatten(2)) out = out.unsqueeze(-1) return self.net(out) # dataset def convert_image_to(img_type, image): if image.mode != img_type: return image.convert(img_type) return image class identity(object): def __call__(self, tensor): return tensor class expand_greyscale(object): def __init__(self, transparent): self.transparent = transparent def __call__(self, tensor): channels = tensor.shape[0] num_target_channels = 4 if self.transparent else 3 if channels == num_target_channels: return tensor alpha = None if channels == 1: color = tensor.expand(3, -1, -1) elif channels == 2: color = tensor[:1].expand(3, -1, -1) alpha = tensor[1:] else: raise Exception(f'image with invalid number of channels given {channels}') if not exists(alpha) and self.transparent: alpha = torch.ones(1, *tensor.shape[1:], device=tensor.device) return color if not self.transparent else torch.cat((color, alpha)) class FCANet(nn.Module): def __init__( self, *, chan_in, chan_out, reduction = 4, width ): super().__init__() freq_w, freq_h = ([0] * 8), list(range(8)) # in paper, it seems 16 frequencies was ideal dct_weights = get_dct_weights(width, chan_in, [*freq_w, *freq_h], [*freq_h, *freq_w]) self.register_buffer('dct_weights', dct_weights) chan_intermediate = max(3, chan_out // reduction) self.net = nn.Sequential( nn.Conv2d(chan_in, chan_intermediate, 1), nn.LeakyReLU(0.1), nn.Conv2d(chan_intermediate, chan_out, 1), nn.Sigmoid() ) def forward(self, x): x = reduce(x * self.dct_weights, 'b c (h h1) (w w1) -> b c h1 w1', 'sum', h1 = 1, w1 = 1) return self.net(x) # modifiable global variables norm_class = nn.BatchNorm2d def upsample(scale_factor = 2): return nn.Upsample(scale_factor = scale_factor) # generative adversarial network class Generator(nn.Module): def __init__( self, *, image_size, latent_dim = 256, fmap_max = 512, fmap_inverse_coef = 12, transparent = False, greyscale = False, attn_res_layers = [], freq_chan_attn = False ): super().__init__() resolution = log2(image_size) assert is_power_of_two(image_size), 'image size must be a power of 2' if transparent: init_channel = 4 elif greyscale: init_channel = 1 else: init_channel = 3 fmap_max = default(fmap_max, latent_dim) self.initial_conv = nn.Sequential( nn.ConvTranspose2d(latent_dim, latent_dim * 2, 4), norm_class(latent_dim * 2), nn.GLU(dim = 1) ) num_layers = int(resolution) - 2 features = list(map(lambda n: (n, 2 ** (fmap_inverse_coef - n)), range(2, num_layers + 2))) features = list(map(lambda n: (n[0], min(n[1], fmap_max)), features)) features = list(map(lambda n: 3 if n[0] >= 8 else n[1], features)) features = [latent_dim, *features] in_out_features = list(zip(features[:-1], features[1:])) self.res_layers = range(2, num_layers + 2) self.layers = nn.ModuleList([]) self.res_to_feature_map = dict(zip(self.res_layers, in_out_features)) self.sle_map = ((3, 7), (4, 8), (5, 9), (6, 10)) self.sle_map = list(filter(lambda t: t[0] <= resolution and t[1] <= resolution, self.sle_map)) self.sle_map = dict(self.sle_map) self.num_layers_spatial_res = 1 for (res, (chan_in, chan_out)) in zip(self.res_layers, in_out_features): image_width = 2 ** res attn = None if image_width in attn_res_layers: attn = PreNorm(chan_in, LinearAttention(chan_in)) sle = None if res in self.sle_map: residual_layer = self.sle_map[res] sle_chan_out = self.res_to_feature_map[residual_layer - 1][-1] if freq_chan_attn: sle = FCANet( chan_in = chan_out, chan_out = sle_chan_out, width = 2 ** (res + 1) ) else: sle = GlobalContext( chan_in = chan_out, chan_out = sle_chan_out ) layer = nn.ModuleList([ nn.Sequential( upsample(), Blur(), nn.Conv2d(chan_in, chan_out * 2, 3, padding = 1), norm_class(chan_out * 2), nn.GLU(dim = 1) ), sle, attn ]) self.layers.append(layer) self.out_conv = nn.Conv2d(features[-1], init_channel, 3, padding = 1) def forward(self, x): x = rearrange(x, 'b c -> b c () ()') x = self.initial_conv(x) x = F.normalize(x, dim = 1) residuals = dict() for (res, (up, sle, attn)) in zip(self.res_layers, self.layers): if exists(attn): x = attn(x) + x x = up(x) if exists(sle): out_res = self.sle_map[res] residual = sle(x) residuals[out_res] = residual next_res = res + 1 if next_res in residuals: x = x * residuals[next_res] return self.out_conv(x) #### ACTUALLY LOAD THE MODEL AND DEFINE THE INTERFACE # Initialize a generator model gan_new = Generator(latent_dim=256, image_size=256, attn_res_layers = [32]) # Load from local saved state dict # gan_new.load_state_dict(torch.load('/content/orbgan_e3_state_dict.pt')) # Load from model hub: class GeneratorWithPyTorchModelHubMixin(gan_new.__class__, PyTorchModelHubMixin): pass gan_new.__class__ = GeneratorWithPyTorchModelHubMixin gan_new = gan_new.from_pretrained('johnowhitaker/orbgan_e1', latent_dim=256, image_size=256, attn_res_layers = [32]) gan_light = Generator(latent_dim=256, image_size=256, attn_res_layers = [32]) gan_light.__class__ = GeneratorWithPyTorchModelHubMixin gan_light = gan_light.from_pretrained('johnowhitaker/orbgan_light', latent_dim=256, image_size=256, attn_res_layers = [32]) gan_dark = Generator(latent_dim=256, image_size=256, attn_res_layers = [32]) gan_dark.__class__ = GeneratorWithPyTorchModelHubMixin gan_dark = gan_dark.from_pretrained('johnowhitaker/orbgan_dark', latent_dim=256, image_size=256, attn_res_layers = [32]) def gen_ims(n_rows, model='both'): if model == "both": ims = gan_new(torch.randn(int(n_rows)**2, 256)).clamp_(0., 1.) if model == "light": ims = gan_light(torch.randn(int(n_rows)**2, 256)).clamp_(0., 1.) if model == "dark": ims = gan_dark(torch.randn(int(n_rows)**2, 256)).clamp_(0., 1.) grid = torchvision.utils.make_grid(ims, nrow=int(n_rows)).permute(1, 2, 0).detach().cpu().numpy() return (grid*255).astype(np.uint8) iface = gr.Interface(fn=gen_ims, inputs=[gr.inputs.Slider(minimum=1, maximum=6, step=1, default=3,label="N rows"), gr.inputs.Dropdown(["both", "light", "dark"], type="value", default="dark", label="Orb Type (model)", optional=False)], outputs=[gr.outputs.Image(type="numpy", label="Generated Images")], title='Demo for orbgan models', article = 'Models are lightweight-gans trained on johnowhitaker/glid3_orbs. See https://huggingface.co/johnowhitaker/orbgan_e1 for training and inference scripts' ) iface.launch()