Aatricks's picture
Deploy ZeroGPU Gradio Space snapshot
b701455
"""Flux2 transformer layers for LightDiffusion-Next.
Core building blocks for the Flux2 architecture:
- Attention mechanisms
- Modulation layers
- Transformer blocks (double and single stream)
- Embedding layers
Adapted from ComfyUI's Flux implementation for LightDiffusion-Next.
"""
import math
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from src.cond import cast as ops_module
from src.Device import Device
# Get operations module
def get_ops():
"""Get the operations module for weight initialization."""
return ops_module.disable_weight_init
class RMSNorm(nn.Module):
"""Root Mean Square Layer Normalization.
Uses native PyTorch rms_norm when available for numerical consistency with ComfyUI.
"""
def __init__(self, dim: int, eps: float = 1e-6, dtype=None, device=None):
super().__init__()
self.eps = eps
# Use 'scale' to match Flux2 checkpoint naming convention
self.scale = nn.Parameter(torch.ones(dim, dtype=dtype, device=device))
# Check if native rms_norm is available
self._use_native = hasattr(torch.nn.functional, 'rms_norm')
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Ensure scale is on the same device as input
scale = self.scale.to(x.device, x.dtype)
if self._use_native and not (torch.jit.is_tracing() or torch.jit.is_scripting()):
# Use native PyTorch rms_norm for better precision (matches ComfyUI)
return torch.nn.functional.rms_norm(x, scale.shape, weight=scale, eps=self.eps)
else:
# Fallback implementation
rms = torch.rsqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
return x * rms * scale
class EmbedND(nn.Module):
"""N-dimensional positional embedding using RoPE."""
def __init__(self, dim: int, theta: int, axes_dim: list[int]):
super().__init__()
self.dim = dim
self.theta = theta
self.axes_dim = axes_dim
def forward(self, ids: torch.Tensor) -> torch.Tensor:
"""Compute rotary positional embeddings.
Args:
ids: Position IDs tensor of shape [batch, seq_len, num_axes]
Returns:
Rotary embeddings of shape [batch, seq_len, dim]
"""
n_axes = ids.shape[-1]
emb = torch.cat(
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
dim=-3,
)
return emb.unsqueeze(1)
def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
"""Compute rotary position embeddings.
Matches ComfyUI's implementation exactly for numerical precision.
Args:
pos: Position indices
dim: Embedding dimension
theta: Base frequency
Returns:
Rotary embeddings as float32 concatenation of cos and sin
"""
assert dim % 2 == 0
device = pos.device
# ComfyUI uses float64 for scale calculation for maximum precision
scale = torch.linspace(0, (dim - 2) / dim, dim // 2, dtype=torch.float64, device=device)
omega = 1.0 / (theta ** scale)
# Einsum for position-frequency interaction - cast pos to float32 like ComfyUI
out = torch.einsum("...n,d->...nd", pos.to(dtype=torch.float32, device=device), omega)
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
# ComfyUI always returns float32 for RoPE embeddings
return out.to(dtype=torch.float32, device=pos.device)
class MLPEmbedder(nn.Module):
"""MLP for timestep and guidance embeddings."""
def __init__(self, in_dim: int, hidden_dim: int, dtype=None, device=None, operations=None, ops_bias: bool = True):
super().__init__()
if operations is None:
operations = get_ops()
self.in_layer = operations.Linear(in_dim, hidden_dim, bias=ops_bias, dtype=dtype, device=device)
self.out_layer = operations.Linear(hidden_dim, hidden_dim, bias=ops_bias, dtype=dtype, device=device)
self.silu = nn.SiLU()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.out_layer(self.silu(self.in_layer(x)))
class GatedMLP(nn.Module):
"""Gated MLP (SwiGLU) for Klein models.
Structure: hidden -> 2*intermediate -> SiLU gate -> intermediate -> hidden
The first linear produces gate and value activations,
SiLU is applied to gate, then gate * value, then final projection.
"""
def __init__(self, hidden_size: int, intermediate_size: int, dtype=None, device=None, operations=None, ops_bias: bool = True):
super().__init__()
if operations is None:
operations = get_ops()
# First layer outputs 2x intermediate for gating
self.gate_up_proj = operations.Linear(hidden_size, intermediate_size * 2, bias=ops_bias, dtype=dtype, device=device)
self.down_proj = operations.Linear(intermediate_size, hidden_size, bias=ops_bias, dtype=dtype, device=device)
self.act = nn.SiLU()
def forward(self, x: torch.Tensor) -> torch.Tensor:
gate_up = self.gate_up_proj(x)
gate, up = gate_up.chunk(2, dim=-1)
return self.down_proj(self.act(gate) * up)
class QKNorm(nn.Module):
"""Query-Key normalization layer."""
def __init__(self, dim: int, dtype=None, device=None, operations=None):
super().__init__()
# Use native RMSNorm instead of operations.RMSNorm
self.query_norm = RMSNorm(dim, dtype=dtype, device=device)
self.key_norm = RMSNorm(dim, dtype=dtype, device=device)
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
q = self.query_norm(q)
k = self.key_norm(k)
# Cast to v's dtype and device to match ComfyUI (crucial for numerical consistency)
return q.to(v), k.to(v)
class SelfAttention(nn.Module):
"""Self-attention with rotary position embedding (RoPE)."""
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
dtype=None,
device=None,
operations=None,
ops_bias: bool = True,
):
super().__init__()
if operations is None:
operations = get_ops()
self.num_heads = num_heads
head_dim = dim // num_heads
self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
self.proj = operations.Linear(dim, dim, bias=ops_bias, dtype=dtype, device=device)
def forward(self, x: torch.Tensor, pe: torch.Tensor) -> torch.Tensor:
qkv = self.qkv(x)
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
q, k = self.norm(q, k, v)
x = attention(q, k, v, pe=pe)
x = self.proj(x)
return x
def attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, pe: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
"""Apply attention with rotary position embeddings.
Args:
q: Query tensor [batch, heads, seq, dim]
k: Key tensor [batch, heads, seq, dim]
v: Value tensor [batch, heads, seq, dim]
pe: Positional embeddings
mask: Optional attention mask for padding tokens
Returns:
Attention output [batch, seq, heads*dim]
"""
# Validate positional embedding sequence length to prevent RoPE shape errors
if pe is not None:
try:
pe_seq = pe.shape[2] if pe.ndim >= 3 else None
if pe_seq not in (1, q.shape[2]):
raise ValueError(
f"RoPE sequence length mismatch: pe.seq={pe_seq} != q.seq={q.shape[2]}. "
"Transformer options (img_h/img_w) may not match the input token grid; check calc_cond_batch merging of transformer_options."
)
except Exception:
# Re-raise as a clear ValueError for easier debugging
raise
q, k = apply_rope(q, k, pe)
# Efficient attention implementation
heads = q.shape[1]
x = optimized_attention(q, k, v, heads, mask=mask)
return x
def apply_rope1(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
"""Apply rotary position embedding to a single tensor.
Correctly applies the 2x2 rotation matrix:
y1 = x1 * cos - x2 * sin
y2 = x1 * sin + x2 * cos
Args:
x: Input tensor [batch, heads, seq, dim]
freqs_cis: Frequency tensor [batch, 1, seq, dim//2, 2, 2]
Returns:
Rotated tensor [batch, heads, seq, dim]
"""
# Reshape x to match RoPE components [batch, heads, seq, dim//2, 2]
x_reshaped = x.reshape(*x.shape[:-1], -1, 2)
# Handle differing sequence lengths between x and freqs_cis
# freqs_cis shape: [batch, 1, seq_pe, dim//2, 2, 2]
seq_x = x.shape[2]
seq_pe = freqs_cis.shape[2]
if seq_pe != seq_x:
if seq_pe < seq_x:
# Upsample by repeating along sequence dimension then slice to exact length
repeat = (seq_x + seq_pe - 1) // seq_pe
freqs_cis = freqs_cis.repeat_interleave(repeat, dim=2)[..., :seq_x, :, :, :]
else:
# Slice to match x sequence length
freqs_cis = freqs_cis[..., :seq_x, :, :, :]
# Sanity-check: feature dimension (half of head dim) must match freqs_cis
feat_half = x.shape[-1] // 2
if freqs_cis.shape[-3] != feat_half:
raise ValueError(
f"RoPE feature-dim mismatch: freqs_cis.dim={freqs_cis.shape[-3]} != x.dim/2={feat_half}. "
f"x.shape={x.shape}, freqs_cis.shape={freqs_cis.shape}"
)
# Extract rotation matrix components
# freqs_cis is [..., dim//2, row, col]
# row 0: [cos, -sin]
# row 1: [sin, cos]
cos = freqs_cis[..., 0, 0]
msin = freqs_cis[..., 0, 1] # -sin
sin = freqs_cis[..., 1, 0]
x1 = x_reshaped[..., 0]
x2 = x_reshaped[..., 1]
# Apply rotation
out1 = x1 * cos + x2 * msin
out2 = x1 * sin + x2 * cos
# Combine and reshape back to original
return torch.stack([out1, out2], dim=-1).reshape(*x.shape).type_as(x)
def apply_rope(q: torch.Tensor, k: torch.Tensor, pe: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Apply rotary position embeddings to queries and keys.
Args:
q: Query tensor [batch, heads, seq, dim]
k: Key tensor [batch, heads, seq, dim]
pe: Positional embeddings [..., dim//2, 2, 2]
Returns:
Rotated (q, k) tensors
"""
return apply_rope1(q, pe), apply_rope1(k, pe)
def optimized_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, heads: int, mask: torch.Tensor = None) -> torch.Tensor:
"""Optimized attention using Flash/SDPA with fallback to xformers.
Performance priority: cuDNN > Flash > SDPA > xformers > naive
Uses SDPA backend priority from Device module for optimal dispatch.
"""
b, _, seq_q, dim = q.shape
_, _, seq_kv, _ = k.shape
# Method 1: Use native scaled_dot_product_attention with backend priority
# This is the fastest path on modern PyTorch with GPU support
if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
try:
# Get SDPA backend priority context manager from Device
sdpa_context = Device.get_sdpa_context()
# Process attention mask for SDPA if provided
attn_mask = None
if mask is not None:
# Add dimensions as needed: [B, L] -> [B, 1, 1, L] for broadcasting
if mask.ndim == 2:
attn_mask = mask.unsqueeze(1).unsqueeze(1) # [B, 1, 1, L]
elif mask.ndim == 3:
attn_mask = mask.unsqueeze(1) # [B, 1, L, L]
else:
attn_mask = mask
# Convert mask to additive form (0 for attend, -inf for mask)
# Input mask is 1 for valid, 0 for invalid (padding)
attn_mask = attn_mask.to(dtype=q.dtype)
attn_mask = (1.0 - attn_mask) * torch.finfo(q.dtype).min
# SDPA expects [batch, heads, seq, dim] - q/k/v are already in this format
with sdpa_context:
out = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
# Reshape: [batch, heads, seq, dim] -> [batch, seq, heads*dim]
# Use transpose + view for efficiency (avoid copy)
out = out.transpose(1, 2).reshape(b, seq_q, -1)
return out
except Exception:
pass # Fall through to xformers
# Method 2: Use xformers memory-efficient attention
if Device.xformers_enabled():
try:
import xformers.ops as xops
# xformers expects [batch, seq, heads, dim]
q_xf = q.transpose(1, 2).contiguous()
k_xf = k.transpose(1, 2).contiguous()
v_xf = v.transpose(1, 2).contiguous()
# Note: xformers has different mask format, conversion would be needed
out = xops.memory_efficient_attention(q_xf, k_xf, v_xf)
del q_xf, k_xf, v_xf # Free memory early
# Reshape: [batch, seq, heads, dim] -> [batch, seq, heads*dim]
out = out.reshape(b, seq_q, -1)
return out
except Exception:
pass # Fall through to naive
# Method 3: Naive implementation (slowest, memory intensive)
out = F.scaled_dot_product_attention(q, k, v)
out = out.transpose(1, 2).reshape(b, seq_q, -1)
return out
@dataclass
class ModulationOut:
"""Output of modulation layer."""
shift: torch.Tensor
scale: torch.Tensor
gate: torch.Tensor
class Modulation(nn.Module):
"""Adaptive layer normalization modulation.
Applies shift, scale, and gate from conditioning vector.
"""
def __init__(self, dim: int, double: bool, dtype=None, device=None, operations=None, ops_bias: bool = True):
super().__init__()
if operations is None:
operations = get_ops()
self.is_double = double
self.multiplier = 6 if double else 3
self.lin = operations.Linear(dim, self.multiplier * dim, bias=ops_bias, dtype=dtype, device=device)
def forward(self, vec: torch.Tensor) -> tuple[ModulationOut, ModulationOut | None]:
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
mod1 = ModulationOut(shift=out[0], scale=out[1], gate=out[2])
mod2 = ModulationOut(shift=out[3], scale=out[4], gate=out[5]) if self.is_double else None
return mod1, mod2
class GlobalModulation(nn.Module):
"""Global modulation for Flux2 (Klein) double stream blocks."""
def __init__(self, dim: int, dtype=None, device=None, operations=None, ops_bias: bool = True):
super().__init__()
if operations is None:
operations = get_ops()
# 12 outputs: 6 for img stream, 6 for txt stream
self.lin = operations.Linear(dim, 12 * dim, bias=ops_bias, dtype=dtype, device=device)
def forward(self, vec: torch.Tensor) -> tuple[ModulationOut, ModulationOut, ModulationOut, ModulationOut]:
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(12, dim=-1)
mod1_img = ModulationOut(shift=out[0], scale=out[1], gate=out[2])
mod2_img = ModulationOut(shift=out[3], scale=out[4], gate=out[5])
mod1_txt = ModulationOut(shift=out[6], scale=out[7], gate=out[8])
mod2_txt = ModulationOut(shift=out[9], scale=out[10], gate=out[11])
return mod1_img, mod2_img, mod1_txt, mod2_txt
class DoubleStreamBlock(nn.Module):
"""Transformer block with separate image and text streams.
Uses joint attention but separate MLPs for image and text.
"""
def __init__(
self,
hidden_size: int,
num_heads: int,
mlp_ratio: float,
qkv_bias: bool = False,
global_modulation: bool = False,
dtype=None,
device=None,
operations=None,
flax_compatible: bool = False,
silu_mlp: bool = False,
gated_mlp: bool = False,
ops_bias: bool = True, # Whether to use bias in linear layers
):
super().__init__()
if operations is None:
operations = get_ops()
self.hidden_size = hidden_size
self.num_heads = num_heads
self.flax_compatible = flax_compatible
self.silu_mlp = silu_mlp
self.gated_mlp = gated_mlp
# For gated MLP (Klein), mlp_ratio is the true ratio
# First layer outputs 2x for gating: hidden -> 2*intermediate
# Second layer: intermediate -> hidden
if gated_mlp:
mlp_intermediate = int(hidden_size * mlp_ratio)
mlp_hidden_dim = mlp_intermediate * 2 # Double for gate+up projection
else:
mlp_hidden_dim = int(hidden_size * mlp_ratio)
mlp_intermediate = mlp_hidden_dim
if global_modulation:
# When using global modulation at model level, don't create per-block modulation
self.double_stream_modulation = None
self.img_mod = None
self.txt_mod = None
self.use_global_modulation = True
else:
self.double_stream_modulation = None
self.img_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations, ops_bias=ops_bias)
self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations, ops_bias=ops_bias)
self.use_global_modulation = False
# Image stream
self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.img_attn = SelfAttention(hidden_size, num_heads, qkv_bias, dtype=dtype, device=device, operations=operations, ops_bias=ops_bias)
self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
if gated_mlp:
# Gated MLP with naming compatible with checkpoint: .0, .1 (identity), .2
self.img_mlp = nn.Sequential(
operations.Linear(hidden_size, mlp_hidden_dim, bias=ops_bias, dtype=dtype, device=device),
nn.Identity(), # Placeholder for index 1
operations.Linear(mlp_intermediate, hidden_size, bias=ops_bias, dtype=dtype, device=device),
)
else:
self.img_mlp = nn.Sequential(
operations.Linear(hidden_size, mlp_hidden_dim, bias=ops_bias, dtype=dtype, device=device),
nn.SiLU() if silu_mlp else nn.GELU(approximate="tanh"),
operations.Linear(mlp_hidden_dim, hidden_size, bias=ops_bias, dtype=dtype, device=device),
)
# Text stream
self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.txt_attn = SelfAttention(hidden_size, num_heads, qkv_bias, dtype=dtype, device=device, operations=operations, ops_bias=ops_bias)
self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
if gated_mlp:
self.txt_mlp = nn.Sequential(
operations.Linear(hidden_size, mlp_hidden_dim, bias=ops_bias, dtype=dtype, device=device),
nn.Identity(),
operations.Linear(mlp_intermediate, hidden_size, bias=ops_bias, dtype=dtype, device=device),
)
else:
self.txt_mlp = nn.Sequential(
operations.Linear(hidden_size, mlp_hidden_dim, bias=ops_bias, dtype=dtype, device=device),
nn.SiLU() if silu_mlp else nn.GELU(approximate="tanh"),
operations.Linear(mlp_hidden_dim, hidden_size, bias=ops_bias, dtype=dtype, device=device),
)
def forward(
self,
img: torch.Tensor,
txt: torch.Tensor,
vec: torch.Tensor,
pe: torch.Tensor,
attn_mask=None,
img_mod: tuple = None, # (img_mod1, img_mod2) from global modulation
txt_mod: tuple = None, # (txt_mod1, txt_mod2) from global modulation
) -> tuple[torch.Tensor, torch.Tensor]:
# Get modulation parameters
if self.use_global_modulation and img_mod is not None and txt_mod is not None:
# Use global modulation passed from model level
img_mod1, img_mod2 = img_mod
txt_mod1, txt_mod2 = txt_mod
elif self.img_mod is not None and self.txt_mod is not None:
# Use per-block modulation (Flux1 style)
img_mod1, img_mod2 = self.img_mod(vec)
txt_mod1, txt_mod2 = self.txt_mod(vec)
else:
raise ValueError("No modulation available - either provide global or use per-block modulation")
# Prepare normed inputs
img_normed = self.img_norm1(img)
img_modulated = (1 + img_mod1.scale) * img_normed + img_mod1.shift
del img_normed # Free memory early
txt_normed = self.txt_norm1(txt)
txt_modulated = (1 + txt_mod1.scale) * txt_normed + txt_mod1.shift
del txt_normed # Free memory early
# Run joint attention - use view+permute for efficiency instead of rearrange
img_qkv = self.img_attn.qkv(img_modulated)
del img_modulated
q_img, k_img, v_img = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
del img_qkv
txt_qkv = self.txt_attn.qkv(txt_modulated)
del txt_modulated
q_txt, k_txt, v_txt = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
del txt_qkv
q_img, k_img = self.img_attn.norm(q_img, k_img, v_img)
q_txt, k_txt = self.txt_attn.norm(q_txt, k_txt, v_txt)
# Concatenate for joint attention
q = torch.cat((q_txt, q_img), dim=2)
del q_txt, q_img
k = torch.cat((k_txt, k_img), dim=2)
del k_txt, k_img
v = torch.cat((v_txt, v_img), dim=2)
del v_txt, v_img
attn_out = attention(q, k, v, pe=pe, mask=attn_mask)
del q, k, v
txt_attn, img_attn = attn_out[:, : txt.shape[1]], attn_out[:, txt.shape[1] :]
del attn_out
# Apply residual connections with gating
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
del img_attn
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
del txt_attn
# MLP with modulation
img_mlp_in = (1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift
img = img + img_mod2.gate * self._forward_mlp(self.img_mlp, img_mlp_in)
del img_mlp_in
txt_mlp_in = (1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift
txt = txt + txt_mod2.gate * self._forward_mlp(self.txt_mlp, txt_mlp_in)
del txt_mlp_in
# Handle fp16 numerical issues (matches ComfyUI exactly)
if txt.dtype == torch.float16:
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
return img, txt
def _forward_mlp(self, mlp: nn.Sequential, x: torch.Tensor) -> torch.Tensor:
"""Forward through MLP, handling both standard and gated variants."""
if self.gated_mlp:
# Gated MLP: split into gate and up, apply SiLU to gate, multiply, project
gate_up = mlp[0](x)
gate, up = gate_up.chunk(2, dim=-1)
hidden = F.silu(gate) * up
return mlp[2](hidden)
else:
return mlp(x)
class SingleStreamBlock(nn.Module):
"""Transformer block with merged image and text stream.
Used after the double stream blocks have processed both modalities.
"""
def __init__(
self,
hidden_size: int,
num_heads: int,
mlp_ratio: float = 4.0,
qk_scale: float = None,
dtype=None,
device=None,
operations=None,
silu_mlp: bool = False,
gated_mlp: bool = False,
ops_bias: bool = True,
global_modulation: bool = False,
):
super().__init__()
if operations is None:
operations = get_ops()
self.hidden_dim = hidden_size
self.num_heads = num_heads
head_dim = hidden_size // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.silu_mlp = silu_mlp
self.gated_mlp = gated_mlp
self.use_global_modulation = global_modulation
# For gated MLP, mlp_ratio gives intermediate size
# linear1 outputs gate+up (2x intermediate), linear2 takes intermediate
if gated_mlp:
self.mlp_intermediate = int(hidden_size * mlp_ratio)
self.mlp_gate_up_dim = self.mlp_intermediate * 2
linear1_out = hidden_size * 3 + self.mlp_gate_up_dim
linear2_in = hidden_size + self.mlp_intermediate
else:
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
linear1_out = hidden_size * 3 + self.mlp_hidden_dim
linear2_in = hidden_size + self.mlp_hidden_dim
# Joint QKV and MLP projection
self.linear1 = operations.Linear(
hidden_size, linear1_out, bias=ops_bias, dtype=dtype, device=device
)
self.linear2 = operations.Linear(
linear2_in, hidden_size, bias=ops_bias, dtype=dtype, device=device
)
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
self.hidden_size = hidden_size
self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
# Only create per-block modulation if not using global modulation
if not global_modulation:
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations, ops_bias=ops_bias)
else:
self.modulation = None
def forward(
self,
x: torch.Tensor,
vec: torch.Tensor,
pe: torch.Tensor,
attn_mask=None,
modulation=None, # ModulationOut from global modulation
) -> torch.Tensor:
# Get modulation
if self.use_global_modulation and modulation is not None:
mod = modulation
elif self.modulation is not None:
mod, _ = self.modulation(vec)
else:
raise ValueError("No modulation available - either provide global or use per-block modulation")
x_normed = self.pre_norm(x)
x_mod = (1 + mod.scale) * x_normed + mod.shift
del x_normed # Free memory early
# Joint projection - split QKV from MLP part
qkv_mlp = self.linear1(x_mod)
del x_mod
if self.gated_mlp:
qkv, mlp_gate_up = qkv_mlp.split([self.hidden_size * 3, self.mlp_gate_up_dim], dim=-1)
del qkv_mlp
# Gated MLP: split into gate and up, apply SiLU to gate, multiply
gate, up = mlp_gate_up.chunk(2, dim=-1)
del mlp_gate_up
mlp = F.silu(gate) * up
del gate, up
else:
qkv, mlp = qkv_mlp.split([self.hidden_size * 3, self.mlp_hidden_dim], dim=-1)
del qkv_mlp
# Standard activation
if self.silu_mlp:
mlp = F.silu(mlp)
else:
mlp = F.gelu(mlp, approximate="tanh")
# Attention - use view+permute for efficiency instead of rearrange
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
del qkv
q, k = self.norm(q, k, v)
attn = attention(q, k, v, pe=pe, mask=attn_mask)
del q, k, v
# Combine and project
output = self.linear2(torch.cat((attn, mlp), dim=-1))
del attn, mlp
result = x + mod.gate * output
# Handle fp16 numerical issues (matches ComfyUI exactly)
if result.dtype == torch.float16:
result = torch.nan_to_num(result, nan=0.0, posinf=65504, neginf=-65504)
return result
class LastLayer(nn.Module):
"""Final layer for unpatchifying and producing output."""
def __init__(self, hidden_size: int, patch_size: int, out_channels: int, dtype=None, device=None, operations=None, ops_bias: bool = True):
super().__init__()
if operations is None:
operations = get_ops()
self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.linear = operations.Linear(
hidden_size, patch_size * patch_size * out_channels, bias=ops_bias, dtype=dtype, device=device
)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
operations.Linear(hidden_size, 2 * hidden_size, bias=ops_bias, dtype=dtype, device=device),
)
def forward(self, x: torch.Tensor, vec: torch.Tensor) -> torch.Tensor:
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=-1)
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
x = self.linear(x)
return x