Shortcuts / sharpness /model.py
KublaiKhan1's picture
Upload folder using huggingface_hub
71fdb1b verified
import math
from typing import Any, Callable, Optional, Tuple, Type, Sequence, Union
import flax.linen as nn
import jax
import jax.numpy as jnp
from einops import rearrange
Array = Any
PRNGKey = Any
Shape = Tuple[int]
Dtype = Any
from math_utils import get_2d_sincos_pos_embed, modulate
from jax._src import core
from jax._src import dtypes
from jax._src.nn.initializers import _compute_fans
def xavier_uniform_pytorchlike():
def init(key, shape, dtype):
dtype = dtypes.canonicalize_dtype(dtype)
#named_shape = core.as_named_shape(shape)
if len(shape) == 2: # Dense, [in, out]
fan_in = shape[0]
fan_out = shape[1]
elif len(shape) == 4: # Conv, [k, k, in, out]. Assumes patch-embed style conv.
fan_in = shape[0] * shape[1] * shape[2]
fan_out = shape[3]
else:
raise ValueError(f"Invalid shape {shape}")
variance = 2 / (fan_in + fan_out)
scale = jnp.sqrt(3 * variance)
param = jax.random.uniform(key, shape, dtype, -1) * scale
return param
return init
class TrainConfig:
def __init__(self, dtype):
self.dtype = dtype
def kern_init(self, name='default', zero=False):
if zero or 'bias' in name:
return nn.initializers.constant(0)
return xavier_uniform_pytorchlike()
def default_config(self):
return {
'kernel_init': self.kern_init(),
'bias_init': self.kern_init('bias', zero=True),
'dtype': self.dtype,
}
class TimestepEmbedder(nn.Module):
"""
Embeds scalar timesteps into vector representations.
"""
hidden_size: int
tc: TrainConfig
frequency_embedding_size: int = 256
@nn.compact
def __call__(self, t):
x = self.timestep_embedding(t)
x = nn.Dense(self.hidden_size, kernel_init=nn.initializers.normal(0.02),
bias_init=self.tc.kern_init('time_bias'), dtype=self.tc.dtype)(x)
x = nn.silu(x)
x = nn.Dense(self.hidden_size, kernel_init=nn.initializers.normal(0.02),
bias_init=self.tc.kern_init('time_bias'))(x)
return x
# t is between [0, 1].
def timestep_embedding(self, t, max_period=10000):
"""
Create sinusoidal timestep embeddings.
:param t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, D) Tensor of positional embeddings.
"""
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
t = jax.lax.convert_element_type(t, jnp.float32)
# t = t * max_period
dim = self.frequency_embedding_size
half = dim // 2
freqs = jnp.exp( -math.log(max_period) * jnp.arange(start=0, stop=half, dtype=jnp.float32) / half)
args = t[:, None] * freqs[None]
embedding = jnp.concatenate([jnp.cos(args), jnp.sin(args)], axis=-1)
embedding = embedding.astype(self.tc.dtype)
return embedding
class LabelEmbedder(nn.Module):
"""
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
"""
num_classes: int
hidden_size: int
tc: TrainConfig
@nn.compact
def __call__(self, labels):
embedding_table = nn.Embed(self.num_classes + 1, self.hidden_size,
embedding_init=nn.initializers.normal(0.02), dtype=self.tc.dtype)
embeddings = embedding_table(labels)
return embeddings
class PatchEmbed(nn.Module):
""" 2D Image to Patch Embedding """
patch_size: int
hidden_size: int
tc: TrainConfig
bias: bool = True
@nn.compact
def __call__(self, x):
B, H, W, C = x.shape
patch_tuple = (self.patch_size, self.patch_size)
num_patches = (H // self.patch_size)
x = nn.Conv(self.hidden_size, patch_tuple, patch_tuple, use_bias=self.bias, padding="VALID",
kernel_init=self.tc.kern_init('patch'), bias_init=self.tc.kern_init('patch_bias', zero=True),
dtype=self.tc.dtype)(x) # (B, P, P, hidden_size)
x = rearrange(x, 'b h w c -> b (h w) c', h=num_patches, w=num_patches)
return x
class MlpBlock(nn.Module):
"""Transformer MLP / feed-forward block."""
mlp_dim: int
tc: TrainConfig
out_dim: Optional[int] = None
dropout_rate: float = None
train: bool = False
@nn.compact
def __call__(self, inputs):
"""It's just an MLP, so the input shape is (batch, len, emb)."""
actual_out_dim = inputs.shape[-1] if self.out_dim is None else self.out_dim
x = nn.Dense(features=self.mlp_dim, **self.tc.default_config())(inputs)
x = nn.gelu(x)
x = nn.Dropout(rate=self.dropout_rate, deterministic=(not self.train))(x)
output = nn.Dense(features=actual_out_dim, **self.tc.default_config())(x)
output = nn.Dropout(rate=self.dropout_rate, deterministic=(not self.train))(output)
return output
def modulate(x, shift, scale):
# scale = jnp.clip(scale, -1, 1)
return x * (1 + scale[:, None]) + shift[:, None]
################################################################################
# Core DiT Model #
#################################################################################
class DiTBlock(nn.Module):
"""
A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
"""
hidden_size: int
num_heads: int
tc: TrainConfig
mlp_ratio: float = 4.0
dropout: float = 0.0
train: bool = False
# @functools.partial(jax.checkpoint, policy=jax.checkpoint_policies.nothing_saveable)
@nn.compact
def __call__(self, x, c):
# Calculate adaLn modulation parameters.
c = nn.silu(c)
c = nn.Dense(6 * self.hidden_size, **self.tc.default_config())(c)
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = jnp.split(c, 6, axis=-1)
# Attention Residual.
x_norm = nn.LayerNorm(use_bias=False, use_scale=False, dtype=self.tc.dtype)(x)
x_modulated = modulate(x_norm, shift_msa, scale_msa)
channels_per_head = self.hidden_size // self.num_heads
k = nn.Dense(self.hidden_size, **self.tc.default_config())(x_modulated)
q = nn.Dense(self.hidden_size, **self.tc.default_config())(x_modulated)
v = nn.Dense(self.hidden_size, **self.tc.default_config())(x_modulated)
k = jnp.reshape(k, (k.shape[0], k.shape[1], self.num_heads, channels_per_head))
q = jnp.reshape(q, (q.shape[0], q.shape[1], self.num_heads, channels_per_head))
v = jnp.reshape(v, (v.shape[0], v.shape[1], self.num_heads, channels_per_head))
q = q / q.shape[3] # (1/d) scaling.
w = jnp.einsum('bqhc,bkhc->bhqk', q, k) # [B, HW, HW, num_heads]
w = w.astype(jnp.float32)
w = nn.softmax(w, axis=-1)
y = jnp.einsum('bhqk,bkhc->bqhc', w, v) # [B, HW, num_heads, channels_per_head]
y = jnp.reshape(y, x.shape) # [B, H, W, C] (C = heads * channels_per_head)
attn_x = nn.Dense(self.hidden_size, **self.tc.default_config())(y)
x = x + (gate_msa[:, None] * attn_x)
# MLP Residual.
x_norm2 = nn.LayerNorm(use_bias=False, use_scale=False, dtype=self.tc.dtype)(x)
x_modulated2 = modulate(x_norm2, shift_mlp, scale_mlp)
mlp_x = MlpBlock(mlp_dim=int(self.hidden_size * self.mlp_ratio), tc=self.tc,
dropout_rate=self.dropout, train=self.train)(x_modulated2)
x = x + (gate_mlp[:, None] * mlp_x)
return x
class FinalLayer(nn.Module):
"""
The final layer of DiT.
"""
patch_size: int
out_channels: int
hidden_size: int
tc: TrainConfig
@nn.compact
def __call__(self, x, c):
c = nn.silu(c)
c = nn.Dense(2 * self.hidden_size, kernel_init=self.tc.kern_init(zero=True),
bias_init=self.tc.kern_init('bias', zero=True), dtype=self.tc.dtype)(c)
shift, scale = jnp.split(c, 2, axis=-1)
x = nn.LayerNorm(use_bias=False, use_scale=False, dtype=self.tc.dtype)(x)
x = modulate(x, shift, scale)
x = nn.Dense(self.patch_size * self.patch_size * self.out_channels,
kernel_init=self.tc.kern_init('final', zero=True),
bias_init=self.tc.kern_init('final_bias', zero=True), dtype=self.tc.dtype)(x)
return x
import jax
import jax.numpy as jnp
def apply_label_embedding_noise(key, label_embeddings):
"""
Applies Gaussian noise to label embeddings based on specified probabilities.
Args:
key: A JAX random key.
label_embeddings: A JAX array of shape (batch_size, embedding_dim),
representing the label embeddings.
Returns:
A tuple containing:
- noisy_label_embeddings: The label embeddings with noise applied.
- noise_levels: A JAX array of shape (batch_size,), indicating
the alpha value used for each sample (1.0 for no noise,
0.0 for 100% noise, or a uniform sample for partial noise).
"""
batch_size, embedding_dim = label_embeddings.shape
# Split key for different random operations
key, noise_type_key, alpha_key, normal_key = jax.random.split(key, 4)
# Determine noise application type for each sample
# 0: 100% noise (alpha = 0)
# 1: Partial noise (alpha uniformly 0-1)
# 2: No noise (do nothing)
noise_type_choices = jax.random.choice(
noise_type_key,
a=jnp.array([0, 1, 2]),
shape=(batch_size,),
p=jnp.array([0.00, 0.10, 0.90])
)
# Initialize noise_levels to 1.0 (no noise)
noise_levels = jnp.ones(batch_size, dtype=label_embeddings.dtype)
# Generate alpha values for partial noise
sampled_alphas = jax.random.uniform(alpha_key, shape=(batch_size,), minval=0.0, maxval=1.0)
# Generate Gaussian noise for the entire batch
# We assume a standard deviation of 1 for the noise, you might want to adjust this.
gaussian_noise = jax.random.normal(normal_key, shape=label_embeddings.shape)
# Initialize noisy_label_embeddings
noisy_label_embeddings = label_embeddings
# Apply 100% noise
cond_100_percent_noise = (noise_type_choices == 0)
noisy_label_embeddings = jnp.where(
cond_100_percent_noise[:, None], # Expand dim for broadcasting
gaussian_noise,
noisy_label_embeddings
)
noise_levels = jnp.where(cond_100_percent_noise, 0.0, noise_levels)
# Apply partial noise
cond_partial_noise = (noise_type_choices == 1)
# Reshape sampled_alphas for broadcasting
alpha_reshaped = sampled_alphas[:, None]
if True:#normal
new_noise = label_embeddings * alpha_reshaped + gaussian_noise * (1.0 - alpha_reshaped)
elif True:
label_embeddings = label_embeddings/jnp.linalg.norm(label_embeddings, axis = -1, keepdims=True)
gaussian_noise = gaussian_noise/jnp.linalg.norm(gaussian_noise, axis = -1, keepdims=True)
new_noise = label_embeddings * alpha_reshaped + gaussian_noise * (1.0 - alpha_reshaped)
elif True:#slerp
label_embeddings = label_embeddings/jnp.linalg.norm(label_embeddings, axis = -1, keepdims=True)
gaussian_noise = gaussian_noise/jnp.linalg.norm(gaussian_noise, axis = -1, keepdims=True)
dot_product = jnp.sum(label_embeddings * gaussian_noise, axis=-1, keepdims=True)
theta = jnp.arccos(dot_product)
sin_theta = jnp.sin(theta)
new_noise = (jnp.sin((1.0 - alpha_reshaped) * theta) / sin_theta) * label_embeddings + \
(jnp.sin(alpha_reshaped * theta) / sin_theta) * gaussian_noise
noisy_label_embeddings = jnp.where(
cond_partial_noise[:, None],
new_noise,
noisy_label_embeddings
)
noise_levels = jnp.where(cond_partial_noise, sampled_alphas, noise_levels)
# For cond_no_noise (noise_type_choices == 2), noisy_label_embeddings remains
# label_embeddings and noise_levels remains 1.0, so no specific action needed.
return noisy_label_embeddings, noise_levels, key
class DiT(nn.Module):
"""
Diffusion model with a Transformer backbone.
"""
patch_size: int
hidden_size: int
depth: int
num_heads: int
mlp_ratio: float
out_channels: int
class_dropout_prob: float
num_classes: int
ignore_dt: bool = False
dropout: float = 0.0
dtype: Dtype = jnp.bfloat16
@nn.compact
def __call__(self, x, t, dt, y, train=False, return_activations=False, perturbe = True):
# (x = (B, H, W, C) image, t = (B,) timesteps, y = (B,) class labels)
print("DiT: Input of shape", x.shape, "dtype", x.dtype)
activations = {}
key = self.make_rng("label")
batch_size = x.shape[0]
input_size = x.shape[1]
in_channels = x.shape[-1]
num_patches = (input_size // self.patch_size) ** 2
num_patches_side = input_size // self.patch_size
tc = TrainConfig(dtype=self.dtype)
if self.ignore_dt:
dt = jnp.zeros_like(t)
# pos_embed = self.param("pos_embed", get_2d_sincos_pos_embed, self.hidden_size, num_patches)
# pos_embed = jax.lax.stop_gradient(pos_embed)
pos_embed = get_2d_sincos_pos_embed(None, self.hidden_size, num_patches)
x = PatchEmbed(self.patch_size, self.hidden_size, tc=tc)(x) # (B, num_patches, hidden_size)
print("DiT: After patch embed, shape is", x.shape, "dtype", x.dtype)
activations['patch_embed'] = x
x = x + pos_embed
x = x.astype(self.dtype)
te = TimestepEmbedder(self.hidden_size, tc=tc)(t) # (B, hidden_size)
dte = TimestepEmbedder(self.hidden_size, tc=tc)(dt) # (B, hidden_size)
ye = LabelEmbedder(self.num_classes, self.hidden_size, tc=tc)(y) # (B, hidden_size)
#CFG free, here!
#So we set CFG % to 0 during training via model commands
#Instead, we will apply gaussian noise to the label embeddings, and condition... somewhere, on that.
#So the perturbed version uses cfg between conditional and conditional, except the second one uses condition_amount = ones
#So we use condition_amount = zeros, then condition_amount = ones.
#Not sure how we indicate training mode. Maybe -1?
#x = int(x == 'true')
#Now we need a way to condition the forward pass..
def adjust_condition_amount(train, peturbe, condition_amount):
def true_fn(_):
return jnp.ones_like(condition_amount) # peturbe is True → ones
def false_fn(_):
return jnp.zeros_like(condition_amount) # peturbe is False → zeros
def train_false_branch(_):
return jax.lax.cond(peturbe, true_fn, false_fn, operand=None)
def train_true_branch(_):
return condition_amount # leave it unchanged during training
return jax.lax.cond(train, train_true_branch, train_false_branch, operand=None)
#When perturbe is true, we return ones = no noise
#When false, return zeros = full noise.
#For NON training, we don't want to actually modify the labels, only the conditioning.
#So default during training is apply
def apply_fn(key, ye, train):
def true_branch(args):
key, ye = args
ye_new, condition_amount, key_new = apply_label_embedding_noise(key, ye)
return ye_new.astype(jnp.float32), condition_amount, key_new
def false_branch(args):
key, ye = args
ye_new, condition_amount, key_new = apply_label_embedding_noise(key, ye)
return ye.astype(jnp.float32), condition_amount, key_new
return jax.lax.cond(train, true_branch, false_branch, (key, ye))
#print("train is", train)#False
#print("perturbe is", perturbe)#False right now (it's getting passed properly)
#print("initial ye", ye[0][0:10])
ye, condition_amount, key = apply_fn(key, ye, train)
#print("new ye", ye[0][0:10])
#print("condition amount", condition_amount)
condition_amount = adjust_condition_amount(train, perturbe, condition_amount)
#print("adjusted", condition_amount)
#one thing I wanted to try during inference, is we add a little bit of noise to our 'uncond' input, in addition to conditioning.
#Let's try .05 or some thing?
#We can also try making te and dte less trustworthy.
ye_g = TimestepEmbedder(self.hidden_size, tc=tc)(condition_amount)
#freeze embedding table.
#ye = ye / ye.std()
#ye = jax.lax.stop_gradient(ye)
c = te + ye + dte + ye_g
activations['pos_embed'] = pos_embed
activations['time_embed'] = te
activations['dt_embed'] = dte
activations['label_embed'] = ye
activations['conditioning'] = c
print("DiT: Patch Embed of shape", x.shape, "dtype", x.dtype)
print("DiT: Conditioning of shape", c.shape, "dtype", c.dtype)
for i in range(self.depth):
x = DiTBlock(self.hidden_size, self.num_heads, tc, self.mlp_ratio, self.dropout, train)(x, c)
activations[f'dit_block_{i}'] = x
x = FinalLayer(self.patch_size, self.out_channels, self.hidden_size, tc)(x, c) # (B, num_patches, p*p*c)
activations['final_layer'] = x
# print("DiT: FinalLayer of shape", x.shape, "dtype", x.dtype)
x = jnp.reshape(x, (batch_size, num_patches_side, num_patches_side,
self.patch_size, self.patch_size, self.out_channels))
x = jnp.einsum('bhwpqc->bhpwqc', x)
x = rearrange(x, 'B H P W Q C -> B (H P) (W Q) C', H=int(num_patches_side), W=int(num_patches_side))
assert x.shape == (batch_size, input_size, input_size, self.out_channels)
t_discrete = jnp.floor(t * 256).astype(jnp.int32)
logvars = nn.Embed(256, 1, embedding_init=nn.initializers.constant(0))(t_discrete) * 100
if return_activations:
return x, logvars, activations
return x#, dte, te