import math from typing import Any, Callable, Optional, Tuple, Type, Sequence, Union import flax.linen as nn import jax import jax.numpy as jnp from einops import rearrange Array = Any PRNGKey = Any Shape = Tuple[int] Dtype = Any from math_utils import get_2d_sincos_pos_embed, modulate from jax._src import core from jax._src import dtypes from jax._src.nn.initializers import _compute_fans def xavier_uniform_pytorchlike(): def init(key, shape, dtype): dtype = dtypes.canonicalize_dtype(dtype) #named_shape = core.as_named_shape(shape) if len(shape) == 2: # Dense, [in, out] fan_in = shape[0] fan_out = shape[1] elif len(shape) == 4: # Conv, [k, k, in, out]. Assumes patch-embed style conv. fan_in = shape[0] * shape[1] * shape[2] fan_out = shape[3] else: raise ValueError(f"Invalid shape {shape}") variance = 2 / (fan_in + fan_out) scale = jnp.sqrt(3 * variance) param = jax.random.uniform(key, shape, dtype, -1) * scale return param return init class TrainConfig: def __init__(self, dtype): self.dtype = dtype def kern_init(self, name='default', zero=False): if zero or 'bias' in name: return nn.initializers.constant(0) return xavier_uniform_pytorchlike() def default_config(self): return { 'kernel_init': self.kern_init(), 'bias_init': self.kern_init('bias', zero=True), 'dtype': self.dtype, } class TimestepEmbedder(nn.Module): """ Embeds scalar timesteps into vector representations. """ hidden_size: int tc: TrainConfig frequency_embedding_size: int = 256 @nn.compact def __call__(self, t): x = self.timestep_embedding(t) x = nn.Dense(self.hidden_size, kernel_init=nn.initializers.normal(0.02), bias_init=self.tc.kern_init('time_bias'), dtype=self.tc.dtype)(x) x = nn.silu(x) x = nn.Dense(self.hidden_size, kernel_init=nn.initializers.normal(0.02), bias_init=self.tc.kern_init('time_bias'))(x) return x # t is between [0, 1]. def timestep_embedding(self, t, max_period=10000): """ Create sinusoidal timestep embeddings. :param t: a 1-D Tensor of N indices, one per batch element. These may be fractional. :param dim: the dimension of the output. :param max_period: controls the minimum frequency of the embeddings. :return: an (N, D) Tensor of positional embeddings. """ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py t = jax.lax.convert_element_type(t, jnp.float32) # t = t * max_period dim = self.frequency_embedding_size half = dim // 2 freqs = jnp.exp( -math.log(max_period) * jnp.arange(start=0, stop=half, dtype=jnp.float32) / half) args = t[:, None] * freqs[None] embedding = jnp.concatenate([jnp.cos(args), jnp.sin(args)], axis=-1) embedding = embedding.astype(self.tc.dtype) return embedding class LabelEmbedder(nn.Module): """ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. """ num_classes: int hidden_size: int tc: TrainConfig @nn.compact def __call__(self, labels): embedding_table = nn.Embed(self.num_classes + 1, self.hidden_size, embedding_init=nn.initializers.normal(0.02), dtype=self.tc.dtype) embeddings = embedding_table(labels) return embeddings class PatchEmbed(nn.Module): """ 2D Image to Patch Embedding """ patch_size: int hidden_size: int tc: TrainConfig bias: bool = True @nn.compact def __call__(self, x): B, H, W, C = x.shape patch_tuple = (self.patch_size, self.patch_size) num_patches = (H // self.patch_size) x = nn.Conv(self.hidden_size, patch_tuple, patch_tuple, use_bias=self.bias, padding="VALID", kernel_init=self.tc.kern_init('patch'), bias_init=self.tc.kern_init('patch_bias', zero=True), dtype=self.tc.dtype)(x) # (B, P, P, hidden_size) x = rearrange(x, 'b h w c -> b (h w) c', h=num_patches, w=num_patches) return x class MlpBlock(nn.Module): """Transformer MLP / feed-forward block.""" mlp_dim: int tc: TrainConfig out_dim: Optional[int] = None dropout_rate: float = None train: bool = False @nn.compact def __call__(self, inputs): """It's just an MLP, so the input shape is (batch, len, emb).""" actual_out_dim = inputs.shape[-1] if self.out_dim is None else self.out_dim x = nn.Dense(features=self.mlp_dim, **self.tc.default_config())(inputs) x = nn.gelu(x) x = nn.Dropout(rate=self.dropout_rate, deterministic=(not self.train))(x) output = nn.Dense(features=actual_out_dim, **self.tc.default_config())(x) output = nn.Dropout(rate=self.dropout_rate, deterministic=(not self.train))(output) return output def modulate(x, shift, scale): # scale = jnp.clip(scale, -1, 1) return x * (1 + scale[:, None]) + shift[:, None] ################################################################################ # Core DiT Model # ################################################################################# class DiTBlock(nn.Module): """ A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning. """ hidden_size: int num_heads: int tc: TrainConfig mlp_ratio: float = 4.0 dropout: float = 0.0 train: bool = False # @functools.partial(jax.checkpoint, policy=jax.checkpoint_policies.nothing_saveable) @nn.compact def __call__(self, x, c): # Calculate adaLn modulation parameters. c = nn.silu(c) c = nn.Dense(6 * self.hidden_size, **self.tc.default_config())(c) shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = jnp.split(c, 6, axis=-1) # Attention Residual. x_norm = nn.LayerNorm(use_bias=False, use_scale=False, dtype=self.tc.dtype)(x) x_modulated = modulate(x_norm, shift_msa, scale_msa) channels_per_head = self.hidden_size // self.num_heads k = nn.Dense(self.hidden_size, **self.tc.default_config())(x_modulated) q = nn.Dense(self.hidden_size, **self.tc.default_config())(x_modulated) v = nn.Dense(self.hidden_size, **self.tc.default_config())(x_modulated) k = jnp.reshape(k, (k.shape[0], k.shape[1], self.num_heads, channels_per_head)) q = jnp.reshape(q, (q.shape[0], q.shape[1], self.num_heads, channels_per_head)) v = jnp.reshape(v, (v.shape[0], v.shape[1], self.num_heads, channels_per_head)) q = q / q.shape[3] # (1/d) scaling. w = jnp.einsum('bqhc,bkhc->bhqk', q, k) # [B, HW, HW, num_heads] w = w.astype(jnp.float32) w = nn.softmax(w, axis=-1) y = jnp.einsum('bhqk,bkhc->bqhc', w, v) # [B, HW, num_heads, channels_per_head] y = jnp.reshape(y, x.shape) # [B, H, W, C] (C = heads * channels_per_head) attn_x = nn.Dense(self.hidden_size, **self.tc.default_config())(y) x = x + (gate_msa[:, None] * attn_x) # MLP Residual. x_norm2 = nn.LayerNorm(use_bias=False, use_scale=False, dtype=self.tc.dtype)(x) x_modulated2 = modulate(x_norm2, shift_mlp, scale_mlp) mlp_x = MlpBlock(mlp_dim=int(self.hidden_size * self.mlp_ratio), tc=self.tc, dropout_rate=self.dropout, train=self.train)(x_modulated2) x = x + (gate_mlp[:, None] * mlp_x) return x class FinalLayer(nn.Module): """ The final layer of DiT. """ patch_size: int out_channels: int hidden_size: int tc: TrainConfig @nn.compact def __call__(self, x, c): c = nn.silu(c) c = nn.Dense(2 * self.hidden_size, kernel_init=self.tc.kern_init(zero=True), bias_init=self.tc.kern_init('bias', zero=True), dtype=self.tc.dtype)(c) shift, scale = jnp.split(c, 2, axis=-1) x = nn.LayerNorm(use_bias=False, use_scale=False, dtype=self.tc.dtype)(x) x = modulate(x, shift, scale) x = nn.Dense(self.patch_size * self.patch_size * self.out_channels, kernel_init=self.tc.kern_init('final', zero=True), bias_init=self.tc.kern_init('final_bias', zero=True), dtype=self.tc.dtype)(x) return x import jax import jax.numpy as jnp def apply_label_embedding_noise(key, label_embeddings): """ Applies Gaussian noise to label embeddings based on specified probabilities. Args: key: A JAX random key. label_embeddings: A JAX array of shape (batch_size, embedding_dim), representing the label embeddings. Returns: A tuple containing: - noisy_label_embeddings: The label embeddings with noise applied. - noise_levels: A JAX array of shape (batch_size,), indicating the alpha value used for each sample (1.0 for no noise, 0.0 for 100% noise, or a uniform sample for partial noise). """ batch_size, embedding_dim = label_embeddings.shape # Split key for different random operations key, noise_type_key, alpha_key, normal_key = jax.random.split(key, 4) # Determine noise application type for each sample # 0: 100% noise (alpha = 0) # 1: Partial noise (alpha uniformly 0-1) # 2: No noise (do nothing) noise_type_choices = jax.random.choice( noise_type_key, a=jnp.array([0, 1, 2]), shape=(batch_size,), p=jnp.array([0.00, 0.10, 0.90]) ) # Initialize noise_levels to 1.0 (no noise) noise_levels = jnp.ones(batch_size, dtype=label_embeddings.dtype) # Generate alpha values for partial noise sampled_alphas = jax.random.uniform(alpha_key, shape=(batch_size,), minval=0.0, maxval=1.0) # Generate Gaussian noise for the entire batch # We assume a standard deviation of 1 for the noise, you might want to adjust this. gaussian_noise = jax.random.normal(normal_key, shape=label_embeddings.shape) # Initialize noisy_label_embeddings noisy_label_embeddings = label_embeddings # Apply 100% noise cond_100_percent_noise = (noise_type_choices == 0) noisy_label_embeddings = jnp.where( cond_100_percent_noise[:, None], # Expand dim for broadcasting gaussian_noise, noisy_label_embeddings ) noise_levels = jnp.where(cond_100_percent_noise, 0.0, noise_levels) # Apply partial noise cond_partial_noise = (noise_type_choices == 1) # Reshape sampled_alphas for broadcasting alpha_reshaped = sampled_alphas[:, None] if True:#normal new_noise = label_embeddings * alpha_reshaped + gaussian_noise * (1.0 - alpha_reshaped) elif True: label_embeddings = label_embeddings/jnp.linalg.norm(label_embeddings, axis = -1, keepdims=True) gaussian_noise = gaussian_noise/jnp.linalg.norm(gaussian_noise, axis = -1, keepdims=True) new_noise = label_embeddings * alpha_reshaped + gaussian_noise * (1.0 - alpha_reshaped) elif True:#slerp label_embeddings = label_embeddings/jnp.linalg.norm(label_embeddings, axis = -1, keepdims=True) gaussian_noise = gaussian_noise/jnp.linalg.norm(gaussian_noise, axis = -1, keepdims=True) dot_product = jnp.sum(label_embeddings * gaussian_noise, axis=-1, keepdims=True) theta = jnp.arccos(dot_product) sin_theta = jnp.sin(theta) new_noise = (jnp.sin((1.0 - alpha_reshaped) * theta) / sin_theta) * label_embeddings + \ (jnp.sin(alpha_reshaped * theta) / sin_theta) * gaussian_noise noisy_label_embeddings = jnp.where( cond_partial_noise[:, None], new_noise, noisy_label_embeddings ) noise_levels = jnp.where(cond_partial_noise, sampled_alphas, noise_levels) # For cond_no_noise (noise_type_choices == 2), noisy_label_embeddings remains # label_embeddings and noise_levels remains 1.0, so no specific action needed. return noisy_label_embeddings, noise_levels, key class DiT(nn.Module): """ Diffusion model with a Transformer backbone. """ patch_size: int hidden_size: int depth: int num_heads: int mlp_ratio: float out_channels: int class_dropout_prob: float num_classes: int ignore_dt: bool = False dropout: float = 0.0 dtype: Dtype = jnp.bfloat16 @nn.compact def __call__(self, x, t, dt, y, train=False, return_activations=False, perturbe = True): # (x = (B, H, W, C) image, t = (B,) timesteps, y = (B,) class labels) print("DiT: Input of shape", x.shape, "dtype", x.dtype) activations = {} key = self.make_rng("label") batch_size = x.shape[0] input_size = x.shape[1] in_channels = x.shape[-1] num_patches = (input_size // self.patch_size) ** 2 num_patches_side = input_size // self.patch_size tc = TrainConfig(dtype=self.dtype) if self.ignore_dt: dt = jnp.zeros_like(t) # pos_embed = self.param("pos_embed", get_2d_sincos_pos_embed, self.hidden_size, num_patches) # pos_embed = jax.lax.stop_gradient(pos_embed) pos_embed = get_2d_sincos_pos_embed(None, self.hidden_size, num_patches) x = PatchEmbed(self.patch_size, self.hidden_size, tc=tc)(x) # (B, num_patches, hidden_size) print("DiT: After patch embed, shape is", x.shape, "dtype", x.dtype) activations['patch_embed'] = x x = x + pos_embed x = x.astype(self.dtype) te = TimestepEmbedder(self.hidden_size, tc=tc)(t) # (B, hidden_size) dte = TimestepEmbedder(self.hidden_size, tc=tc)(dt) # (B, hidden_size) ye = LabelEmbedder(self.num_classes, self.hidden_size, tc=tc)(y) # (B, hidden_size) #CFG free, here! #So we set CFG % to 0 during training via model commands #Instead, we will apply gaussian noise to the label embeddings, and condition... somewhere, on that. #So the perturbed version uses cfg between conditional and conditional, except the second one uses condition_amount = ones #So we use condition_amount = zeros, then condition_amount = ones. #Not sure how we indicate training mode. Maybe -1? #x = int(x == 'true') #Now we need a way to condition the forward pass.. def adjust_condition_amount(train, peturbe, condition_amount): def true_fn(_): return jnp.ones_like(condition_amount) # peturbe is True → ones def false_fn(_): return jnp.zeros_like(condition_amount) # peturbe is False → zeros def train_false_branch(_): return jax.lax.cond(peturbe, true_fn, false_fn, operand=None) def train_true_branch(_): return condition_amount # leave it unchanged during training return jax.lax.cond(train, train_true_branch, train_false_branch, operand=None) #When perturbe is true, we return ones = no noise #When false, return zeros = full noise. #For NON training, we don't want to actually modify the labels, only the conditioning. #So default during training is apply def apply_fn(key, ye, train): def true_branch(args): key, ye = args ye_new, condition_amount, key_new = apply_label_embedding_noise(key, ye) return ye_new.astype(jnp.float32), condition_amount, key_new def false_branch(args): key, ye = args ye_new, condition_amount, key_new = apply_label_embedding_noise(key, ye) return ye.astype(jnp.float32), condition_amount, key_new return jax.lax.cond(train, true_branch, false_branch, (key, ye)) #print("train is", train)#False #print("perturbe is", perturbe)#False right now (it's getting passed properly) #print("initial ye", ye[0][0:10]) ye, condition_amount, key = apply_fn(key, ye, train) #print("new ye", ye[0][0:10]) #print("condition amount", condition_amount) condition_amount = adjust_condition_amount(train, perturbe, condition_amount) #print("adjusted", condition_amount) #one thing I wanted to try during inference, is we add a little bit of noise to our 'uncond' input, in addition to conditioning. #Let's try .05 or some thing? #We can also try making te and dte less trustworthy. ye_g = TimestepEmbedder(self.hidden_size, tc=tc)(condition_amount) #freeze embedding table. #ye = ye / ye.std() #ye = jax.lax.stop_gradient(ye) c = te + ye + dte + ye_g activations['pos_embed'] = pos_embed activations['time_embed'] = te activations['dt_embed'] = dte activations['label_embed'] = ye activations['conditioning'] = c print("DiT: Patch Embed of shape", x.shape, "dtype", x.dtype) print("DiT: Conditioning of shape", c.shape, "dtype", c.dtype) for i in range(self.depth): x = DiTBlock(self.hidden_size, self.num_heads, tc, self.mlp_ratio, self.dropout, train)(x, c) activations[f'dit_block_{i}'] = x x = FinalLayer(self.patch_size, self.out_channels, self.hidden_size, tc)(x, c) # (B, num_patches, p*p*c) activations['final_layer'] = x # print("DiT: FinalLayer of shape", x.shape, "dtype", x.dtype) x = jnp.reshape(x, (batch_size, num_patches_side, num_patches_side, self.patch_size, self.patch_size, self.out_channels)) x = jnp.einsum('bhwpqc->bhpwqc', x) x = rearrange(x, 'B H P W Q C -> B (H P) (W Q) C', H=int(num_patches_side), W=int(num_patches_side)) assert x.shape == (batch_size, input_size, input_size, self.out_channels) t_discrete = jnp.floor(t * 256).astype(jnp.int32) logvars = nn.Embed(256, 1, embedding_init=nn.initializers.constant(0))(t_discrete) * 100 if return_activations: return x, logvars, activations return x#, dte, te