| """IRIS: Complete model — patchify, refinement core, unpatchify, tiny decoder.""" |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import math |
| from .core import RefinementCore |
|
|
|
|
| class Patchify(nn.Module): |
| def __init__(self, in_channels=32, dim=512, patch_size=4): |
| super().__init__() |
| self.patch_size = patch_size |
| self.dw_conv = nn.Conv2d(in_channels, in_channels, 3, padding=1, groups=in_channels, bias=True) |
| self.proj = nn.Linear(in_channels * patch_size * patch_size, dim, bias=True) |
|
|
| def forward(self, z): |
| B, C, H, W = z.shape |
| p = self.patch_size |
| orig_dtype = z.dtype |
| |
| with torch.amp.autocast(device_type='cuda', enabled=False): |
| z = self.dw_conv(z.float()) |
| z = z.to(orig_dtype) |
| H_tok, W_tok = H // p, W // p |
| z = z.view(B, C, H_tok, p, W_tok, p).permute(0, 2, 4, 1, 3, 5).reshape(B, H_tok * W_tok, C * p * p) |
| return self.proj(z), H_tok, W_tok |
|
|
|
|
| class Unpatchify(nn.Module): |
| def __init__(self, out_channels=32, dim=512, patch_size=4): |
| super().__init__() |
| self.patch_size = patch_size |
| self.out_channels = out_channels |
| self.proj = nn.Linear(dim, out_channels * patch_size * patch_size, bias=True) |
| self.dw_conv = nn.Conv2d(out_channels, out_channels, 3, padding=1, groups=out_channels, bias=True) |
|
|
| def forward(self, tokens, H_tok, W_tok): |
| B, N, D = tokens.shape |
| p = self.patch_size |
| C = self.out_channels |
| z = self.proj(tokens).view(B, H_tok, W_tok, C, p, p) |
| z = z.permute(0, 3, 1, 4, 2, 5).reshape(B, C, H_tok * p, W_tok * p) |
| |
| orig_dtype = z.dtype |
| with torch.amp.autocast(device_type='cuda', enabled=False): |
| z = self.dw_conv(z.float()) |
| return z.to(orig_dtype) |
|
|
|
|
| class TinyDecoder(nn.Module): |
| """Minimal latent->pixels decoder via PixelShuffle. ~0.1M params.""" |
| def __init__(self, in_channels=32, out_channels=3): |
| super().__init__() |
| self.stages = nn.ModuleList() |
| channels = [in_channels, 32, 32, 16, 8, out_channels] |
| for i in range(5): |
| self.stages.append(nn.Sequential( |
| nn.Conv2d(channels[i], channels[i+1]*4, 3, padding=1, bias=True), |
| nn.PixelShuffle(2), |
| nn.SiLU() if i < 4 else nn.Identity(), |
| )) |
| self.final = nn.Conv2d(out_channels, out_channels, 1, bias=True) |
|
|
| def forward(self, z): |
| |
| orig_dtype = z.dtype |
| with torch.amp.autocast(device_type='cuda', enabled=False): |
| x = z.float() |
| for stage in self.stages: |
| x = stage(x) |
| x = torch.tanh(self.final(x)) |
| return x.to(orig_dtype) |
|
|
|
|
| class IRIS(nn.Module): |
| """ |
| IRIS: Iterative Refinement Image Synthesizer. |
| Predicts velocity v_theta(z_t, t, c) for flow matching. |
| |
| Args: |
| text_dim: dimension of text encoder output. If different from dim, |
| a learned linear projection is applied. Set to 384 for |
| all-MiniLM-L6-v2, 512 for CLIP, etc. Set to None or |
| equal to dim to skip projection. |
| """ |
| def __init__(self, latent_channels=32, dim=512, patch_size=4, num_blocks=6, |
| num_heads=8, max_iterations=8, ffn_expansion=2, |
| gradient_checkpointing=True, text_dim=None): |
| super().__init__() |
| self.latent_channels = latent_channels |
| self.dim = dim |
| self.patch_size = patch_size |
|
|
| self.patchify = Patchify(latent_channels, dim, patch_size) |
| self.unpatchify = Unpatchify(latent_channels, dim, patch_size) |
| spatial_size = 4 |
| self.core = RefinementCore(dim=dim, num_blocks=num_blocks, num_heads=num_heads, |
| spatial_size=spatial_size, max_iterations=max_iterations, |
| ffn_expansion=ffn_expansion, gradient_checkpointing=gradient_checkpointing) |
| self.tiny_decoder = TinyDecoder(latent_channels, out_channels=3) |
|
|
| |
| if text_dim is not None and text_dim != dim: |
| self.context_proj = nn.Linear(text_dim, dim, bias=False) |
| else: |
| self.context_proj = None |
|
|
| self._init_weights() |
|
|
| def _init_weights(self): |
| for m in self.modules(): |
| if isinstance(m, nn.Linear): |
| nn.init.xavier_uniform_(m.weight) |
| if m.bias is not None: nn.init.zeros_(m.bias) |
| elif isinstance(m, nn.Conv2d): |
| nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") |
| if m.bias is not None: nn.init.zeros_(m.bias) |
| elif isinstance(m, (nn.LayerNorm, nn.GroupNorm)): |
| if m.weight is not None: nn.init.ones_(m.weight) |
| if m.bias is not None: nn.init.zeros_(m.bias) |
| nn.init.zeros_(self.unpatchify.proj.weight) |
| nn.init.zeros_(self.unpatchify.proj.bias) |
|
|
| def forward(self, z_t, t, context, num_iterations=4): |
| tokens, H_tok, W_tok = self.patchify(z_t) |
|
|
| |
| if self.context_proj is not None: |
| context = self.context_proj(context) |
| elif context.shape[-1] != self.dim: |
| |
| if not hasattr(self, '_lazy_context_proj'): |
| self._lazy_context_proj = nn.Linear( |
| context.shape[-1], self.dim, bias=False |
| ).to(context.device, context.dtype) |
| context = self._lazy_context_proj(context) |
|
|
| refined = self.core(tokens, context, t, H_tok, W_tok, num_iterations=num_iterations) |
| return self.unpatchify(refined, H_tok, W_tok) |
|
|
| def decode_latent(self, z): |
| return self.tiny_decoder(z) |
|
|
| def count_params(self): |
| counts = {} |
| for name, module in self.named_children(): |
| counts[name] = sum(p.numel() for p in module.parameters()) |
| counts["total"] = sum(p.numel() for p in self.parameters()) |
| counts["trainable"] = sum(p.numel() for p in self.parameters() if p.requires_grad) |
| return counts |
|
|