Spaces:
Running on Zero
Running on Zero
| """Flux2 transformer layers for LightDiffusion-Next. | |
| Core building blocks for the Flux2 architecture: | |
| - Attention mechanisms | |
| - Modulation layers | |
| - Transformer blocks (double and single stream) | |
| - Embedding layers | |
| Adapted from ComfyUI's Flux implementation for LightDiffusion-Next. | |
| """ | |
| import math | |
| from dataclasses import dataclass | |
| from typing import Optional | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from einops import rearrange | |
| from src.cond import cast as ops_module | |
| from src.Device import Device | |
| # Get operations module | |
| def get_ops(): | |
| """Get the operations module for weight initialization.""" | |
| return ops_module.disable_weight_init | |
| class RMSNorm(nn.Module): | |
| """Root Mean Square Layer Normalization. | |
| Uses native PyTorch rms_norm when available for numerical consistency with ComfyUI. | |
| """ | |
| def __init__(self, dim: int, eps: float = 1e-6, dtype=None, device=None): | |
| super().__init__() | |
| self.eps = eps | |
| # Use 'scale' to match Flux2 checkpoint naming convention | |
| self.scale = nn.Parameter(torch.ones(dim, dtype=dtype, device=device)) | |
| # Check if native rms_norm is available | |
| self._use_native = hasattr(torch.nn.functional, 'rms_norm') | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| # Ensure scale is on the same device as input | |
| scale = self.scale.to(x.device, x.dtype) | |
| if self._use_native and not (torch.jit.is_tracing() or torch.jit.is_scripting()): | |
| # Use native PyTorch rms_norm for better precision (matches ComfyUI) | |
| return torch.nn.functional.rms_norm(x, scale.shape, weight=scale, eps=self.eps) | |
| else: | |
| # Fallback implementation | |
| rms = torch.rsqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps) | |
| return x * rms * scale | |
| class EmbedND(nn.Module): | |
| """N-dimensional positional embedding using RoPE.""" | |
| def __init__(self, dim: int, theta: int, axes_dim: list[int]): | |
| super().__init__() | |
| self.dim = dim | |
| self.theta = theta | |
| self.axes_dim = axes_dim | |
| def forward(self, ids: torch.Tensor) -> torch.Tensor: | |
| """Compute rotary positional embeddings. | |
| Args: | |
| ids: Position IDs tensor of shape [batch, seq_len, num_axes] | |
| Returns: | |
| Rotary embeddings of shape [batch, seq_len, dim] | |
| """ | |
| n_axes = ids.shape[-1] | |
| emb = torch.cat( | |
| [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], | |
| dim=-3, | |
| ) | |
| return emb.unsqueeze(1) | |
| def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor: | |
| """Compute rotary position embeddings. | |
| Matches ComfyUI's implementation exactly for numerical precision. | |
| Args: | |
| pos: Position indices | |
| dim: Embedding dimension | |
| theta: Base frequency | |
| Returns: | |
| Rotary embeddings as float32 concatenation of cos and sin | |
| """ | |
| assert dim % 2 == 0 | |
| device = pos.device | |
| # ComfyUI uses float64 for scale calculation for maximum precision | |
| scale = torch.linspace(0, (dim - 2) / dim, dim // 2, dtype=torch.float64, device=device) | |
| omega = 1.0 / (theta ** scale) | |
| # Einsum for position-frequency interaction - cast pos to float32 like ComfyUI | |
| out = torch.einsum("...n,d->...nd", pos.to(dtype=torch.float32, device=device), omega) | |
| out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1) | |
| out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) | |
| # ComfyUI always returns float32 for RoPE embeddings | |
| return out.to(dtype=torch.float32, device=pos.device) | |
| class MLPEmbedder(nn.Module): | |
| """MLP for timestep and guidance embeddings.""" | |
| def __init__(self, in_dim: int, hidden_dim: int, dtype=None, device=None, operations=None, ops_bias: bool = True): | |
| super().__init__() | |
| if operations is None: | |
| operations = get_ops() | |
| self.in_layer = operations.Linear(in_dim, hidden_dim, bias=ops_bias, dtype=dtype, device=device) | |
| self.out_layer = operations.Linear(hidden_dim, hidden_dim, bias=ops_bias, dtype=dtype, device=device) | |
| self.silu = nn.SiLU() | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return self.out_layer(self.silu(self.in_layer(x))) | |
| class GatedMLP(nn.Module): | |
| """Gated MLP (SwiGLU) for Klein models. | |
| Structure: hidden -> 2*intermediate -> SiLU gate -> intermediate -> hidden | |
| The first linear produces gate and value activations, | |
| SiLU is applied to gate, then gate * value, then final projection. | |
| """ | |
| def __init__(self, hidden_size: int, intermediate_size: int, dtype=None, device=None, operations=None, ops_bias: bool = True): | |
| super().__init__() | |
| if operations is None: | |
| operations = get_ops() | |
| # First layer outputs 2x intermediate for gating | |
| self.gate_up_proj = operations.Linear(hidden_size, intermediate_size * 2, bias=ops_bias, dtype=dtype, device=device) | |
| self.down_proj = operations.Linear(intermediate_size, hidden_size, bias=ops_bias, dtype=dtype, device=device) | |
| self.act = nn.SiLU() | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| gate_up = self.gate_up_proj(x) | |
| gate, up = gate_up.chunk(2, dim=-1) | |
| return self.down_proj(self.act(gate) * up) | |
| class QKNorm(nn.Module): | |
| """Query-Key normalization layer.""" | |
| def __init__(self, dim: int, dtype=None, device=None, operations=None): | |
| super().__init__() | |
| # Use native RMSNorm instead of operations.RMSNorm | |
| self.query_norm = RMSNorm(dim, dtype=dtype, device=device) | |
| self.key_norm = RMSNorm(dim, dtype=dtype, device=device) | |
| def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): | |
| q = self.query_norm(q) | |
| k = self.key_norm(k) | |
| # Cast to v's dtype and device to match ComfyUI (crucial for numerical consistency) | |
| return q.to(v), k.to(v) | |
| class SelfAttention(nn.Module): | |
| """Self-attention with rotary position embedding (RoPE).""" | |
| def __init__( | |
| self, | |
| dim: int, | |
| num_heads: int = 8, | |
| qkv_bias: bool = False, | |
| dtype=None, | |
| device=None, | |
| operations=None, | |
| ops_bias: bool = True, | |
| ): | |
| super().__init__() | |
| if operations is None: | |
| operations = get_ops() | |
| self.num_heads = num_heads | |
| head_dim = dim // num_heads | |
| self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device) | |
| self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations) | |
| self.proj = operations.Linear(dim, dim, bias=ops_bias, dtype=dtype, device=device) | |
| def forward(self, x: torch.Tensor, pe: torch.Tensor) -> torch.Tensor: | |
| qkv = self.qkv(x) | |
| q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) | |
| q, k = self.norm(q, k, v) | |
| x = attention(q, k, v, pe=pe) | |
| x = self.proj(x) | |
| return x | |
| def attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, pe: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor: | |
| """Apply attention with rotary position embeddings. | |
| Args: | |
| q: Query tensor [batch, heads, seq, dim] | |
| k: Key tensor [batch, heads, seq, dim] | |
| v: Value tensor [batch, heads, seq, dim] | |
| pe: Positional embeddings | |
| mask: Optional attention mask for padding tokens | |
| Returns: | |
| Attention output [batch, seq, heads*dim] | |
| """ | |
| # Validate positional embedding sequence length to prevent RoPE shape errors | |
| if pe is not None: | |
| try: | |
| pe_seq = pe.shape[2] if pe.ndim >= 3 else None | |
| if pe_seq not in (1, q.shape[2]): | |
| raise ValueError( | |
| f"RoPE sequence length mismatch: pe.seq={pe_seq} != q.seq={q.shape[2]}. " | |
| "Transformer options (img_h/img_w) may not match the input token grid; check calc_cond_batch merging of transformer_options." | |
| ) | |
| except Exception: | |
| # Re-raise as a clear ValueError for easier debugging | |
| raise | |
| q, k = apply_rope(q, k, pe) | |
| # Efficient attention implementation | |
| heads = q.shape[1] | |
| x = optimized_attention(q, k, v, heads, mask=mask) | |
| return x | |
| def apply_rope1(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: | |
| """Apply rotary position embedding to a single tensor. | |
| Correctly applies the 2x2 rotation matrix: | |
| y1 = x1 * cos - x2 * sin | |
| y2 = x1 * sin + x2 * cos | |
| Args: | |
| x: Input tensor [batch, heads, seq, dim] | |
| freqs_cis: Frequency tensor [batch, 1, seq, dim//2, 2, 2] | |
| Returns: | |
| Rotated tensor [batch, heads, seq, dim] | |
| """ | |
| # Reshape x to match RoPE components [batch, heads, seq, dim//2, 2] | |
| x_reshaped = x.reshape(*x.shape[:-1], -1, 2) | |
| # Handle differing sequence lengths between x and freqs_cis | |
| # freqs_cis shape: [batch, 1, seq_pe, dim//2, 2, 2] | |
| seq_x = x.shape[2] | |
| seq_pe = freqs_cis.shape[2] | |
| if seq_pe != seq_x: | |
| if seq_pe < seq_x: | |
| # Upsample by repeating along sequence dimension then slice to exact length | |
| repeat = (seq_x + seq_pe - 1) // seq_pe | |
| freqs_cis = freqs_cis.repeat_interleave(repeat, dim=2)[..., :seq_x, :, :, :] | |
| else: | |
| # Slice to match x sequence length | |
| freqs_cis = freqs_cis[..., :seq_x, :, :, :] | |
| # Sanity-check: feature dimension (half of head dim) must match freqs_cis | |
| feat_half = x.shape[-1] // 2 | |
| if freqs_cis.shape[-3] != feat_half: | |
| raise ValueError( | |
| f"RoPE feature-dim mismatch: freqs_cis.dim={freqs_cis.shape[-3]} != x.dim/2={feat_half}. " | |
| f"x.shape={x.shape}, freqs_cis.shape={freqs_cis.shape}" | |
| ) | |
| # Extract rotation matrix components | |
| # freqs_cis is [..., dim//2, row, col] | |
| # row 0: [cos, -sin] | |
| # row 1: [sin, cos] | |
| cos = freqs_cis[..., 0, 0] | |
| msin = freqs_cis[..., 0, 1] # -sin | |
| sin = freqs_cis[..., 1, 0] | |
| x1 = x_reshaped[..., 0] | |
| x2 = x_reshaped[..., 1] | |
| # Apply rotation | |
| out1 = x1 * cos + x2 * msin | |
| out2 = x1 * sin + x2 * cos | |
| # Combine and reshape back to original | |
| return torch.stack([out1, out2], dim=-1).reshape(*x.shape).type_as(x) | |
| def apply_rope(q: torch.Tensor, k: torch.Tensor, pe: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: | |
| """Apply rotary position embeddings to queries and keys. | |
| Args: | |
| q: Query tensor [batch, heads, seq, dim] | |
| k: Key tensor [batch, heads, seq, dim] | |
| pe: Positional embeddings [..., dim//2, 2, 2] | |
| Returns: | |
| Rotated (q, k) tensors | |
| """ | |
| return apply_rope1(q, pe), apply_rope1(k, pe) | |
| def optimized_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, heads: int, mask: torch.Tensor = None) -> torch.Tensor: | |
| """Optimized attention using Flash/SDPA with fallback to xformers. | |
| Performance priority: cuDNN > Flash > SDPA > xformers > naive | |
| Uses SDPA backend priority from Device module for optimal dispatch. | |
| """ | |
| b, _, seq_q, dim = q.shape | |
| _, _, seq_kv, _ = k.shape | |
| # Method 1: Use native scaled_dot_product_attention with backend priority | |
| # This is the fastest path on modern PyTorch with GPU support | |
| if hasattr(torch.nn.functional, 'scaled_dot_product_attention'): | |
| try: | |
| # Get SDPA backend priority context manager from Device | |
| sdpa_context = Device.get_sdpa_context() | |
| # Process attention mask for SDPA if provided | |
| attn_mask = None | |
| if mask is not None: | |
| # Add dimensions as needed: [B, L] -> [B, 1, 1, L] for broadcasting | |
| if mask.ndim == 2: | |
| attn_mask = mask.unsqueeze(1).unsqueeze(1) # [B, 1, 1, L] | |
| elif mask.ndim == 3: | |
| attn_mask = mask.unsqueeze(1) # [B, 1, L, L] | |
| else: | |
| attn_mask = mask | |
| # Convert mask to additive form (0 for attend, -inf for mask) | |
| # Input mask is 1 for valid, 0 for invalid (padding) | |
| attn_mask = attn_mask.to(dtype=q.dtype) | |
| attn_mask = (1.0 - attn_mask) * torch.finfo(q.dtype).min | |
| # SDPA expects [batch, heads, seq, dim] - q/k/v are already in this format | |
| with sdpa_context: | |
| out = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) | |
| # Reshape: [batch, heads, seq, dim] -> [batch, seq, heads*dim] | |
| # Use transpose + view for efficiency (avoid copy) | |
| out = out.transpose(1, 2).reshape(b, seq_q, -1) | |
| return out | |
| except Exception: | |
| pass # Fall through to xformers | |
| # Method 2: Use xformers memory-efficient attention | |
| if Device.xformers_enabled(): | |
| try: | |
| import xformers.ops as xops | |
| # xformers expects [batch, seq, heads, dim] | |
| q_xf = q.transpose(1, 2).contiguous() | |
| k_xf = k.transpose(1, 2).contiguous() | |
| v_xf = v.transpose(1, 2).contiguous() | |
| # Note: xformers has different mask format, conversion would be needed | |
| out = xops.memory_efficient_attention(q_xf, k_xf, v_xf) | |
| del q_xf, k_xf, v_xf # Free memory early | |
| # Reshape: [batch, seq, heads, dim] -> [batch, seq, heads*dim] | |
| out = out.reshape(b, seq_q, -1) | |
| return out | |
| except Exception: | |
| pass # Fall through to naive | |
| # Method 3: Naive implementation (slowest, memory intensive) | |
| out = F.scaled_dot_product_attention(q, k, v) | |
| out = out.transpose(1, 2).reshape(b, seq_q, -1) | |
| return out | |
| class ModulationOut: | |
| """Output of modulation layer.""" | |
| shift: torch.Tensor | |
| scale: torch.Tensor | |
| gate: torch.Tensor | |
| class Modulation(nn.Module): | |
| """Adaptive layer normalization modulation. | |
| Applies shift, scale, and gate from conditioning vector. | |
| """ | |
| def __init__(self, dim: int, double: bool, dtype=None, device=None, operations=None, ops_bias: bool = True): | |
| super().__init__() | |
| if operations is None: | |
| operations = get_ops() | |
| self.is_double = double | |
| self.multiplier = 6 if double else 3 | |
| self.lin = operations.Linear(dim, self.multiplier * dim, bias=ops_bias, dtype=dtype, device=device) | |
| def forward(self, vec: torch.Tensor) -> tuple[ModulationOut, ModulationOut | None]: | |
| out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1) | |
| mod1 = ModulationOut(shift=out[0], scale=out[1], gate=out[2]) | |
| mod2 = ModulationOut(shift=out[3], scale=out[4], gate=out[5]) if self.is_double else None | |
| return mod1, mod2 | |
| class GlobalModulation(nn.Module): | |
| """Global modulation for Flux2 (Klein) double stream blocks.""" | |
| def __init__(self, dim: int, dtype=None, device=None, operations=None, ops_bias: bool = True): | |
| super().__init__() | |
| if operations is None: | |
| operations = get_ops() | |
| # 12 outputs: 6 for img stream, 6 for txt stream | |
| self.lin = operations.Linear(dim, 12 * dim, bias=ops_bias, dtype=dtype, device=device) | |
| def forward(self, vec: torch.Tensor) -> tuple[ModulationOut, ModulationOut, ModulationOut, ModulationOut]: | |
| out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(12, dim=-1) | |
| mod1_img = ModulationOut(shift=out[0], scale=out[1], gate=out[2]) | |
| mod2_img = ModulationOut(shift=out[3], scale=out[4], gate=out[5]) | |
| mod1_txt = ModulationOut(shift=out[6], scale=out[7], gate=out[8]) | |
| mod2_txt = ModulationOut(shift=out[9], scale=out[10], gate=out[11]) | |
| return mod1_img, mod2_img, mod1_txt, mod2_txt | |
| class DoubleStreamBlock(nn.Module): | |
| """Transformer block with separate image and text streams. | |
| Uses joint attention but separate MLPs for image and text. | |
| """ | |
| def __init__( | |
| self, | |
| hidden_size: int, | |
| num_heads: int, | |
| mlp_ratio: float, | |
| qkv_bias: bool = False, | |
| global_modulation: bool = False, | |
| dtype=None, | |
| device=None, | |
| operations=None, | |
| flax_compatible: bool = False, | |
| silu_mlp: bool = False, | |
| gated_mlp: bool = False, | |
| ops_bias: bool = True, # Whether to use bias in linear layers | |
| ): | |
| super().__init__() | |
| if operations is None: | |
| operations = get_ops() | |
| self.hidden_size = hidden_size | |
| self.num_heads = num_heads | |
| self.flax_compatible = flax_compatible | |
| self.silu_mlp = silu_mlp | |
| self.gated_mlp = gated_mlp | |
| # For gated MLP (Klein), mlp_ratio is the true ratio | |
| # First layer outputs 2x for gating: hidden -> 2*intermediate | |
| # Second layer: intermediate -> hidden | |
| if gated_mlp: | |
| mlp_intermediate = int(hidden_size * mlp_ratio) | |
| mlp_hidden_dim = mlp_intermediate * 2 # Double for gate+up projection | |
| else: | |
| mlp_hidden_dim = int(hidden_size * mlp_ratio) | |
| mlp_intermediate = mlp_hidden_dim | |
| if global_modulation: | |
| # When using global modulation at model level, don't create per-block modulation | |
| self.double_stream_modulation = None | |
| self.img_mod = None | |
| self.txt_mod = None | |
| self.use_global_modulation = True | |
| else: | |
| self.double_stream_modulation = None | |
| self.img_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations, ops_bias=ops_bias) | |
| self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations, ops_bias=ops_bias) | |
| self.use_global_modulation = False | |
| # Image stream | |
| self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) | |
| self.img_attn = SelfAttention(hidden_size, num_heads, qkv_bias, dtype=dtype, device=device, operations=operations, ops_bias=ops_bias) | |
| self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) | |
| if gated_mlp: | |
| # Gated MLP with naming compatible with checkpoint: .0, .1 (identity), .2 | |
| self.img_mlp = nn.Sequential( | |
| operations.Linear(hidden_size, mlp_hidden_dim, bias=ops_bias, dtype=dtype, device=device), | |
| nn.Identity(), # Placeholder for index 1 | |
| operations.Linear(mlp_intermediate, hidden_size, bias=ops_bias, dtype=dtype, device=device), | |
| ) | |
| else: | |
| self.img_mlp = nn.Sequential( | |
| operations.Linear(hidden_size, mlp_hidden_dim, bias=ops_bias, dtype=dtype, device=device), | |
| nn.SiLU() if silu_mlp else nn.GELU(approximate="tanh"), | |
| operations.Linear(mlp_hidden_dim, hidden_size, bias=ops_bias, dtype=dtype, device=device), | |
| ) | |
| # Text stream | |
| self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) | |
| self.txt_attn = SelfAttention(hidden_size, num_heads, qkv_bias, dtype=dtype, device=device, operations=operations, ops_bias=ops_bias) | |
| self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) | |
| if gated_mlp: | |
| self.txt_mlp = nn.Sequential( | |
| operations.Linear(hidden_size, mlp_hidden_dim, bias=ops_bias, dtype=dtype, device=device), | |
| nn.Identity(), | |
| operations.Linear(mlp_intermediate, hidden_size, bias=ops_bias, dtype=dtype, device=device), | |
| ) | |
| else: | |
| self.txt_mlp = nn.Sequential( | |
| operations.Linear(hidden_size, mlp_hidden_dim, bias=ops_bias, dtype=dtype, device=device), | |
| nn.SiLU() if silu_mlp else nn.GELU(approximate="tanh"), | |
| operations.Linear(mlp_hidden_dim, hidden_size, bias=ops_bias, dtype=dtype, device=device), | |
| ) | |
| def forward( | |
| self, | |
| img: torch.Tensor, | |
| txt: torch.Tensor, | |
| vec: torch.Tensor, | |
| pe: torch.Tensor, | |
| attn_mask=None, | |
| img_mod: tuple = None, # (img_mod1, img_mod2) from global modulation | |
| txt_mod: tuple = None, # (txt_mod1, txt_mod2) from global modulation | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| # Get modulation parameters | |
| if self.use_global_modulation and img_mod is not None and txt_mod is not None: | |
| # Use global modulation passed from model level | |
| img_mod1, img_mod2 = img_mod | |
| txt_mod1, txt_mod2 = txt_mod | |
| elif self.img_mod is not None and self.txt_mod is not None: | |
| # Use per-block modulation (Flux1 style) | |
| img_mod1, img_mod2 = self.img_mod(vec) | |
| txt_mod1, txt_mod2 = self.txt_mod(vec) | |
| else: | |
| raise ValueError("No modulation available - either provide global or use per-block modulation") | |
| # Prepare normed inputs | |
| img_normed = self.img_norm1(img) | |
| img_modulated = (1 + img_mod1.scale) * img_normed + img_mod1.shift | |
| del img_normed # Free memory early | |
| txt_normed = self.txt_norm1(txt) | |
| txt_modulated = (1 + txt_mod1.scale) * txt_normed + txt_mod1.shift | |
| del txt_normed # Free memory early | |
| # Run joint attention - use view+permute for efficiency instead of rearrange | |
| img_qkv = self.img_attn.qkv(img_modulated) | |
| del img_modulated | |
| q_img, k_img, v_img = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) | |
| del img_qkv | |
| txt_qkv = self.txt_attn.qkv(txt_modulated) | |
| del txt_modulated | |
| q_txt, k_txt, v_txt = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) | |
| del txt_qkv | |
| q_img, k_img = self.img_attn.norm(q_img, k_img, v_img) | |
| q_txt, k_txt = self.txt_attn.norm(q_txt, k_txt, v_txt) | |
| # Concatenate for joint attention | |
| q = torch.cat((q_txt, q_img), dim=2) | |
| del q_txt, q_img | |
| k = torch.cat((k_txt, k_img), dim=2) | |
| del k_txt, k_img | |
| v = torch.cat((v_txt, v_img), dim=2) | |
| del v_txt, v_img | |
| attn_out = attention(q, k, v, pe=pe, mask=attn_mask) | |
| del q, k, v | |
| txt_attn, img_attn = attn_out[:, : txt.shape[1]], attn_out[:, txt.shape[1] :] | |
| del attn_out | |
| # Apply residual connections with gating | |
| img = img + img_mod1.gate * self.img_attn.proj(img_attn) | |
| del img_attn | |
| txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn) | |
| del txt_attn | |
| # MLP with modulation | |
| img_mlp_in = (1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift | |
| img = img + img_mod2.gate * self._forward_mlp(self.img_mlp, img_mlp_in) | |
| del img_mlp_in | |
| txt_mlp_in = (1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift | |
| txt = txt + txt_mod2.gate * self._forward_mlp(self.txt_mlp, txt_mlp_in) | |
| del txt_mlp_in | |
| # Handle fp16 numerical issues (matches ComfyUI exactly) | |
| if txt.dtype == torch.float16: | |
| txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504) | |
| return img, txt | |
| def _forward_mlp(self, mlp: nn.Sequential, x: torch.Tensor) -> torch.Tensor: | |
| """Forward through MLP, handling both standard and gated variants.""" | |
| if self.gated_mlp: | |
| # Gated MLP: split into gate and up, apply SiLU to gate, multiply, project | |
| gate_up = mlp[0](x) | |
| gate, up = gate_up.chunk(2, dim=-1) | |
| hidden = F.silu(gate) * up | |
| return mlp[2](hidden) | |
| else: | |
| return mlp(x) | |
| class SingleStreamBlock(nn.Module): | |
| """Transformer block with merged image and text stream. | |
| Used after the double stream blocks have processed both modalities. | |
| """ | |
| def __init__( | |
| self, | |
| hidden_size: int, | |
| num_heads: int, | |
| mlp_ratio: float = 4.0, | |
| qk_scale: float = None, | |
| dtype=None, | |
| device=None, | |
| operations=None, | |
| silu_mlp: bool = False, | |
| gated_mlp: bool = False, | |
| ops_bias: bool = True, | |
| global_modulation: bool = False, | |
| ): | |
| super().__init__() | |
| if operations is None: | |
| operations = get_ops() | |
| self.hidden_dim = hidden_size | |
| self.num_heads = num_heads | |
| head_dim = hidden_size // num_heads | |
| self.scale = qk_scale or head_dim ** -0.5 | |
| self.silu_mlp = silu_mlp | |
| self.gated_mlp = gated_mlp | |
| self.use_global_modulation = global_modulation | |
| # For gated MLP, mlp_ratio gives intermediate size | |
| # linear1 outputs gate+up (2x intermediate), linear2 takes intermediate | |
| if gated_mlp: | |
| self.mlp_intermediate = int(hidden_size * mlp_ratio) | |
| self.mlp_gate_up_dim = self.mlp_intermediate * 2 | |
| linear1_out = hidden_size * 3 + self.mlp_gate_up_dim | |
| linear2_in = hidden_size + self.mlp_intermediate | |
| else: | |
| self.mlp_hidden_dim = int(hidden_size * mlp_ratio) | |
| linear1_out = hidden_size * 3 + self.mlp_hidden_dim | |
| linear2_in = hidden_size + self.mlp_hidden_dim | |
| # Joint QKV and MLP projection | |
| self.linear1 = operations.Linear( | |
| hidden_size, linear1_out, bias=ops_bias, dtype=dtype, device=device | |
| ) | |
| self.linear2 = operations.Linear( | |
| linear2_in, hidden_size, bias=ops_bias, dtype=dtype, device=device | |
| ) | |
| self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations) | |
| self.hidden_size = hidden_size | |
| self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) | |
| # Only create per-block modulation if not using global modulation | |
| if not global_modulation: | |
| self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations, ops_bias=ops_bias) | |
| else: | |
| self.modulation = None | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| vec: torch.Tensor, | |
| pe: torch.Tensor, | |
| attn_mask=None, | |
| modulation=None, # ModulationOut from global modulation | |
| ) -> torch.Tensor: | |
| # Get modulation | |
| if self.use_global_modulation and modulation is not None: | |
| mod = modulation | |
| elif self.modulation is not None: | |
| mod, _ = self.modulation(vec) | |
| else: | |
| raise ValueError("No modulation available - either provide global or use per-block modulation") | |
| x_normed = self.pre_norm(x) | |
| x_mod = (1 + mod.scale) * x_normed + mod.shift | |
| del x_normed # Free memory early | |
| # Joint projection - split QKV from MLP part | |
| qkv_mlp = self.linear1(x_mod) | |
| del x_mod | |
| if self.gated_mlp: | |
| qkv, mlp_gate_up = qkv_mlp.split([self.hidden_size * 3, self.mlp_gate_up_dim], dim=-1) | |
| del qkv_mlp | |
| # Gated MLP: split into gate and up, apply SiLU to gate, multiply | |
| gate, up = mlp_gate_up.chunk(2, dim=-1) | |
| del mlp_gate_up | |
| mlp = F.silu(gate) * up | |
| del gate, up | |
| else: | |
| qkv, mlp = qkv_mlp.split([self.hidden_size * 3, self.mlp_hidden_dim], dim=-1) | |
| del qkv_mlp | |
| # Standard activation | |
| if self.silu_mlp: | |
| mlp = F.silu(mlp) | |
| else: | |
| mlp = F.gelu(mlp, approximate="tanh") | |
| # Attention - use view+permute for efficiency instead of rearrange | |
| q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) | |
| del qkv | |
| q, k = self.norm(q, k, v) | |
| attn = attention(q, k, v, pe=pe, mask=attn_mask) | |
| del q, k, v | |
| # Combine and project | |
| output = self.linear2(torch.cat((attn, mlp), dim=-1)) | |
| del attn, mlp | |
| result = x + mod.gate * output | |
| # Handle fp16 numerical issues (matches ComfyUI exactly) | |
| if result.dtype == torch.float16: | |
| result = torch.nan_to_num(result, nan=0.0, posinf=65504, neginf=-65504) | |
| return result | |
| class LastLayer(nn.Module): | |
| """Final layer for unpatchifying and producing output.""" | |
| def __init__(self, hidden_size: int, patch_size: int, out_channels: int, dtype=None, device=None, operations=None, ops_bias: bool = True): | |
| super().__init__() | |
| if operations is None: | |
| operations = get_ops() | |
| self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) | |
| self.linear = operations.Linear( | |
| hidden_size, patch_size * patch_size * out_channels, bias=ops_bias, dtype=dtype, device=device | |
| ) | |
| self.adaLN_modulation = nn.Sequential( | |
| nn.SiLU(), | |
| operations.Linear(hidden_size, 2 * hidden_size, bias=ops_bias, dtype=dtype, device=device), | |
| ) | |
| def forward(self, x: torch.Tensor, vec: torch.Tensor) -> torch.Tensor: | |
| shift, scale = self.adaLN_modulation(vec).chunk(2, dim=-1) | |
| x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :] | |
| x = self.linear(x) | |
| return x | |