showgan's picture
Training in progress, step 1000
09b13b3 verified
raw
history blame
No virus
52.8 kB
import os
import json
import tempfile
from random import random
import math
from math import log2, floor
from pathlib import Path
from functools import partial
from contextlib import contextmanager, ExitStack
from pathlib import Path
from shutil import rmtree
import torch
from torch.optim import Adam
from torch import nn, einsum
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.autograd import grad as torch_grad
from PIL import Image
import torchvision
from torchvision import transforms
from torchvision.utils import save_image
from kornia.filters import filter2d
from huggan.pytorch.lightweight_gan.diff_augment import DiffAugment
from tqdm import tqdm
from einops import rearrange, reduce, repeat
from datasets import load_dataset
from accelerate import Accelerator, DistributedDataParallelKwargs
from huggingface_hub import hf_hub_download, create_repo
from huggan.pytorch.huggan_mixin import HugGANModelHubMixin
from huggan.utils.hub import get_full_repo_name
# constants
# NUM_CORES = multiprocessing.cpu_count()
EXTS = ['jpg', 'jpeg', 'png']
PYTORCH_WEIGHTS_NAME = 'model.pt'
# helpers
def exists(val):
return val is not None
@contextmanager
def null_context():
yield
def is_power_of_two(val):
return log2(val).is_integer()
def default(val, d):
return val if exists(val) else d
def set_requires_grad(model, bool):
for p in model.parameters():
p.requires_grad = bool
def cycle(iterable):
while True:
for i in iterable:
yield i
def raise_if_nan(t):
if torch.isnan(t):
raise NanException
def evaluate_in_chunks(max_batch_size, model, *args):
split_args = list(zip(*list(map(lambda x: x.split(max_batch_size, dim=0), args))))
chunked_outputs = [model(*i) for i in split_args]
if len(chunked_outputs) == 1:
return chunked_outputs[0]
return torch.cat(chunked_outputs, dim=0)
def slerp(val, low, high):
low_norm = low / torch.norm(low, dim=1, keepdim=True)
high_norm = high / torch.norm(high, dim=1, keepdim=True)
omega = torch.acos((low_norm * high_norm).sum(1))
so = torch.sin(omega)
res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high
return res
def safe_div(n, d):
try:
res = n / d
except ZeroDivisionError:
prefix = '' if int(n >= 0) else '-'
res = float(f'{prefix}inf')
return res
# loss functions
def gen_hinge_loss(fake, real):
return fake.mean()
def hinge_loss(real, fake):
return (F.relu(1 + real) + F.relu(1 - fake)).mean()
def dual_contrastive_loss(real_logits, fake_logits):
device = real_logits.device
real_logits, fake_logits = map(lambda t: rearrange(t, '... -> (...)'), (real_logits, fake_logits))
def loss_half(t1, t2):
t1 = rearrange(t1, 'i -> i ()')
t2 = repeat(t2, 'j -> i j', i=t1.shape[0])
t = torch.cat((t1, t2), dim=-1)
return F.cross_entropy(t, torch.zeros(t1.shape[0], device=device, dtype=torch.long))
return loss_half(real_logits, fake_logits) + loss_half(-fake_logits, -real_logits)
# helper classes
class NanException(Exception):
pass
class EMA():
def __init__(self, beta):
super().__init__()
self.beta = beta
def update_average(self, old, new):
if not exists(old):
return new
return old * self.beta + (1 - self.beta) * new
class RandomApply(nn.Module):
def __init__(self, prob, fn, fn_else=lambda x: x):
super().__init__()
self.fn = fn
self.fn_else = fn_else
self.prob = prob
def forward(self, x):
fn = self.fn if random() < self.prob else self.fn_else
return fn(x)
class ChanNorm(nn.Module):
def __init__(self, dim, eps=1e-5):
super().__init__()
self.eps = eps
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
def forward(self, x):
var = torch.var(x, dim=1, unbiased=False, keepdim=True)
mean = torch.mean(x, dim=1, keepdim=True)
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
self.norm = ChanNorm(dim)
def forward(self, x):
return self.fn(self.norm(x))
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x):
return self.fn(x) + x
class SumBranches(nn.Module):
def __init__(self, branches):
super().__init__()
self.branches = nn.ModuleList(branches)
def forward(self, x):
return sum(map(lambda fn: fn(x), self.branches))
class Fuzziness(nn.Module):
def __init__(self):
super().__init__()
f = torch.Tensor([1, 2, 1])
self.register_buffer('f', f)
def forward(self, x):
f = self.f
f = f[None, None, :] * f[None, :, None]
return filter2d(x, f, normalized=True)
Blur = nn.Identity
# attention
class DepthWiseConv2d(nn.Module):
def __init__(self, dim_in, dim_out, kernel_size, padding=0, stride=1, bias=True):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(dim_in, dim_in, kernel_size=kernel_size, padding=padding, groups=dim_in, stride=stride,
bias=bias),
nn.Conv2d(dim_in, dim_out, kernel_size=1, bias=bias)
)
def forward(self, x):
return self.net(x)
class LinearAttention(nn.Module):
def __init__(self, dim, dim_head=64, heads=8):
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
inner_dim = dim_head * heads
self.nonlin = nn.GELU()
self.to_q = nn.Conv2d(dim, inner_dim, 1, bias=False)
self.to_kv = DepthWiseConv2d(dim, inner_dim * 2, 3, padding=1, bias=False)
self.to_out = nn.Conv2d(inner_dim, dim, 1)
def forward(self, fmap):
h, x, y = self.heads, *fmap.shape[-2:]
q, k, v = (self.to_q(fmap), *self.to_kv(fmap).chunk(2, dim=1))
q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> (b h) (x y) c', h=h), (q, k, v))
q = q.softmax(dim=-1)
k = k.softmax(dim=-2)
q = q * self.scale
context = einsum('b n d, b n e -> b d e', k, v)
out = einsum('b n d, b d e -> b n e', q, context)
out = rearrange(out, '(b h) (x y) d -> b (h d) x y', h=h, x=x, y=y)
out = self.nonlin(out)
return self.to_out(out)
# dataset
def convert_image_to(img_type, image):
if image.mode != img_type:
return image.convert(img_type)
return image
class identity(object):
def __call__(self, tensor):
return tensor
class expand_greyscale(object):
def __init__(self, transparent):
self.transparent = transparent
def __call__(self, tensor):
channels = tensor.shape[0]
num_target_channels = 4 if self.transparent else 3
if channels == num_target_channels:
return tensor
alpha = None
if channels == 1:
color = tensor.expand(3, -1, -1)
elif channels == 2:
color = tensor[:1].expand(3, -1, -1)
alpha = tensor[1:]
else:
raise Exception(f'image with invalid number of channels given {channels}')
if not exists(alpha) and self.transparent:
alpha = torch.ones(1, *tensor.shape[1:], device=tensor.device)
return color if not self.transparent else torch.cat((color, alpha))
def resize_to_minimum_size(min_size, image):
if max(*image.size) < min_size:
return torchvision.transforms.functional.resize(image, min_size)
return image
# augmentations
def random_hflip(tensor, prob):
if prob > random():
return tensor
return torch.flip(tensor, dims=(3,))
class AugWrapper(nn.Module):
def __init__(self, D, image_size):
super().__init__()
self.D = D
def forward(self, images, prob=0., types=[], detach=False, **kwargs):
context = torch.no_grad if detach else null_context
with context():
if random() < prob:
images = random_hflip(images, prob=0.5)
images = DiffAugment(images, types=types)
return self.D(images, **kwargs)
# modifiable global variables
norm_class = nn.BatchNorm2d
def upsample(scale_factor=2):
return nn.Upsample(scale_factor=scale_factor)
# squeeze excitation classes
# global context network
# https://arxiv.org/abs/2012.13375
# similar to squeeze-excite, but with a simplified attention pooling and a subsequent layer norm
class GlobalContext(nn.Module):
def __init__(
self,
*,
chan_in,
chan_out
):
super().__init__()
self.to_k = nn.Conv2d(chan_in, 1, 1)
chan_intermediate = max(3, chan_out // 2)
self.net = nn.Sequential(
nn.Conv2d(chan_in, chan_intermediate, 1),
nn.LeakyReLU(0.1),
nn.Conv2d(chan_intermediate, chan_out, 1),
nn.Sigmoid()
)
def forward(self, x):
context = self.to_k(x)
context = context.flatten(2).softmax(dim=-1)
out = einsum('b i n, b c n -> b c i', context, x.flatten(2))
out = out.unsqueeze(-1)
return self.net(out)
# frequency channel attention
# https://arxiv.org/abs/2012.11879
def get_1d_dct(i, freq, L):
result = math.cos(math.pi * freq * (i + 0.5) / L) / math.sqrt(L)
return result * (1 if freq == 0 else math.sqrt(2))
def get_dct_weights(width, channel, fidx_u, fidx_v):
dct_weights = torch.zeros(1, channel, width, width)
c_part = channel // len(fidx_u)
for i, (u_x, v_y) in enumerate(zip(fidx_u, fidx_v)):
for x in range(width):
for y in range(width):
coor_value = get_1d_dct(x, u_x, width) * get_1d_dct(y, v_y, width)
dct_weights[:, i * c_part: (i + 1) * c_part, x, y] = coor_value
return dct_weights
class FCANet(nn.Module):
def __init__(
self,
*,
chan_in,
chan_out,
reduction=4,
width
):
super().__init__()
freq_w, freq_h = ([0] * 8), list(range(8)) # in paper, it seems 16 frequencies was ideal
dct_weights = get_dct_weights(width, chan_in, [*freq_w, *freq_h], [*freq_h, *freq_w])
self.register_buffer('dct_weights', dct_weights)
chan_intermediate = max(3, chan_out // reduction)
self.net = nn.Sequential(
nn.Conv2d(chan_in, chan_intermediate, 1),
nn.LeakyReLU(0.1),
nn.Conv2d(chan_intermediate, chan_out, 1),
nn.Sigmoid()
)
def forward(self, x):
x = reduce(x * self.dct_weights, 'b c (h h1) (w w1) -> b c h1 w1', 'sum', h1=1, w1=1)
return self.net(x)
# generative adversarial network
class Generator(nn.Module):
def __init__(
self,
*,
image_size,
latent_dim=256,
fmap_max=512,
fmap_inverse_coef=12,
transparent=False,
greyscale=False,
attn_res_layers=[],
freq_chan_attn=False
):
super().__init__()
resolution = log2(image_size)
assert is_power_of_two(image_size), 'image size must be a power of 2'
if transparent:
init_channel = 4
elif greyscale:
init_channel = 1
else:
init_channel = 3
fmap_max = default(fmap_max, latent_dim)
self.initial_conv = nn.Sequential(
nn.ConvTranspose2d(latent_dim, latent_dim * 2, 4),
norm_class(latent_dim * 2),
nn.GLU(dim=1)
)
num_layers = int(resolution) - 2
features = list(map(lambda n: (n, 2 ** (fmap_inverse_coef - n)), range(2, num_layers + 2)))
features = list(map(lambda n: (n[0], min(n[1], fmap_max)), features))
features = list(map(lambda n: 3 if n[0] >= 8 else n[1], features))
features = [latent_dim, *features]
in_out_features = list(zip(features[:-1], features[1:]))
self.res_layers = range(2, num_layers + 2)
self.layers = nn.ModuleList([])
self.res_to_feature_map = dict(zip(self.res_layers, in_out_features))
self.sle_map = ((3, 7), (4, 8), (5, 9), (6, 10))
self.sle_map = list(filter(lambda t: t[0] <= resolution and t[1] <= resolution, self.sle_map))
self.sle_map = dict(self.sle_map)
self.num_layers_spatial_res = 1
for (res, (chan_in, chan_out)) in zip(self.res_layers, in_out_features):
image_width = 2 ** res
attn = None
if image_width in attn_res_layers:
attn = PreNorm(chan_in, LinearAttention(chan_in))
sle = None
if res in self.sle_map:
residual_layer = self.sle_map[res]
sle_chan_out = self.res_to_feature_map[residual_layer - 1][-1]
if freq_chan_attn:
sle = FCANet(
chan_in=chan_out,
chan_out=sle_chan_out,
width=2 ** (res + 1)
)
else:
sle = GlobalContext(
chan_in=chan_out,
chan_out=sle_chan_out
)
layer = nn.ModuleList([
nn.Sequential(
upsample(),
Blur(),
nn.Conv2d(chan_in, chan_out * 2, 3, padding=1),
norm_class(chan_out * 2),
nn.GLU(dim=1)
),
sle,
attn
])
self.layers.append(layer)
self.out_conv = nn.Conv2d(features[-1], init_channel, 3, padding=1)
def forward(self, x):
x = rearrange(x, 'b c -> b c () ()')
x = self.initial_conv(x)
x = F.normalize(x, dim=1)
residuals = dict()
for (res, (up, sle, attn)) in zip(self.res_layers, self.layers):
if exists(attn):
x = attn(x) + x
x = up(x)
if exists(sle):
out_res = self.sle_map[res]
residual = sle(x)
residuals[out_res] = residual
next_res = res + 1
if next_res in residuals:
x = x * residuals[next_res]
return self.out_conv(x)
class SimpleDecoder(nn.Module):
def __init__(
self,
*,
chan_in,
chan_out=3,
num_upsamples=4,
):
super().__init__()
self.layers = nn.ModuleList([])
final_chan = chan_out
chans = chan_in
for ind in range(num_upsamples):
last_layer = ind == (num_upsamples - 1)
chan_out = chans if not last_layer else final_chan * 2
layer = nn.Sequential(
upsample(),
nn.Conv2d(chans, chan_out, 3, padding=1),
nn.GLU(dim=1)
)
self.layers.append(layer)
chans //= 2
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
class Discriminator(nn.Module):
def __init__(
self,
*,
image_size,
fmap_max=512,
fmap_inverse_coef=12,
transparent=False,
greyscale=False,
disc_output_size=5,
attn_res_layers=[]
):
super().__init__()
resolution = log2(image_size)
assert is_power_of_two(image_size), 'image size must be a power of 2'
assert disc_output_size in {1, 5}, 'discriminator output dimensions can only be 5x5 or 1x1'
resolution = int(resolution)
if transparent:
init_channel = 4
elif greyscale:
init_channel = 1
else:
init_channel = 3
num_non_residual_layers = max(0, int(resolution) - 8)
num_residual_layers = 8 - 3
non_residual_resolutions = range(min(8, resolution), 2, -1)
features = list(map(lambda n: (n, 2 ** (fmap_inverse_coef - n)), non_residual_resolutions))
features = list(map(lambda n: (n[0], min(n[1], fmap_max)), features))
if num_non_residual_layers == 0:
res, _ = features[0]
features[0] = (res, init_channel)
chan_in_out = list(zip(features[:-1], features[1:]))
self.non_residual_layers = nn.ModuleList([])
for ind in range(num_non_residual_layers):
first_layer = ind == 0
last_layer = ind == (num_non_residual_layers - 1)
chan_out = features[0][-1] if last_layer else init_channel
self.non_residual_layers.append(nn.Sequential(
Blur(),
nn.Conv2d(init_channel, chan_out, 4, stride=2, padding=1),
nn.LeakyReLU(0.1)
))
self.residual_layers = nn.ModuleList([])
for (res, ((_, chan_in), (_, chan_out))) in zip(non_residual_resolutions, chan_in_out):
image_width = 2 ** res
attn = None
if image_width in attn_res_layers:
attn = PreNorm(chan_in, LinearAttention(chan_in))
self.residual_layers.append(nn.ModuleList([
SumBranches([
nn.Sequential(
Blur(),
nn.Conv2d(chan_in, chan_out, 4, stride=2, padding=1),
nn.LeakyReLU(0.1),
nn.Conv2d(chan_out, chan_out, 3, padding=1),
nn.LeakyReLU(0.1)
),
nn.Sequential(
Blur(),
nn.AvgPool2d(2),
nn.Conv2d(chan_in, chan_out, 1),
nn.LeakyReLU(0.1),
)
]),
attn
]))
last_chan = features[-1][-1]
if disc_output_size == 5:
self.to_logits = nn.Sequential(
nn.Conv2d(last_chan, last_chan, 1),
nn.LeakyReLU(0.1),
nn.Conv2d(last_chan, 1, 4)
)
elif disc_output_size == 1:
self.to_logits = nn.Sequential(
Blur(),
nn.Conv2d(last_chan, last_chan, 3, stride=2, padding=1),
nn.LeakyReLU(0.1),
nn.Conv2d(last_chan, 1, 4)
)
self.to_shape_disc_out = nn.Sequential(
nn.Conv2d(init_channel, 64, 3, padding=1),
Residual(PreNorm(64, LinearAttention(64))),
SumBranches([
nn.Sequential(
Blur(),
nn.Conv2d(64, 32, 4, stride=2, padding=1),
nn.LeakyReLU(0.1),
nn.Conv2d(32, 32, 3, padding=1),
nn.LeakyReLU(0.1)
),
nn.Sequential(
Blur(),
nn.AvgPool2d(2),
nn.Conv2d(64, 32, 1),
nn.LeakyReLU(0.1),
)
]),
Residual(PreNorm(32, LinearAttention(32))),
nn.AdaptiveAvgPool2d((4, 4)),
nn.Conv2d(32, 1, 4)
)
self.decoder1 = SimpleDecoder(chan_in=last_chan, chan_out=init_channel)
self.decoder2 = SimpleDecoder(chan_in=features[-2][-1], chan_out=init_channel) if resolution >= 9 else None
def forward(self, x, calc_aux_loss=False):
orig_img = x
for layer in self.non_residual_layers:
x = layer(x)
layer_outputs = []
for (net, attn) in self.residual_layers:
if exists(attn):
x = attn(x) + x
x = net(x)
layer_outputs.append(x)
out = self.to_logits(x).flatten(1)
img_32x32 = F.interpolate(orig_img, size=(32, 32))
out_32x32 = self.to_shape_disc_out(img_32x32)
if not calc_aux_loss:
return out, out_32x32, None
# self-supervised auto-encoding loss
layer_8x8 = layer_outputs[-1]
layer_16x16 = layer_outputs[-2]
recon_img_8x8 = self.decoder1(layer_8x8)
aux_loss = F.mse_loss(
recon_img_8x8,
F.interpolate(orig_img, size=recon_img_8x8.shape[2:])
)
if exists(self.decoder2):
select_random_quadrant = lambda rand_quadrant, img: \
rearrange(img, 'b c (m h) (n w) -> (m n) b c h w', m=2, n=2)[rand_quadrant]
crop_image_fn = partial(select_random_quadrant, floor(random() * 4))
img_part, layer_16x16_part = map(crop_image_fn, (orig_img, layer_16x16))
recon_img_16x16 = self.decoder2(layer_16x16_part)
aux_loss_16x16 = F.mse_loss(
recon_img_16x16,
F.interpolate(img_part, size=recon_img_16x16.shape[2:])
)
aux_loss = aux_loss + aux_loss_16x16
return out, out_32x32, aux_loss
class LightweightGAN(nn.Module, HugGANModelHubMixin):
def __init__(
self,
*,
latent_dim,
image_size,
optimizer="adam",
fmap_max=512,
fmap_inverse_coef=12,
transparent=False,
greyscale=False,
disc_output_size=5,
attn_res_layers=[],
freq_chan_attn=False,
ttur_mult=1.,
lr=2e-4,
):
super().__init__()
self.config = {
'latent_dim': latent_dim,
'image_size': image_size,
'optimizer': optimizer,
'fmap_max': fmap_max,
'fmap_inverse_coef': fmap_inverse_coef,
'transparent': transparent,
'greyscale': greyscale,
'disc_output_size': disc_output_size,
'attn_res_layers': attn_res_layers,
'freq_chan_attn': freq_chan_attn,
'ttur_mult': ttur_mult,
'lr': lr
}
self.latent_dim = latent_dim
self.image_size = image_size
G_kwargs = dict(
image_size=image_size,
latent_dim=latent_dim,
fmap_max=fmap_max,
fmap_inverse_coef=fmap_inverse_coef,
transparent=transparent,
greyscale=greyscale,
attn_res_layers=attn_res_layers,
freq_chan_attn=freq_chan_attn
)
self.G = Generator(**G_kwargs)
self.D = Discriminator(
image_size=image_size,
fmap_max=fmap_max,
fmap_inverse_coef=fmap_inverse_coef,
transparent=transparent,
greyscale=greyscale,
attn_res_layers=attn_res_layers,
disc_output_size=disc_output_size
)
self.ema_updater = EMA(0.995)
self.GE = Generator(**G_kwargs)
set_requires_grad(self.GE, False)
if optimizer == "adam":
self.G_opt = Adam(self.G.parameters(), lr=lr, betas=(0.5, 0.9))
self.D_opt = Adam(self.D.parameters(), lr=lr * ttur_mult, betas=(0.5, 0.9))
elif optimizer == "adabelief":
from adabelief_pytorch import AdaBelief
self.G_opt = AdaBelief(self.G.parameters(), lr=lr, betas=(0.5, 0.9))
self.D_opt = AdaBelief(self.D.parameters(), lr=lr * ttur_mult, betas=(0.5, 0.9))
else:
assert False, "No valid optimizer is given"
self.apply(self._init_weights)
self.reset_parameter_averaging()
self.D_aug = AugWrapper(self.D, image_size)
def _init_weights(self, m):
if type(m) in {nn.Conv2d, nn.Linear}:
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
def EMA(self):
def update_moving_average(ma_model, current_model):
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
old_weight, up_weight = ma_params.data, current_params.data
ma_params.data = self.ema_updater.update_average(old_weight, up_weight)
for current_buffer, ma_buffer in zip(current_model.buffers(), ma_model.buffers()):
new_buffer_value = self.ema_updater.update_average(ma_buffer, current_buffer)
ma_buffer.copy_(new_buffer_value)
update_moving_average(self.GE, self.G)
def reset_parameter_averaging(self):
self.GE.load_state_dict(self.G.state_dict())
def forward(self, x):
raise NotImplemented
def _save_pretrained(self, save_directory):
"""
Overwrite this method in case you don't want to save complete model,
rather some specific layers
"""
path = os.path.join(save_directory, PYTORCH_WEIGHTS_NAME)
model_to_save = self.module if hasattr(self, "module") else self
# We update this to be a dict containing 'GAN', as that's what is expected
torch.save({'GAN': model_to_save.state_dict()}, path)
@classmethod
def _from_pretrained(
cls,
model_id,
revision,
cache_dir,
force_download,
proxies,
resume_download,
local_files_only,
token,
map_location="cpu",
strict=False,
**model_kwargs,
):
"""
Overwrite this method in case you wish to initialize your model in a
different way.
"""
map_location = torch.device(map_location)
if os.path.isdir(model_id):
print("Loading weights from local directory")
model_file = os.path.join(model_id, PYTORCH_WEIGHTS_NAME)
else:
model_file = hf_hub_download(
repo_id=model_id,
filename=PYTORCH_WEIGHTS_NAME,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
token=token,
local_files_only=local_files_only,
)
# We update here to directly unpack config
model = cls(**model_kwargs['config'])
state_dict = torch.load(model_file, map_location=map_location)
model.load_state_dict(state_dict["GAN"], strict=strict)
model.eval()
return model
# trainer
class Trainer():
def __init__(
self,
dataset_name="huggan/CelebA-faces",
name='default',
results_dir='results',
models_dir='models',
base_dir='./',
optimizer='adam',
latent_dim=256,
image_size=128,
num_image_tiles=8,
fmap_max=512,
transparent=False,
greyscale=False,
batch_size=4,
gp_weight=10,
gradient_accumulate_every=1,
attn_res_layers=[],
freq_chan_attn=False,
disc_output_size=5,
dual_contrast_loss=False,
antialias=False,
lr=2e-4,
lr_mlp=1.,
ttur_mult=1.,
save_every=10000,
evaluate_every=1000,
aug_prob=None,
aug_types=['translation', 'cutout'],
dataset_aug_prob=0.,
calculate_fid_every=None,
calculate_fid_num_images=12800,
clear_fid_cache=False,
log=False,
cpu=False,
mixed_precision="no",
wandb=False,
push_to_hub=False,
organization_name=None,
*args,
**kwargs
):
self.GAN_params = [args, kwargs]
self.GAN = None
self.dataset_name = dataset_name
self.name = name
base_dir = Path(base_dir)
self.base_dir = base_dir
self.results_dir = base_dir / results_dir
self.models_dir = base_dir / models_dir
self.fid_dir = base_dir / 'fid' / name
# Note - in original repo config is private - ".config.json", but here, we make it public
self.config_path = self.models_dir / name / 'config.json'
assert is_power_of_two(image_size), 'image size must be a power of 2 (64, 128, 256, 512, 1024)'
assert all(map(is_power_of_two,
attn_res_layers)), 'resolution layers of attention must all be powers of 2 (16, 32, 64, 128, 256, 512)'
assert not (
dual_contrast_loss and disc_output_size > 1), 'discriminator output size cannot be greater than 1 if using dual contrastive loss'
self.image_size = image_size
self.num_image_tiles = num_image_tiles
self.latent_dim = latent_dim
self.fmap_max = fmap_max
self.transparent = transparent
self.greyscale = greyscale
assert (int(self.transparent) + int(self.greyscale)) < 2, 'you can only set either transparency or greyscale'
self.aug_prob = aug_prob
self.aug_types = aug_types
self.lr = lr
self.optimizer = optimizer
self.ttur_mult = ttur_mult
self.batch_size = batch_size
self.gradient_accumulate_every = gradient_accumulate_every
self.gp_weight = gp_weight
self.evaluate_every = evaluate_every
self.save_every = save_every
self.steps = 0
self.attn_res_layers = attn_res_layers
self.freq_chan_attn = freq_chan_attn
self.disc_output_size = disc_output_size
self.antialias = antialias
self.dual_contrast_loss = dual_contrast_loss
self.d_loss = 0
self.g_loss = 0
self.last_gp_loss = None
self.last_recon_loss = None
self.last_fid = None
self.init_folders()
self.loader = None
self.dataset_aug_prob = dataset_aug_prob
self.calculate_fid_every = calculate_fid_every
self.calculate_fid_num_images = calculate_fid_num_images
self.clear_fid_cache = clear_fid_cache
self.syncbatchnorm = torch.cuda.device_count() > 1 and not cpu
self.cpu = cpu
self.mixed_precision = mixed_precision
self.wandb = wandb
self.push_to_hub = push_to_hub
self.organization_name = organization_name
self.repo_name = get_full_repo_name(self.name, self.organization_name)
if self.push_to_hub:
self.repo_url = create_repo(self.repo_name, exist_ok=True)
@property
def image_extension(self):
return 'jpg' if not self.transparent else 'png'
@property
def checkpoint_num(self):
return floor(self.steps // self.save_every)
def init_GAN(self):
args, kwargs = self.GAN_params
# set some global variables before instantiating GAN
global norm_class
global Blur
norm_class = nn.SyncBatchNorm if self.syncbatchnorm else nn.BatchNorm2d
Blur = nn.Identity if not self.antialias else Fuzziness
# instantiate GAN
self.GAN = LightweightGAN(
optimizer=self.optimizer,
lr=self.lr,
latent_dim=self.latent_dim,
attn_res_layers=self.attn_res_layers,
freq_chan_attn=self.freq_chan_attn,
image_size=self.image_size,
ttur_mult=self.ttur_mult,
fmap_max=self.fmap_max,
disc_output_size=self.disc_output_size,
transparent=self.transparent,
greyscale=self.greyscale,
*args,
**kwargs
)
def write_config(self):
self.config_path.write_text(json.dumps(self.config()))
def load_config(self):
config = self.config() if not self.config_path.exists() else json.loads(self.config_path.read_text())
self.image_size = config['image_size']
self.transparent = config['transparent']
self.syncbatchnorm = config['syncbatchnorm']
self.disc_output_size = config['disc_output_size']
self.greyscale = config.pop('greyscale', False)
self.attn_res_layers = config.pop('attn_res_layers', [])
self.freq_chan_attn = config.pop('freq_chan_attn', False)
self.optimizer = config.pop('optimizer', 'adam')
self.fmap_max = config.pop('fmap_max', 512)
del self.GAN
self.init_GAN()
def config(self):
return {
'image_size': self.image_size,
'transparent': self.transparent,
'greyscale': self.greyscale,
'syncbatchnorm': self.syncbatchnorm,
'disc_output_size': self.disc_output_size,
'optimizer': self.optimizer,
'attn_res_layers': self.attn_res_layers,
'freq_chan_attn': self.freq_chan_attn
}
def set_data_src(self):
# start of using HuggingFace dataset
dataset = load_dataset(self.dataset_name)
if self.transparent:
num_channels = 4
pillow_mode = 'RGBA'
expand_fn = expand_greyscale(self.transparent)
elif self.greyscale:
num_channels = 1
pillow_mode = 'L'
expand_fn = identity()
else:
num_channels = 3
pillow_mode = 'RGB'
expand_fn = expand_greyscale(self.transparent)
convert_image_fn = partial(convert_image_to, pillow_mode)
transform = transforms.Compose([
transforms.Lambda(convert_image_fn),
transforms.Lambda(partial(resize_to_minimum_size, self.image_size)),
transforms.Resize(self.image_size),
RandomApply(0., transforms.RandomResizedCrop(self.image_size, scale=(0.5, 1.0), ratio=(0.98, 1.02)),
transforms.CenterCrop(self.image_size)),
transforms.ToTensor(),
transforms.Lambda(expand_fn)
])
def transform_images(examples):
transformed_images = [transform(image.convert("RGB")) for image in examples["image"]]
examples["image"] = torch.stack(transformed_images)
return examples
transformed_dataset = dataset.with_transform(transform_images)
per_device_batch_size = math.ceil(self.batch_size / self.accelerator.num_processes)
dataloader = DataLoader(transformed_dataset["train"], per_device_batch_size, sampler=None, shuffle=False,
drop_last=True, pin_memory=True)
num_samples = len(transformed_dataset)
## end of HuggingFace dataset
# Note - in original repo, this is wrapped with cycle, but we will do that after accelerator prepares
self.loader = dataloader
# auto set augmentation prob for user if dataset is detected to be low
# num_samples = len(self.dataset)
if not exists(self.aug_prob) and num_samples < 1e5:
self.aug_prob = min(0.5, (1e5 - num_samples) * 3e-6)
print(f'autosetting augmentation probability to {round(self.aug_prob * 100)}%')
def init_accelerator(self):
# Initialize the accelerator. We will let the accelerator handle device placement.
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
self.accelerator = Accelerator(kwargs_handlers=[ddp_kwargs], mixed_precision=self.mixed_precision, cpu=self.cpu)
if self.accelerator.is_local_main_process:
# set up Weights and Biases if requested
if self.wandb:
import wandb
wandb.init(project=str(self.results_dir).split("/")[-1])
if not exists(self.GAN):
self.init_GAN()
G = self.GAN.G
D = self.GAN.D
D_aug = self.GAN.D_aug
# discriminator loss fn
self.set_data_src()
# prepare
G, D, D_aug, self.GAN.D_opt, self.GAN.G_opt, self.loader = self.accelerator.prepare(G, D, D_aug, self.GAN.D_opt,
self.GAN.G_opt, self.loader)
self.loader = cycle(self.loader)
return G, D, D_aug
def train(self, G, D, D_aug):
assert exists(self.loader), 'You must first initialize the data source with `.set_data_src(<folder of images>)`'
self.GAN.train()
total_disc_loss = torch.zeros([], device=self.accelerator.device)
total_gen_loss = torch.zeros([], device=self.accelerator.device)
batch_size = math.ceil(self.batch_size / self.accelerator.num_processes)
image_size = self.GAN.image_size
latent_dim = self.GAN.latent_dim
aug_prob = default(self.aug_prob, 0)
aug_types = self.aug_types
aug_kwargs = {'prob': aug_prob, 'types': aug_types}
apply_gradient_penalty = self.steps % 4 == 0
# discriminator loss fn
if self.dual_contrast_loss:
D_loss_fn = dual_contrastive_loss
else:
D_loss_fn = hinge_loss
# train discriminator
self.GAN.D_opt.zero_grad()
for i in range(self.gradient_accumulate_every):
latents = torch.randn(batch_size, latent_dim, device=self.accelerator.device)
image_batch = next(self.loader)["image"]
image_batch.requires_grad_()
with torch.no_grad():
generated_images = G(latents)
fake_output, fake_output_32x32, _ = D_aug(generated_images, detach=True, **aug_kwargs)
real_output, real_output_32x32, real_aux_loss = D_aug(image_batch, calc_aux_loss=True, **aug_kwargs)
real_output_loss = real_output
fake_output_loss = fake_output
divergence = D_loss_fn(real_output_loss, fake_output_loss)
divergence_32x32 = D_loss_fn(real_output_32x32, fake_output_32x32)
disc_loss = divergence + divergence_32x32
aux_loss = real_aux_loss
disc_loss = disc_loss + aux_loss
if apply_gradient_penalty:
outputs = [real_output, real_output_32x32]
if self.accelerator.scaler is not None:
outputs = list(map(self.accelerator.scaler.scale, outputs))
scaled_gradients = torch_grad(outputs=outputs, inputs=image_batch,
grad_outputs=list(
map(lambda t: torch.ones(t.size(), device=self.accelerator.device),
outputs)),
create_graph=True, retain_graph=True, only_inputs=True)[0]
inv_scale = 1.
if self.accelerator.scaler is not None:
inv_scale = safe_div(1., self.accelerator.scaler.get_scale())
if inv_scale != float('inf'):
gradients = scaled_gradients * inv_scale
gradients = gradients.reshape(batch_size, -1)
gp = self.gp_weight * ((gradients.norm(2, dim=1) - 1) ** 2).mean()
if not torch.isnan(gp):
disc_loss = disc_loss + gp
self.last_gp_loss = gp.clone().detach().item()
# divide loss by gradient accumulation steps since gradients
# are accumulated for multiple backward passes in PyTorch
disc_loss = disc_loss / self.gradient_accumulate_every
disc_loss.register_hook(raise_if_nan)
self.accelerator.backward(disc_loss)
total_disc_loss += divergence
self.last_recon_loss = aux_loss.item()
self.d_loss = float(total_disc_loss.item() / self.gradient_accumulate_every)
self.GAN.D_opt.step()
# generator loss fn
if self.dual_contrast_loss:
G_loss_fn = dual_contrastive_loss
G_requires_calc_real = True
else:
G_loss_fn = gen_hinge_loss
G_requires_calc_real = False
# train generator
self.GAN.G_opt.zero_grad()
for i in range(self.gradient_accumulate_every):
latents = torch.randn(batch_size, latent_dim, device=self.accelerator.device)
if G_requires_calc_real:
image_batch = next(self.loader)["image"]
image_batch.requires_grad_()
generated_images = G(latents)
fake_output, fake_output_32x32, _ = D_aug(generated_images, **aug_kwargs)
real_output, real_output_32x32, _ = D_aug(image_batch, **aug_kwargs) if G_requires_calc_real else (
None, None, None)
loss = G_loss_fn(fake_output, real_output)
loss_32x32 = G_loss_fn(fake_output_32x32, real_output_32x32)
gen_loss = loss + loss_32x32
gen_loss = gen_loss / self.gradient_accumulate_every
gen_loss.register_hook(raise_if_nan)
self.accelerator.backward(gen_loss)
total_gen_loss += loss
# divide loss by gradient accumulation steps since gradients
# are accumulated for multiple backward passes in PyTorch
self.g_loss = float(total_gen_loss.item() / self.gradient_accumulate_every)
self.GAN.G_opt.step()
# calculate moving averages
if self.accelerator.is_main_process and self.steps % 10 == 0 and self.steps > 20000:
self.GAN.EMA()
if self.accelerator.is_main_process and self.steps <= 25000 and self.steps % 1000 == 2:
self.GAN.reset_parameter_averaging()
# save from NaN errors
if any(torch.isnan(l) for l in (total_gen_loss, total_disc_loss)):
print(f'NaN detected for generator or discriminator. Loading from checkpoint #{self.checkpoint_num}')
self.load(self.checkpoint_num)
raise NanException
del total_disc_loss
del total_gen_loss
# periodically save results
if self.accelerator.is_main_process:
if self.steps % self.save_every == 0:
self.save(self.checkpoint_num)
if self.push_to_hub:
with tempfile.TemporaryDirectory() as temp_dir:
self.GAN.push_to_hub(temp_dir, self.repo_url, config=self.GAN.config, skip_lfs_files=True)
if self.steps % self.evaluate_every == 0 or (self.steps % 100 == 0 and self.steps < 20000):
self.evaluate(floor(self.steps / self.evaluate_every), num_image_tiles=self.num_image_tiles)
if exists(self.calculate_fid_every) and self.steps % self.calculate_fid_every == 0 and self.steps != 0:
num_batches = math.ceil(self.calculate_fid_num_images / self.batch_size)
fid = self.calculate_fid(num_batches)
self.last_fid = fid
with open(str(self.results_dir / self.name / f'fid_scores.txt'), 'a') as f:
f.write(f'{self.steps},{fid}\n')
self.steps += 1
@torch.no_grad()
def evaluate(self, num=0, num_image_tiles=4):
self.GAN.eval()
ext = self.image_extension
num_rows = num_image_tiles
latent_dim = self.GAN.latent_dim
image_size = self.GAN.image_size
# latents and noise
latents = torch.randn(num_rows ** 2, latent_dim, device=self.accelerator.device)
# regular
generated_images = self.generate_(self.GAN.G, latents)
file_name = str(self.results_dir / self.name / f'{str(num)}.{ext}')
save_image(generated_images, file_name, nrow=num_rows)
# moving averages
generated_images = self.generate_(self.GAN.GE.to(self.accelerator.device), latents)
file_name_ema = str(self.results_dir / self.name / f'{str(num)}-ema.{ext}')
save_image(generated_images, file_name_ema, nrow=num_rows)
if self.accelerator.is_local_main_process and self.wandb:
import wandb
wandb.log({'generated_examples': wandb.Image(str(file_name))})
wandb.log({'generated_examples_ema': wandb.Image(str(file_name_ema))})
@torch.no_grad()
def generate(self, num=0, num_image_tiles=4, checkpoint=None, types=['default', 'ema']):
self.GAN.eval()
latent_dim = self.GAN.latent_dim
dir_name = self.name + str('-generated-') + str(checkpoint)
dir_full = Path().absolute() / self.results_dir / dir_name
ext = self.image_extension
if not dir_full.exists():
os.mkdir(dir_full)
# regular
if 'default' in types:
for i in tqdm(range(num_image_tiles), desc='Saving generated default images'):
latents = torch.randn(1, latent_dim, device=self.accelerator.device)
generated_image = self.generate_(self.GAN.G, latents)
path = str(self.results_dir / dir_name / f'{str(num)}-{str(i)}.{ext}')
save_image(generated_image[0], path, nrow=1)
# moving averages
if 'ema' in types:
for i in tqdm(range(num_image_tiles), desc='Saving generated EMA images'):
latents = torch.randn(1, latent_dim, device=self.accelerator.device)
generated_image = self.generate_(self.GAN.GE, latents)
path = str(self.results_dir / dir_name / f'{str(num)}-{str(i)}-ema.{ext}')
save_image(generated_image[0], path, nrow=1)
return dir_full
@torch.no_grad()
def show_progress(self, num_images=4, types=['default', 'ema']):
checkpoints = self.get_checkpoints()
assert exists(checkpoints), 'cannot find any checkpoints to create a training progress video for'
dir_name = self.name + str('-progress')
dir_full = Path().absolute() / self.results_dir / dir_name
ext = self.image_extension
latents = None
zfill_length = math.ceil(math.log10(len(checkpoints)))
if not dir_full.exists():
os.mkdir(dir_full)
for checkpoint in tqdm(checkpoints, desc='Generating progress images'):
self.load(checkpoint, print_version=False)
self.GAN.eval()
if checkpoint == 0:
latents = torch.randn(num_images, self.GAN.latent_dim, self.accelerator.device)
# regular
if 'default' in types:
generated_image = self.generate_(self.GAN.G, latents)
path = str(self.results_dir / dir_name / f'{str(checkpoint).zfill(zfill_length)}.{ext}')
save_image(generated_image, path, nrow=num_images)
# moving averages
if 'ema' in types:
generated_image = self.generate_(self.GAN.GE, latents)
path = str(self.results_dir / dir_name / f'{str(checkpoint).zfill(zfill_length)}-ema.{ext}')
save_image(generated_image, path, nrow=num_images)
@torch.no_grad()
def calculate_fid(self, num_batches):
from pytorch_fid import fid_score
real_path = self.fid_dir / 'real'
fake_path = self.fid_dir / 'fake'
# remove any existing files used for fid calculation and recreate directories
if not real_path.exists() or self.clear_fid_cache:
rmtree(real_path, ignore_errors=True)
os.makedirs(real_path)
for batch_num in tqdm(range(num_batches), desc='calculating FID - saving reals'):
real_batch = next(self.loader)["image"]
for k, image in enumerate(real_batch.unbind(0)):
ind = k + batch_num * self.batch_size
save_image(image, real_path / f'{ind}.png')
# generate a bunch of fake images in results / name / fid_fake
rmtree(fake_path, ignore_errors=True)
os.makedirs(fake_path)
self.GAN.eval()
ext = self.image_extension
latent_dim = self.GAN.latent_dim
image_size = self.GAN.image_size
for batch_num in tqdm(range(num_batches), desc='calculating FID - saving generated'):
# latents and noise
latents = torch.randn(self.batch_size, latent_dim, device=self.accelerator.device)
# moving averages
generated_images = self.generate_(self.GAN.GE, latents)
for j, image in enumerate(generated_images.unbind(0)):
ind = j + batch_num * self.batch_size
save_image(image, str(fake_path / f'{str(ind)}-ema.{ext}'))
return fid_score.calculate_fid_given_paths([str(real_path), str(fake_path)], 256, latents.device, 2048)
@torch.no_grad()
def generate_(self, G, style, num_image_tiles=8):
generated_images = evaluate_in_chunks(self.batch_size, G, style)
return generated_images.clamp_(0., 1.)
@torch.no_grad()
def generate_interpolation(self, num=0, num_image_tiles=8, num_steps=100, save_frames=False):
self.GAN.eval()
ext = self.image_extension
num_rows = num_image_tiles
latent_dim = self.GAN.latent_dim
image_size = self.GAN.image_size
# latents and noise
latents_low = torch.randn(num_rows ** 2, latent_dim, device=self.accelerator.device)
latents_high = torch.randn(num_rows ** 2, latent_dim, device=self.accelerator.device)
ratios = torch.linspace(0., 8., num_steps)
frames = []
for ratio in tqdm(ratios):
interp_latents = slerp(ratio, latents_low, latents_high)
generated_images = self.generate_(self.GAN.GE, interp_latents)
images_grid = torchvision.utils.make_grid(generated_images, nrow=num_rows)
pil_image = transforms.ToPILImage()(images_grid.cpu())
if self.transparent:
background = Image.new('RGBA', pil_image.size, (255, 255, 255))
pil_image = Image.alpha_composite(background, pil_image)
frames.append(pil_image)
frames[0].save(str(self.results_dir / self.name / f'{str(num)}.gif'), save_all=True, append_images=frames[1:],
duration=80, loop=0, optimize=True)
if save_frames:
folder_path = (self.results_dir / self.name / f'{str(num)}')
folder_path.mkdir(parents=True, exist_ok=True)
for ind, frame in enumerate(frames):
frame.save(str(folder_path / f'{str(ind)}.{ext}'))
def print_log(self):
data = [
('G', self.g_loss),
('D', self.d_loss),
('GP', self.last_gp_loss),
('SS', self.last_recon_loss),
('FID', self.last_fid)
]
data = [d for d in data if exists(d[1])]
log = ' | '.join(map(lambda n: f'{n[0]}: {n[1]:.2f}', data))
print(log)
if self.accelerator.is_local_main_process:
log_dict = {v[0]: v[1] for v in data}
if self.wandb:
import wandb
wandb.log(log_dict)
def model_name(self, num):
return str(self.models_dir / self.name / f'model_{num}.pt')
def init_folders(self):
(self.results_dir / self.name).mkdir(parents=True, exist_ok=True)
(self.models_dir / self.name).mkdir(parents=True, exist_ok=True)
def clear(self):
rmtree(str(self.models_dir / self.name), True)
rmtree(str(self.results_dir / self.name), True)
rmtree(str(self.fid_dir), True)
rmtree(str(self.config_path), True)
self.init_folders()
def save(self, num):
save_data = {
'GAN': self.GAN.state_dict(),
}
torch.save(save_data, self.model_name(num))
self.write_config()
def load(self, num=-1):
self.load_config()
name = num
if num == -1:
checkpoints = self.get_checkpoints()
if not exists(checkpoints):
return
name = checkpoints[-1]
print(f'continuing from previous epoch - {name}')
self.steps = name * self.save_every
load_data = torch.load(self.model_name(name))
try:
self.GAN.load_state_dict(load_data['GAN'])
except Exception as e:
print(
'unable to load save model. please try downgrading the package to the version specified by the saved model')
raise e
def get_checkpoints(self):
file_paths = [p for p in Path(self.models_dir / self.name).glob('model_*.pt')]
saved_nums = sorted(map(lambda x: int(x.stem.split('_')[1]), file_paths))
if len(saved_nums) == 0:
return None
return saved_nums