Spaces:
Runtime error
Runtime error
#@title Gradio demo (used in space: ) | |
from matplotlib import pyplot as plt | |
from huggingface_hub import PyTorchModelHubMixin | |
import numpy as np | |
import gradio as gr | |
#@title Defining Generator and associated code ourselves without the GPU requirements | |
import os | |
import json | |
import multiprocessing | |
from random import random | |
import math | |
from math import log2, floor | |
from functools import partial | |
from contextlib import contextmanager, ExitStack | |
from pathlib import Path | |
from shutil import rmtree | |
import torch | |
from torch.cuda.amp import autocast, GradScaler | |
from torch.optim import Adam | |
from torch import nn, einsum | |
import torch.nn.functional as F | |
from torch.utils.data import Dataset, DataLoader | |
from torch.autograd import grad as torch_grad | |
from torch.utils.data.distributed import DistributedSampler | |
from torch.nn.parallel import DistributedDataParallel as DDP | |
from PIL import Image | |
import torchvision | |
from torchvision import transforms | |
from kornia.filters import filter2d | |
from tqdm import tqdm | |
from einops import rearrange, reduce, repeat | |
from adabelief_pytorch import AdaBelief | |
# helpers | |
def DiffAugment(x, types=[]): | |
for p in types: | |
for f in AUGMENT_FNS[p]: | |
x = f(x) | |
return x.contiguous() | |
def exists(val): | |
return val is not None | |
def null_context(): | |
yield | |
def combine_contexts(contexts): | |
def multi_contexts(): | |
with ExitStack() as stack: | |
yield [stack.enter_context(ctx()) for ctx in contexts] | |
return multi_contexts | |
def is_power_of_two(val): | |
return log2(val).is_integer() | |
def default(val, d): | |
return val if exists(val) else d | |
def set_requires_grad(model, bool): | |
for p in model.parameters(): | |
p.requires_grad = bool | |
def cycle(iterable): | |
while True: | |
for i in iterable: | |
yield i | |
def raise_if_nan(t): | |
if torch.isnan(t): | |
raise NanException | |
def gradient_accumulate_contexts(gradient_accumulate_every, is_ddp, ddps): | |
if is_ddp: | |
num_no_syncs = gradient_accumulate_every - 1 | |
head = [combine_contexts(map(lambda ddp: ddp.no_sync, ddps))] * num_no_syncs | |
tail = [null_context] | |
contexts = head + tail | |
else: | |
contexts = [null_context] * gradient_accumulate_every | |
for context in contexts: | |
with context(): | |
yield | |
def evaluate_in_chunks(max_batch_size, model, *args): | |
split_args = list(zip(*list(map(lambda x: x.split(max_batch_size, dim=0), args)))) | |
chunked_outputs = [model(*i) for i in split_args] | |
if len(chunked_outputs) == 1: | |
return chunked_outputs[0] | |
return torch.cat(chunked_outputs, dim=0) | |
def slerp(val, low, high): | |
low_norm = low / torch.norm(low, dim=1, keepdim=True) | |
high_norm = high / torch.norm(high, dim=1, keepdim=True) | |
omega = torch.acos((low_norm * high_norm).sum(1)) | |
so = torch.sin(omega) | |
res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high | |
return res | |
def safe_div(n, d): | |
try: | |
res = n / d | |
except ZeroDivisionError: | |
prefix = '' if int(n >= 0) else '-' | |
res = float(f'{prefix}inf') | |
return res | |
# loss functions | |
def gen_hinge_loss(fake, real): | |
return fake.mean() | |
def hinge_loss(real, fake): | |
return (F.relu(1 + real) + F.relu(1 - fake)).mean() | |
def dual_contrastive_loss(real_logits, fake_logits): | |
device = real_logits.device | |
real_logits, fake_logits = map(lambda t: rearrange(t, '... -> (...)'), (real_logits, fake_logits)) | |
def loss_half(t1, t2): | |
t1 = rearrange(t1, 'i -> i ()') | |
t2 = repeat(t2, 'j -> i j', i = t1.shape[0]) | |
t = torch.cat((t1, t2), dim = -1) | |
return F.cross_entropy(t, torch.zeros(t1.shape[0], device = device, dtype = torch.long)) | |
return loss_half(real_logits, fake_logits) + loss_half(-fake_logits, -real_logits) | |
# helper classes | |
class NanException(Exception): | |
pass | |
class EMA(): | |
def __init__(self, beta): | |
super().__init__() | |
self.beta = beta | |
def update_average(self, old, new): | |
if not exists(old): | |
return new | |
return old * self.beta + (1 - self.beta) * new | |
class RandomApply(nn.Module): | |
def __init__(self, prob, fn, fn_else = lambda x: x): | |
super().__init__() | |
self.fn = fn | |
self.fn_else = fn_else | |
self.prob = prob | |
def forward(self, x): | |
fn = self.fn if random() < self.prob else self.fn_else | |
return fn(x) | |
class ChanNorm(nn.Module): | |
def __init__(self, dim, eps = 1e-5): | |
super().__init__() | |
self.eps = eps | |
self.g = nn.Parameter(torch.ones(1, dim, 1, 1)) | |
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1)) | |
def forward(self, x): | |
var = torch.var(x, dim = 1, unbiased = False, keepdim = True) | |
mean = torch.mean(x, dim = 1, keepdim = True) | |
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b | |
class PreNorm(nn.Module): | |
def __init__(self, dim, fn): | |
super().__init__() | |
self.fn = fn | |
self.norm = ChanNorm(dim) | |
def forward(self, x): | |
return self.fn(self.norm(x)) | |
class Residual(nn.Module): | |
def __init__(self, fn): | |
super().__init__() | |
self.fn = fn | |
def forward(self, x): | |
return self.fn(x) + x | |
class SumBranches(nn.Module): | |
def __init__(self, branches): | |
super().__init__() | |
self.branches = nn.ModuleList(branches) | |
def forward(self, x): | |
return sum(map(lambda fn: fn(x), self.branches)) | |
class Blur(nn.Module): | |
def __init__(self): | |
super().__init__() | |
f = torch.Tensor([1, 2, 1]) | |
self.register_buffer('f', f) | |
def forward(self, x): | |
f = self.f | |
f = f[None, None, :] * f [None, :, None] | |
return filter2d(x, f, normalized=True) | |
class Noise(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.weight = nn.Parameter(torch.zeros(1)) | |
def forward(self, x, noise = None): | |
b, _, h, w, device = *x.shape, x.device | |
if not exists(noise): | |
noise = torch.randn(b, 1, h, w, device = device) | |
return x + self.weight * noise | |
def Conv2dSame(dim_in, dim_out, kernel_size, bias = True): | |
pad_left = kernel_size // 2 | |
pad_right = (pad_left - 1) if (kernel_size % 2) == 0 else pad_left | |
return nn.Sequential( | |
nn.ZeroPad2d((pad_left, pad_right, pad_left, pad_right)), | |
nn.Conv2d(dim_in, dim_out, kernel_size, bias = bias) | |
) | |
# attention | |
class DepthWiseConv2d(nn.Module): | |
def __init__(self, dim_in, dim_out, kernel_size, padding = 0, stride = 1, bias = True): | |
super().__init__() | |
self.net = nn.Sequential( | |
nn.Conv2d(dim_in, dim_in, kernel_size = kernel_size, padding = padding, groups = dim_in, stride = stride, bias = bias), | |
nn.Conv2d(dim_in, dim_out, kernel_size = 1, bias = bias) | |
) | |
def forward(self, x): | |
return self.net(x) | |
class LinearAttention(nn.Module): | |
def __init__(self, dim, dim_head = 64, heads = 8, kernel_size = 3): | |
super().__init__() | |
self.scale = dim_head ** -0.5 | |
self.heads = heads | |
self.dim_head = dim_head | |
inner_dim = dim_head * heads | |
self.kernel_size = kernel_size | |
self.nonlin = nn.GELU() | |
self.to_lin_q = nn.Conv2d(dim, inner_dim, 1, bias = False) | |
self.to_lin_kv = DepthWiseConv2d(dim, inner_dim * 2, 3, padding = 1, bias = False) | |
self.to_q = nn.Conv2d(dim, inner_dim, 1, bias = False) | |
self.to_kv = nn.Conv2d(dim, inner_dim * 2, 1, bias = False) | |
self.to_out = nn.Conv2d(inner_dim * 2, dim, 1) | |
def forward(self, fmap): | |
h, x, y = self.heads, *fmap.shape[-2:] | |
# linear attention | |
lin_q, lin_k, lin_v = (self.to_lin_q(fmap), *self.to_lin_kv(fmap).chunk(2, dim = 1)) | |
lin_q, lin_k, lin_v = map(lambda t: rearrange(t, 'b (h c) x y -> (b h) (x y) c', h = h), (lin_q, lin_k, lin_v)) | |
lin_q = lin_q.softmax(dim = -1) | |
lin_k = lin_k.softmax(dim = -2) | |
lin_q = lin_q * self.scale | |
context = einsum('b n d, b n e -> b d e', lin_k, lin_v) | |
lin_out = einsum('b n d, b d e -> b n e', lin_q, context) | |
lin_out = rearrange(lin_out, '(b h) (x y) d -> b (h d) x y', h = h, x = x, y = y) | |
# conv-like full attention | |
q, k, v = (self.to_q(fmap), *self.to_kv(fmap).chunk(2, dim = 1)) | |
q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> (b h) c x y', h = h), (q, k, v)) | |
k = F.unfold(k, kernel_size = self.kernel_size, padding = self.kernel_size // 2) | |
v = F.unfold(v, kernel_size = self.kernel_size, padding = self.kernel_size // 2) | |
k, v = map(lambda t: rearrange(t, 'b (d j) n -> b n j d', d = self.dim_head), (k, v)) | |
q = rearrange(q, 'b c ... -> b (...) c') * self.scale | |
sim = einsum('b i d, b i j d -> b i j', q, k) | |
sim = sim - sim.amax(dim = -1, keepdim = True).detach() | |
attn = sim.softmax(dim = -1) | |
full_out = einsum('b i j, b i j d -> b i d', attn, v) | |
full_out = rearrange(full_out, '(b h) (x y) d -> b (h d) x y', h = h, x = x, y = y) | |
# add outputs of linear attention + conv like full attention | |
lin_out = self.nonlin(lin_out) | |
out = torch.cat((lin_out, full_out), dim = 1) | |
return self.to_out(out) | |
# dataset | |
def convert_image_to(img_type, image): | |
if image.mode != img_type: | |
return image.convert(img_type) | |
return image | |
class identity(object): | |
def __call__(self, tensor): | |
return tensor | |
class expand_greyscale(object): | |
def __init__(self, transparent): | |
self.transparent = transparent | |
def __call__(self, tensor): | |
channels = tensor.shape[0] | |
num_target_channels = 4 if self.transparent else 3 | |
if channels == num_target_channels: | |
return tensor | |
alpha = None | |
if channels == 1: | |
color = tensor.expand(3, -1, -1) | |
elif channels == 2: | |
color = tensor[:1].expand(3, -1, -1) | |
alpha = tensor[1:] | |
else: | |
raise Exception(f'image with invalid number of channels given {channels}') | |
if not exists(alpha) and self.transparent: | |
alpha = torch.ones(1, *tensor.shape[1:], device=tensor.device) | |
return color if not self.transparent else torch.cat((color, alpha)) | |
def resize_to_minimum_size(min_size, image): | |
if max(*image.size) < min_size: | |
return torchvision.transforms.functional.resize(image, min_size) | |
return image | |
class ImageDataset(Dataset): | |
def __init__( | |
self, | |
folder, | |
image_size, | |
transparent = False, | |
greyscale = False, | |
aug_prob = 0. | |
): | |
super().__init__() | |
self.folder = folder | |
self.image_size = image_size | |
self.paths = [p for ext in EXTS for p in Path(f'{folder}').glob(f'**/*.{ext}')] | |
assert len(self.paths) > 0, f'No images were found in {folder} for training' | |
if transparent: | |
num_channels = 4 | |
pillow_mode = 'RGBA' | |
expand_fn = expand_greyscale(transparent) | |
elif greyscale: | |
num_channels = 1 | |
pillow_mode = 'L' | |
expand_fn = identity() | |
else: | |
num_channels = 3 | |
pillow_mode = 'RGB' | |
expand_fn = expand_greyscale(transparent) | |
convert_image_fn = partial(convert_image_to, pillow_mode) | |
self.transform = transforms.Compose([ | |
transforms.Lambda(convert_image_fn), | |
transforms.Lambda(partial(resize_to_minimum_size, image_size)), | |
transforms.Resize(image_size), | |
RandomApply(aug_prob, transforms.RandomResizedCrop(image_size, scale=(0.5, 1.0), ratio=(0.98, 1.02)), transforms.CenterCrop(image_size)), | |
transforms.ToTensor(), | |
transforms.Lambda(expand_fn) | |
]) | |
def __len__(self): | |
return len(self.paths) | |
def __getitem__(self, index): | |
path = self.paths[index] | |
img = Image.open(path) | |
return self.transform(img) | |
# augmentations | |
def random_hflip(tensor, prob): | |
if prob > random(): | |
return tensor | |
return torch.flip(tensor, dims=(3,)) | |
class AugWrapper(nn.Module): | |
def __init__(self, D, image_size): | |
super().__init__() | |
self.D = D | |
def forward(self, images, prob = 0., types = [], detach = False, **kwargs): | |
context = torch.no_grad if detach else null_context | |
with context(): | |
if random() < prob: | |
images = random_hflip(images, prob=0.5) | |
images = DiffAugment(images, types=types) | |
return self.D(images, **kwargs) | |
# modifiable global variables | |
norm_class = nn.BatchNorm2d | |
def upsample(scale_factor = 2): | |
return nn.Upsample(scale_factor = scale_factor) | |
# squeeze excitation classes | |
# global context network | |
# https://arxiv.org/abs/2012.13375 | |
# similar to squeeze-excite, but with a simplified attention pooling and a subsequent layer norm | |
class GlobalContext(nn.Module): | |
def __init__( | |
self, | |
*, | |
chan_in, | |
chan_out | |
): | |
super().__init__() | |
self.to_k = nn.Conv2d(chan_in, 1, 1) | |
chan_intermediate = max(3, chan_out // 2) | |
self.net = nn.Sequential( | |
nn.Conv2d(chan_in, chan_intermediate, 1), | |
nn.LeakyReLU(0.1), | |
nn.Conv2d(chan_intermediate, chan_out, 1), | |
nn.Sigmoid() | |
) | |
def forward(self, x): | |
context = self.to_k(x) | |
context = context.flatten(2).softmax(dim = -1) | |
out = einsum('b i n, b c n -> b c i', context, x.flatten(2)) | |
out = out.unsqueeze(-1) | |
return self.net(out) | |
# frequency channel attention | |
# https://arxiv.org/abs/2012.11879 | |
def get_1d_dct(i, freq, L): | |
result = math.cos(math.pi * freq * (i + 0.5) / L) / math.sqrt(L) | |
return result * (1 if freq == 0 else math.sqrt(2)) | |
def get_dct_weights(width, channel, fidx_u, fidx_v): | |
dct_weights = torch.zeros(1, channel, width, width) | |
c_part = channel // len(fidx_u) | |
for i, (u_x, v_y) in enumerate(zip(fidx_u, fidx_v)): | |
for x in range(width): | |
for y in range(width): | |
coor_value = get_1d_dct(x, u_x, width) * get_1d_dct(y, v_y, width) | |
dct_weights[:, i * c_part: (i + 1) * c_part, x, y] = coor_value | |
return dct_weights | |
class FCANet(nn.Module): | |
def __init__( | |
self, | |
*, | |
chan_in, | |
chan_out, | |
reduction = 4, | |
width | |
): | |
super().__init__() | |
freq_w, freq_h = ([0] * 8), list(range(8)) # in paper, it seems 16 frequencies was ideal | |
dct_weights = get_dct_weights(width, chan_in, [*freq_w, *freq_h], [*freq_h, *freq_w]) | |
self.register_buffer('dct_weights', dct_weights) | |
chan_intermediate = max(3, chan_out // reduction) | |
self.net = nn.Sequential( | |
nn.Conv2d(chan_in, chan_intermediate, 1), | |
nn.LeakyReLU(0.1), | |
nn.Conv2d(chan_intermediate, chan_out, 1), | |
nn.Sigmoid() | |
) | |
def forward(self, x): | |
x = reduce(x * self.dct_weights, 'b c (h h1) (w w1) -> b c h1 w1', 'sum', h1 = 1, w1 = 1) | |
return self.net(x) | |
# generative adversarial network | |
class Generator(nn.Module): | |
def __init__( | |
self, | |
*, | |
image_size, | |
latent_dim = 256, | |
fmap_max = 512, | |
fmap_inverse_coef = 12, | |
transparent = False, | |
greyscale = False, | |
attn_res_layers = [], | |
freq_chan_attn = False | |
): | |
super().__init__() | |
resolution = log2(image_size) | |
assert is_power_of_two(image_size), 'image size must be a power of 2' | |
if transparent: | |
init_channel = 4 | |
elif greyscale: | |
init_channel = 1 | |
else: | |
init_channel = 3 | |
fmap_max = default(fmap_max, latent_dim) | |
self.initial_conv = nn.Sequential( | |
nn.ConvTranspose2d(latent_dim, latent_dim * 2, 4), | |
norm_class(latent_dim * 2), | |
nn.GLU(dim = 1) | |
) | |
num_layers = int(resolution) - 2 | |
features = list(map(lambda n: (n, 2 ** (fmap_inverse_coef - n)), range(2, num_layers + 2))) | |
features = list(map(lambda n: (n[0], min(n[1], fmap_max)), features)) | |
features = list(map(lambda n: 3 if n[0] >= 8 else n[1], features)) | |
features = [latent_dim, *features] | |
in_out_features = list(zip(features[:-1], features[1:])) | |
self.res_layers = range(2, num_layers + 2) | |
self.layers = nn.ModuleList([]) | |
self.res_to_feature_map = dict(zip(self.res_layers, in_out_features)) | |
self.sle_map = ((3, 7), (4, 8), (5, 9), (6, 10)) | |
self.sle_map = list(filter(lambda t: t[0] <= resolution and t[1] <= resolution, self.sle_map)) | |
self.sle_map = dict(self.sle_map) | |
self.num_layers_spatial_res = 1 | |
for (res, (chan_in, chan_out)) in zip(self.res_layers, in_out_features): | |
image_width = 2 ** res | |
attn = None | |
if image_width in attn_res_layers: | |
attn = PreNorm(chan_in, LinearAttention(chan_in)) | |
sle = None | |
if res in self.sle_map: | |
residual_layer = self.sle_map[res] | |
sle_chan_out = self.res_to_feature_map[residual_layer - 1][-1] | |
if freq_chan_attn: | |
sle = FCANet( | |
chan_in = chan_out, | |
chan_out = sle_chan_out, | |
width = 2 ** (res + 1) | |
) | |
else: | |
sle = GlobalContext( | |
chan_in = chan_out, | |
chan_out = sle_chan_out | |
) | |
layer = nn.ModuleList([ | |
nn.Sequential( | |
upsample(), | |
Blur(), | |
Conv2dSame(chan_in, chan_out * 2, 4), | |
Noise(), | |
norm_class(chan_out * 2), | |
nn.GLU(dim = 1) | |
), | |
sle, | |
attn | |
]) | |
self.layers.append(layer) | |
self.out_conv = nn.Conv2d(features[-1], init_channel, 3, padding = 1) | |
def forward(self, x): | |
x = rearrange(x, 'b c -> b c () ()') | |
x = self.initial_conv(x) | |
x = F.normalize(x, dim = 1) | |
residuals = dict() | |
for (res, (up, sle, attn)) in zip(self.res_layers, self.layers): | |
if exists(attn): | |
x = attn(x) + x | |
x = up(x) | |
if exists(sle): | |
out_res = self.sle_map[res] | |
residual = sle(x) | |
residuals[out_res] = residual | |
next_res = res + 1 | |
if next_res in residuals: | |
x = x * residuals[next_res] | |
return self.out_conv(x) | |
# Initialize a generator model | |
gan_new = Generator(latent_dim=256, image_size=256, attn_res_layers = [32]) | |
# Load from local saved state dict | |
# gan_new.load_state_dict(torch.load('/content/orbgan_e3_state_dict.pt')) | |
# Load from model hub: | |
class GeneratorWithPyTorchModelHubMixin(gan_new.__class__, PyTorchModelHubMixin): | |
pass | |
gan_new.__class__ = GeneratorWithPyTorchModelHubMixin | |
gan_new = gan_new.from_pretrained('johnowhitaker/colorb_gan', latent_dim=256, image_size=256, attn_res_layers = [32]) | |
def gen_ims(n_rows): | |
ims = gan_new(torch.randn(int(n_rows)**2, 256)).clamp_(0., 1.) | |
grid = torchvision.utils.make_grid(ims, nrow=int(n_rows)).permute(1, 2, 0).detach().cpu().numpy() | |
return (grid*255).astype(np.uint8) | |
iface = gr.Interface(fn=gen_ims, | |
inputs=[gr.inputs.Slider(minimum=1, maximum=6, step=1, default=3,label="N rows")], | |
outputs=[gr.outputs.Image(type="numpy", label="Generated Images")], | |
title='Demo for Colorbgan model', | |
article = 'A lightweight-gans trained on johnowhitaker/colorbs. See https://huggingface.co/johnowhitaker/orbgan_e1 for training and inference scripts' | |
) | |
iface.launch() |