import numpy as np import jax from jax import random import jax.numpy as jnp import flax.linen as nn from typing import Any, Tuple, List, Callable import h5py from . import ops from stylegan2 import utils URLS = {'afhqcat': 'https://www.dropbox.com/s/qygbjkefyqyu9k9/stylegan2_discriminator_afhqcat.h5?dl=1', 'afhqdog': 'https://www.dropbox.com/s/kmoxbp33qswz64p/stylegan2_discriminator_afhqdog.h5?dl=1', 'afhqwild': 'https://www.dropbox.com/s/jz1hpsyt3isj6e7/stylegan2_discriminator_afhqwild.h5?dl=1', 'brecahad': 'https://www.dropbox.com/s/h0cb89hruo6pmyj/stylegan2_discriminator_brecahad.h5?dl=1', 'car': 'https://www.dropbox.com/s/2ghjrmxih7cic76/stylegan2_discriminator_car.h5?dl=1', 'cat': 'https://www.dropbox.com/s/zfhjsvlsny5qixd/stylegan2_discriminator_cat.h5?dl=1', 'church': 'https://www.dropbox.com/s/jlno7zeivkjtk8g/stylegan2_discriminator_church.h5?dl=1', 'cifar10': 'https://www.dropbox.com/s/eldpubfkl4c6rur/stylegan2_discriminator_cifar10.h5?dl=1', 'ffhq': 'https://www.dropbox.com/s/m42qy9951b7lq1s/stylegan2_discriminator_ffhq.h5?dl=1', 'horse': 'https://www.dropbox.com/s/19f5pxrcdh2g8cw/stylegan2_discriminator_horse.h5?dl=1', 'metfaces': 'https://www.dropbox.com/s/xnokaunql12glkd/stylegan2_discriminator_metfaces.h5?dl=1'} RESOLUTION = {'metfaces': 1024, 'ffhq': 1024, 'church': 256, 'cat': 256, 'horse': 256, 'car': 512, 'brecahad': 512, 'afhqwild': 512, 'afhqdog': 512, 'afhqcat': 512, 'cifar10': 32} C_DIM = {'metfaces': 0, 'ffhq': 0, 'church': 0, 'cat': 0, 'horse': 0, 'car': 0, 'brecahad': 0, 'afhqwild': 0, 'afhqdog': 0, 'afhqcat': 0, 'cifar10': 10} ARCHITECTURE = {'metfaces': 'resnet', 'ffhq': 'resnet', 'church': 'resnet', 'cat': 'resnet', 'horse': 'resnet', 'car': 'resnet', 'brecahad': 'resnet', 'afhqwild': 'resnet', 'afhqdog': 'resnet', 'afhqcat': 'resnet', 'cifar10': 'orig'} MBSTD_GROUP_SIZE = {'metfaces': None, 'ffhq': None, 'church': None, 'cat': None, 'horse': None, 'car': None, 'brecahad': None, 'afhqwild': None, 'afhqdog': None, 'afhqcat': None, 'cifar10': 32} class FromRGBLayer(nn.Module): """ From RGB Layer. Attributes: fmaps (int): Number of output channels of the convolution. kernel (int): Kernel size of the convolution. lr_multiplier (float): Learning rate multiplier. activation (str): Activation function: 'relu', 'lrelu', etc. param_dict (h5py.Group): Parameter dict with pretrained parameters. If not None, 'pretrained' will be ignored. clip_conv (float): Clip the output of convolution layers to [-clip_conv, +clip_conv], None = disable clipping. dtype (str): Data dtype. rng (jax.random.PRNGKey): PRNG for initialization. """ fmaps: int kernel: int=1 lr_multiplier: float=1 activation: str='leaky_relu' param_dict: h5py.Group=None clip_conv: float=None dtype: str='float32' rng: Any=random.PRNGKey(0) @nn.compact def __call__(self, x, y): """ Run From RGB Layer. Args: x (tensor): Input image of shape [N, H, W, num_channels]. y (tensor): Input tensor of shape [N, H, W, out_channels]. Returns: (tensor): Output tensor of shape [N, H, W, out_channels]. """ w_shape = [self.kernel, self.kernel, x.shape[3], self.fmaps] w, b = ops.get_weight(w_shape, self.lr_multiplier, True, self.param_dict, 'fromrgb', self.rng) w = self.param(name='weight', init_fn=lambda *_ : w) b = self.param(name='bias', init_fn=lambda *_ : b) w = ops.equalize_lr_weight(w, self.lr_multiplier) b = ops.equalize_lr_bias(b, self.lr_multiplier) x = x.astype(self.dtype) x = ops.conv2d(x, w.astype(x.dtype)) x += b.astype(x.dtype) x = ops.apply_activation(x, activation=self.activation) if self.clip_conv is not None: x = jnp.clip(x, -self.clip_conv, self.clip_conv) if y is not None: x += y return x class DiscriminatorLayer(nn.Module): """ Discriminator Layer. Attributes: fmaps (int): Number of output channels of the convolution. kernel (int): Kernel size of the convolution. use_bias (bool): If True, use bias. down (bool): If True, downsample the spatial resolution. resample_kernel (Tuple): Kernel that is used for FIR filter. activation (str): Activation function: 'relu', 'lrelu', etc. layer_name (str): Layer name. param_dict (h5py.Group): Parameter dict with pretrained parameters. lr_multiplier (float): Learning rate multiplier. clip_conv (float): Clip the output of convolution layers to [-clip_conv, +clip_conv], None = disable clipping. dtype (str): Data dtype. rng (jax.random.PRNGKey): PRNG for initialization. """ fmaps: int kernel: int=3 use_bias: bool=True down: bool=False resample_kernel: Tuple=None activation: str='leaky_relu' layer_name: str=None param_dict: h5py.Group=None lr_multiplier: float=1 clip_conv: float=None dtype: str='float32' rng: Any=random.PRNGKey(0) @nn.compact def __call__(self, x): """ Run Discriminator Layer. Args: x (tensor): Input tensor of shape [N, H, W, C]. Returns: (tensor): Output tensor of shape [N, H, W, fmaps]. """ w_shape = [self.kernel, self.kernel, x.shape[3], self.fmaps] if self.use_bias: w, b = ops.get_weight(w_shape, self.lr_multiplier, self.use_bias, self.param_dict, self.layer_name, self.rng) else: w = ops.get_weight(w_shape, self.lr_multiplier, self.use_bias, self.param_dict, self.layer_name, self.rng) w = self.param(name='weight', init_fn=lambda *_ : w) w = ops.equalize_lr_weight(w, self.lr_multiplier) if self.use_bias: b = self.param(name='bias', init_fn=lambda *_ : b) b = ops.equalize_lr_bias(b, self.lr_multiplier) x = x.astype(self.dtype) x = ops.conv2d(x, w, down=self.down, resample_kernel=self.resample_kernel) if self.use_bias: x += b.astype(x.dtype) x = ops.apply_activation(x, activation=self.activation) if self.clip_conv is not None: x = jnp.clip(x, -self.clip_conv, self.clip_conv) return x class DiscriminatorBlock(nn.Module): """ Discriminator Block. Attributes: fmaps (int): Number of output channels of the convolution. kernel (int): Kernel size of the convolution. resample_kernel (Tuple): Kernel that is used for FIR filter. activation (str): Activation function: 'relu', 'lrelu', etc. param_dict (h5py.Group): Parameter dict with pretrained parameters. lr_multiplier (float): Learning rate multiplier. architecture (str): Architecture: 'orig', 'resnet'. nf (Callable): Callable that returns the number of feature maps for a given layer. clip_conv (float): Clip the output of convolution layers to [-clip_conv, +clip_conv], None = disable clipping. dtype (str): Data dtype. rng (jax.random.PRNGKey): Random seed for initialization. """ res: int kernel: int=3 resample_kernel: Tuple=(1, 3, 3, 1) activation: str='leaky_relu' param_dict: Any=None lr_multiplier: float=1 architecture: str='resnet' nf: Callable=None clip_conv: float=None dtype: str='float32' rng: Any=random.PRNGKey(0) @nn.compact def __call__(self, x): """ Run Discriminator Block. Args: x (tensor): Input tensor of shape [N, H, W, C]. Returns: (tensor): Output tensor of shape [N, H, W, fmaps]. """ init_rng = self.rng x = x.astype(self.dtype) residual = x for i in range(2): init_rng, init_key = random.split(init_rng) x = DiscriminatorLayer(fmaps=self.nf(self.res - (i + 1)), kernel=self.kernel, down=i == 1, resample_kernel=self.resample_kernel if i == 1 else None, activation=self.activation, layer_name=f'conv{i}', param_dict=self.param_dict, lr_multiplier=self.lr_multiplier, clip_conv=self.clip_conv, dtype=self.dtype, rng=init_key)(x) if self.architecture == 'resnet': init_rng, init_key = random.split(init_rng) residual = DiscriminatorLayer(fmaps=self.nf(self.res - 2), kernel=1, use_bias=False, down=True, resample_kernel=self.resample_kernel, activation='linear', layer_name='skip', param_dict=self.param_dict, lr_multiplier=self.lr_multiplier, dtype=self.dtype, rng=init_key)(residual) x = (x + residual) * np.sqrt(0.5, dtype=x.dtype) return x class Discriminator(nn.Module): """ Discriminator. Attributes: resolution (int): Input resolution. Overridden based on dataset. num_channels (int): Number of input color channels. Overridden based on dataset. c_dim (int): Dimensionality of the labels (c), 0 if no labels. Overrttten based on dataset. fmap_base (int): Overall multiplier for the number of feature maps. fmap_decay (int): Log2 feature map reduction when doubling the resolution. fmap_min (int): Minimum number of feature maps in any layer. fmap_max (int): Maximum number of feature maps in any layer. mapping_layers (int): Number of additional mapping layers for the conditioning labels. mapping_fmaps (int): Number of activations in the mapping layers, None = default. mapping_lr_multiplier (float): Learning rate multiplier for the mapping layers. architecture (str): Architecture: 'orig', 'resnet'. activation (int): Activation function: 'relu', 'leaky_relu', etc. mbstd_group_size (int): Group size for the minibatch standard deviation layer, None = entire minibatch. mbstd_num_features (int): Number of features for the minibatch standard deviation layer, 0 = disable. resample_kernel (Tuple): Low-pass filter to apply when resampling activations, None = box filter. num_fp16_res (int): Use float16 for the 'num_fp16_res' highest resolutions. clip_conv (float): Clip the output of convolution layers to [-clip_conv, +clip_conv], None = disable clipping. pretrained (str): Use pretrained model, None for random initialization. ckpt_dir (str): Directory to which the pretrained weights are downloaded. If None, a temp directory will be used. dtype (str): Data type. rng (jax.random.PRNGKey): PRNG for initialization. """ # Input dimensions. resolution: int=1024 num_channels: int=3 c_dim: int=0 # Capacity. fmap_base: int=16384 fmap_decay: int=1 fmap_min: int=1 fmap_max: int=512 # Internal details. mapping_layers: int=0 mapping_fmaps: int=None mapping_lr_multiplier: float=0.1 architecture: str='resnet' activation: str='leaky_relu' mbstd_group_size: int=None mbstd_num_features: int=1 resample_kernel: Tuple=(1, 3, 3, 1) num_fp16_res: int=0 clip_conv: float=None # Pretraining pretrained: str=None ckpt_dir: str=None dtype: str='float32' rng: Any=random.PRNGKey(0) def setup(self): self.resolution_ = self.resolution self.c_dim_ = self.c_dim self.architecture_ = self.architecture self.mbstd_group_size_ = self.mbstd_group_size self.param_dict = None if self.pretrained is not None: assert self.pretrained in URLS.keys(), f'Pretrained model not available: {self.pretrained}' ckpt_file = utils.download(self.ckpt_dir, URLS[self.pretrained]) self.param_dict = h5py.File(ckpt_file, 'r')['discriminator'] self.resolution_ = RESOLUTION[self.pretrained] self.architecture_ = ARCHITECTURE[self.pretrained] self.mbstd_group_size_ = MBSTD_GROUP_SIZE[self.pretrained] self.c_dim_ = C_DIM[self.pretrained] assert self.architecture in ['orig', 'resnet'] @nn.compact def __call__(self, x, c=None): """ Run Discriminator. Args: x (tensor): Input image of shape [N, H, W, num_channels]. c (tensor): Input labels, shape [N, c_dim]. Returns: (tensor): Output tensor of shape [N, 1]. """ resolution_log2 = int(np.log2(self.resolution_)) assert self.resolution_ == 2**resolution_log2 and self.resolution_ >= 4 def nf(stage): return np.clip(int(self.fmap_base / (2.0 ** (stage * self.fmap_decay))), self.fmap_min, self.fmap_max) if self.mapping_fmaps is None: mapping_fmaps = nf(0) else: mapping_fmaps = self.mapping_fmaps init_rng = self.rng # Label embedding and mapping. if self.c_dim_ > 0: c = ops.LinearLayer(in_features=self.c_dim_, out_features=mapping_fmaps, lr_multiplier=self.mapping_lr_multiplier, param_dict=self.param_dict, layer_name='label_embedding', dtype=self.dtype, rng=init_rng)(c) c = ops.normalize_2nd_moment(c) for i in range(self.mapping_layers): init_rng, init_key = random.split(init_rng) c = ops.LinearLayer(in_features=self.c_dim_, out_features=mapping_fmaps, lr_multiplier=self.mapping_lr_multiplier, param_dict=self.param_dict, layer_name=f'fc{i}', dtype=self.dtype, rng=init_key)(c) # Layers for >=8x8 resolutions. y = None for res in range(resolution_log2, 2, -1): res_str = f'block_{2**res}x{2**res}' if res == resolution_log2: init_rng, init_key = random.split(init_rng) x = FromRGBLayer(fmaps=nf(res - 1), kernel=1, activation=self.activation, param_dict=self.param_dict[res_str] if self.param_dict is not None else None, clip_conv=self.clip_conv, dtype=self.dtype if res >= resolution_log2 + 1 - self.num_fp16_res else 'float32', rng=init_key)(x, y) init_rng, init_key = random.split(init_rng) x = DiscriminatorBlock(res=res, kernel=3, resample_kernel=self.resample_kernel, activation=self.activation, param_dict=self.param_dict[res_str] if self.param_dict is not None else None, architecture=self.architecture_, nf=nf, clip_conv=self.clip_conv, dtype=self.dtype if res >= resolution_log2 + 1 - self.num_fp16_res else 'float32', rng=init_key)(x) # Layers for 4x4 resolution. dtype = jnp.float32 x = x.astype(dtype) if self.mbstd_num_features > 0: x = ops.minibatch_stddev_layer(x, self.mbstd_group_size_, self.mbstd_num_features) init_rng, init_key = random.split(init_rng) x = DiscriminatorLayer(fmaps=nf(1), kernel=3, use_bias=True, activation=self.activation, layer_name='conv0', param_dict=self.param_dict['block_4x4'] if self.param_dict is not None else None, clip_conv=self.clip_conv, dtype=dtype, rng=init_rng)(x) # Switch to NCHW so that the pretrained weights still work after reshaping x = jnp.transpose(x, axes=(0, 3, 1, 2)) x = jnp.reshape(x, newshape=(-1, x.shape[1] * x.shape[2] * x.shape[3])) init_rng, init_key = random.split(init_rng) x = ops.LinearLayer(in_features=x.shape[1], out_features=nf(0), activation=self.activation, param_dict=self.param_dict['block_4x4'] if self.param_dict is not None else None, layer_name='fc0', dtype=dtype, rng=init_key)(x) # Output layer. init_rng, init_key = random.split(init_rng) x = ops.LinearLayer(in_features=x.shape[1], out_features=1 if self.c_dim_ == 0 else mapping_fmaps, param_dict=self.param_dict, layer_name='output', dtype=dtype, rng=init_key)(x) if self.c_dim_ > 0: x = jnp.sum(x * c, axis=1, keepdims=True) / jnp.sqrt(mapping_fmaps) return x