import jax import jax.numpy as jnp from jaxlib.xla_extension import DeviceArray import flax from flax.optim import dynamic_scale as dynamic_scale_lib from flax.core import frozen_dict from flax.training import train_state from flax import struct import numpy as np from PIL import Image from urllib.request import Request, urlopen import urllib.error from typing import Any, Callable def sync_moving_stats(state): """ Sync moving statistics across devices. Args: state (train_state.TrainState): Training state. Returns: (train_state.TrainState): Updated training state. """ cross_replica_mean = jax.pmap(lambda x: jax.lax.pmean(x, 'x'), 'x') return state.replace(moving_stats=cross_replica_mean(state.moving_stats)) def update_generator_ema(state_G, params_ema_G, config, ema_beta=None): """ Update exponentially moving average of the generator weights. Moving stats and noise constants will be copied over. Args: state_G (train_state.TrainState): Generator state. params_ema_G (frozen_dict.FrozenDict): Parameters of the ema generator. config (Any): Config object. ema_beta (float): Beta parameter of the ema. If None, will be computed from 'ema_nimg' and 'batch_size'. Returns: (frozen_dict.FrozenDict): Updates parameters of the ema generator. """ def _update_ema(src, trg, beta): for name, src_child in src.items(): if isinstance(src_child, DeviceArray): trg[name] = src[name] + ema_beta * (trg[name] - src[name]) else: _update_ema(src_child, trg[name], beta) if ema_beta is None: ema_nimg = config.ema_kimg * 1000 ema_beta = 0.5 ** (config.batch_size / max(ema_nimg, 1e-8)) params_ema_G = params_ema_G.unfreeze() # Copy over moving stats params_ema_G['moving_stats']['mapping_network'] = state_G.moving_stats params_ema_G['noise_consts']['synthesis_network'] = state_G.noise_consts # Update exponentially moving average of the trainable parameters _update_ema(state_G.params['mapping'], params_ema_G['params']['mapping_network'], ema_beta) _update_ema(state_G.params['synthesis'], params_ema_G['params']['synthesis_network'], ema_beta) params_ema_G = frozen_dict.freeze(params_ema_G) return params_ema_G class TrainStateG(train_state.TrainState): """ Generator train state for a single Optax optimizer. Attributes: apply_mapping (Callable): Apply function of the Mapping Network. apply_synthesis (Callable): Apply function of the Synthesis Network. dynamic_scale (dynamic_scale_lib.DynamicScale): Dynamic loss scaling for mixed precision gradients. epoch (int): Current epoch. moving_stats (Any): Moving average of the latent W. noise_consts (Any): Noise constants from synthesis layers. """ apply_mapping: Callable = struct.field(pytree_node=False) apply_synthesis: Callable = struct.field(pytree_node=False) dynamic_scale_main: dynamic_scale_lib.DynamicScale dynamic_scale_reg: dynamic_scale_lib.DynamicScale epoch: int moving_stats: Any=None noise_consts: Any=None class TrainStateD(train_state.TrainState): """ Discriminator train state for a single Optax optimizer. Attributes: dynamic_scale (dynamic_scale_lib.DynamicScale): Dynamic loss scaling for mixed precision gradients. epoch (int): Current epoch. """ dynamic_scale_main: dynamic_scale_lib.DynamicScale dynamic_scale_reg: dynamic_scale_lib.DynamicScale epoch: int def get_training_snapshot(image_real, image_gen, max_num=10): """ Creates a snapshot of generated images and real images. Args: images_real (DeviceArray): Batch of real images, shape [B, H, W, C]. images_gen (DeviceArray): Batch of generated images, shape [B, H, W, C]. max_num (int): Maximum number of images used for snapshot. Returns: (PIL.Image): Training snapshot. Top row: generated images, bottom row: real images. """ if image_real.shape[0] > max_num: image_real = image_real[:max_num] if image_gen.shape[0] > max_num: image_gen = image_gen[:max_num] image_real = jnp.split(image_real, image_real.shape[0], axis=0) image_gen = jnp.split(image_gen, image_gen.shape[0], axis=0) image_real = [jnp.squeeze(x, axis=0) for x in image_real] image_gen = [jnp.squeeze(x, axis=0) for x in image_gen] image_real = jnp.concatenate(image_real, axis=1) image_gen = jnp.concatenate(image_gen, axis=1) image_gen = (image_gen - np.min(image_gen)) / (np.max(image_gen) - np.min(image_gen)) image_real = (image_real - np.min(image_real)) / (np.max(image_real) - np.min(image_real)) image = jnp.concatenate((image_gen, image_real), axis=0) image = np.uint8(image * 255) if image.shape[-1] == 1: image = np.repeat(image, 3, axis=-1) return Image.fromarray(image) def get_eval_snapshot(image, max_num=10): """ Creates a snapshot of generated images. Args: image (DeviceArray): Generated images, shape [B, H, W, C]. Returns: (PIL.Image): Eval snapshot. """ if image.shape[0] > max_num: image = image[:max_num] image = jnp.split(image, image.shape[0], axis=0) image = [jnp.squeeze(x, axis=0) for x in image] image = jnp.concatenate(image, axis=1) image = (image - np.min(image)) / (np.max(image) - np.min(image)) image = np.uint8(image * 255) if image.shape[-1] == 1: image = np.repeat(image, 3, axis=-1) return Image.fromarray(image) def get_vm_name(): gcp_metadata_url = "http://metadata.google.internal/computeMetadata/v1/instance/attributes/instance-id" req = Request(gcp_metadata_url) req.add_header('Metadata-Flavor', 'Google') instance_id = None try: with urlopen(req) as url: instance_id = url.read().decode() except urllib.error.URLError: # metadata.google.internal not reachable: use dev pass return instance_id