stylegan2-flax-tpu / training_utils.py
akhaliq's picture
akhaliq HF staff
add files
81170fd
raw history blame
No virus
6.17 kB
import jax
import jax.numpy as jnp
from jaxlib.xla_extension import DeviceArray
import flax
from flax.optim import dynamic_scale as dynamic_scale_lib
from flax.core import frozen_dict
from flax.training import train_state
from flax import struct
import numpy as np
from PIL import Image
from urllib.request import Request, urlopen
import urllib.error
from typing import Any, Callable
def sync_moving_stats(state):
"""
Sync moving statistics across devices.
Args:
state (train_state.TrainState): Training state.
Returns:
(train_state.TrainState): Updated training state.
"""
cross_replica_mean = jax.pmap(lambda x: jax.lax.pmean(x, 'x'), 'x')
return state.replace(moving_stats=cross_replica_mean(state.moving_stats))
def update_generator_ema(state_G, params_ema_G, config, ema_beta=None):
"""
Update exponentially moving average of the generator weights.
Moving stats and noise constants will be copied over.
Args:
state_G (train_state.TrainState): Generator state.
params_ema_G (frozen_dict.FrozenDict): Parameters of the ema generator.
config (Any): Config object.
ema_beta (float): Beta parameter of the ema. If None, will be computed
from 'ema_nimg' and 'batch_size'.
Returns:
(frozen_dict.FrozenDict): Updates parameters of the ema generator.
"""
def _update_ema(src, trg, beta):
for name, src_child in src.items():
if isinstance(src_child, DeviceArray):
trg[name] = src[name] + ema_beta * (trg[name] - src[name])
else:
_update_ema(src_child, trg[name], beta)
if ema_beta is None:
ema_nimg = config.ema_kimg * 1000
ema_beta = 0.5 ** (config.batch_size / max(ema_nimg, 1e-8))
params_ema_G = params_ema_G.unfreeze()
# Copy over moving stats
params_ema_G['moving_stats']['mapping_network'] = state_G.moving_stats
params_ema_G['noise_consts']['synthesis_network'] = state_G.noise_consts
# Update exponentially moving average of the trainable parameters
_update_ema(state_G.params['mapping'], params_ema_G['params']['mapping_network'], ema_beta)
_update_ema(state_G.params['synthesis'], params_ema_G['params']['synthesis_network'], ema_beta)
params_ema_G = frozen_dict.freeze(params_ema_G)
return params_ema_G
class TrainStateG(train_state.TrainState):
"""
Generator train state for a single Optax optimizer.
Attributes:
apply_mapping (Callable): Apply function of the Mapping Network.
apply_synthesis (Callable): Apply function of the Synthesis Network.
dynamic_scale (dynamic_scale_lib.DynamicScale): Dynamic loss scaling for mixed precision gradients.
epoch (int): Current epoch.
moving_stats (Any): Moving average of the latent W.
noise_consts (Any): Noise constants from synthesis layers.
"""
apply_mapping: Callable = struct.field(pytree_node=False)
apply_synthesis: Callable = struct.field(pytree_node=False)
dynamic_scale_main: dynamic_scale_lib.DynamicScale
dynamic_scale_reg: dynamic_scale_lib.DynamicScale
epoch: int
moving_stats: Any=None
noise_consts: Any=None
class TrainStateD(train_state.TrainState):
"""
Discriminator train state for a single Optax optimizer.
Attributes:
dynamic_scale (dynamic_scale_lib.DynamicScale): Dynamic loss scaling for mixed precision gradients.
epoch (int): Current epoch.
"""
dynamic_scale_main: dynamic_scale_lib.DynamicScale
dynamic_scale_reg: dynamic_scale_lib.DynamicScale
epoch: int
def get_training_snapshot(image_real, image_gen, max_num=10):
"""
Creates a snapshot of generated images and real images.
Args:
images_real (DeviceArray): Batch of real images, shape [B, H, W, C].
images_gen (DeviceArray): Batch of generated images, shape [B, H, W, C].
max_num (int): Maximum number of images used for snapshot.
Returns:
(PIL.Image): Training snapshot. Top row: generated images, bottom row: real images.
"""
if image_real.shape[0] > max_num:
image_real = image_real[:max_num]
if image_gen.shape[0] > max_num:
image_gen = image_gen[:max_num]
image_real = jnp.split(image_real, image_real.shape[0], axis=0)
image_gen = jnp.split(image_gen, image_gen.shape[0], axis=0)
image_real = [jnp.squeeze(x, axis=0) for x in image_real]
image_gen = [jnp.squeeze(x, axis=0) for x in image_gen]
image_real = jnp.concatenate(image_real, axis=1)
image_gen = jnp.concatenate(image_gen, axis=1)
image_gen = (image_gen - np.min(image_gen)) / (np.max(image_gen) - np.min(image_gen))
image_real = (image_real - np.min(image_real)) / (np.max(image_real) - np.min(image_real))
image = jnp.concatenate((image_gen, image_real), axis=0)
image = np.uint8(image * 255)
if image.shape[-1] == 1:
image = np.repeat(image, 3, axis=-1)
return Image.fromarray(image)
def get_eval_snapshot(image, max_num=10):
"""
Creates a snapshot of generated images.
Args:
image (DeviceArray): Generated images, shape [B, H, W, C].
Returns:
(PIL.Image): Eval snapshot.
"""
if image.shape[0] > max_num:
image = image[:max_num]
image = jnp.split(image, image.shape[0], axis=0)
image = [jnp.squeeze(x, axis=0) for x in image]
image = jnp.concatenate(image, axis=1)
image = (image - np.min(image)) / (np.max(image) - np.min(image))
image = np.uint8(image * 255)
if image.shape[-1] == 1:
image = np.repeat(image, 3, axis=-1)
return Image.fromarray(image)
def get_vm_name():
gcp_metadata_url = "http://metadata.google.internal/computeMetadata/v1/instance/attributes/instance-id"
req = Request(gcp_metadata_url)
req.add_header('Metadata-Flavor', 'Google')
instance_id = None
try:
with urlopen(req) as url:
instance_id = url.read().decode()
except urllib.error.URLError:
# metadata.google.internal not reachable: use dev
pass
return instance_id