iris-image-gen / iris /model.py
asdf98's picture
Fix conv2d bf16 crash on T4: iris/model.py
654d061 verified
"""IRIS: Complete model — patchify, refinement core, unpatchify, tiny decoder."""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from .core import RefinementCore
class Patchify(nn.Module):
def __init__(self, in_channels=32, dim=512, patch_size=4):
super().__init__()
self.patch_size = patch_size
self.dw_conv = nn.Conv2d(in_channels, in_channels, 3, padding=1, groups=in_channels, bias=True)
self.proj = nn.Linear(in_channels * patch_size * patch_size, dim, bias=True)
def forward(self, z):
B, C, H, W = z.shape
p = self.patch_size
orig_dtype = z.dtype
# Run grouped conv in float32 — cuDNN lacks bf16 kernels for grouped convs on T4
with torch.amp.autocast(device_type='cuda', enabled=False):
z = self.dw_conv(z.float())
z = z.to(orig_dtype)
H_tok, W_tok = H // p, W // p
z = z.view(B, C, H_tok, p, W_tok, p).permute(0, 2, 4, 1, 3, 5).reshape(B, H_tok * W_tok, C * p * p)
return self.proj(z), H_tok, W_tok
class Unpatchify(nn.Module):
def __init__(self, out_channels=32, dim=512, patch_size=4):
super().__init__()
self.patch_size = patch_size
self.out_channels = out_channels
self.proj = nn.Linear(dim, out_channels * patch_size * patch_size, bias=True)
self.dw_conv = nn.Conv2d(out_channels, out_channels, 3, padding=1, groups=out_channels, bias=True)
def forward(self, tokens, H_tok, W_tok):
B, N, D = tokens.shape
p = self.patch_size
C = self.out_channels
z = self.proj(tokens).view(B, H_tok, W_tok, C, p, p)
z = z.permute(0, 3, 1, 4, 2, 5).reshape(B, C, H_tok * p, W_tok * p)
# Run grouped conv in float32 — cuDNN lacks bf16 kernels for grouped convs on T4
orig_dtype = z.dtype
with torch.amp.autocast(device_type='cuda', enabled=False):
z = self.dw_conv(z.float())
return z.to(orig_dtype)
class TinyDecoder(nn.Module):
"""Minimal latent->pixels decoder via PixelShuffle. ~0.1M params."""
def __init__(self, in_channels=32, out_channels=3):
super().__init__()
self.stages = nn.ModuleList()
channels = [in_channels, 32, 32, 16, 8, out_channels]
for i in range(5):
self.stages.append(nn.Sequential(
nn.Conv2d(channels[i], channels[i+1]*4, 3, padding=1, bias=True),
nn.PixelShuffle(2),
nn.SiLU() if i < 4 else nn.Identity(),
))
self.final = nn.Conv2d(out_channels, out_channels, 1, bias=True)
def forward(self, z):
# Run decoder convs in float32 — cuDNN lacks bf16 kernels on T4
orig_dtype = z.dtype
with torch.amp.autocast(device_type='cuda', enabled=False):
x = z.float()
for stage in self.stages:
x = stage(x)
x = torch.tanh(self.final(x))
return x.to(orig_dtype)
class IRIS(nn.Module):
"""
IRIS: Iterative Refinement Image Synthesizer.
Predicts velocity v_theta(z_t, t, c) for flow matching.
Args:
text_dim: dimension of text encoder output. If different from dim,
a learned linear projection is applied. Set to 384 for
all-MiniLM-L6-v2, 512 for CLIP, etc. Set to None or
equal to dim to skip projection.
"""
def __init__(self, latent_channels=32, dim=512, patch_size=4, num_blocks=6,
num_heads=8, max_iterations=8, ffn_expansion=2,
gradient_checkpointing=True, text_dim=None):
super().__init__()
self.latent_channels = latent_channels
self.dim = dim
self.patch_size = patch_size
self.patchify = Patchify(latent_channels, dim, patch_size)
self.unpatchify = Unpatchify(latent_channels, dim, patch_size)
spatial_size = 4 # default for 16x16 latent with ps=4
self.core = RefinementCore(dim=dim, num_blocks=num_blocks, num_heads=num_heads,
spatial_size=spatial_size, max_iterations=max_iterations,
ffn_expansion=ffn_expansion, gradient_checkpointing=gradient_checkpointing)
self.tiny_decoder = TinyDecoder(latent_channels, out_channels=3)
# Text projection: maps text encoder dim to model dim if they differ
if text_dim is not None and text_dim != dim:
self.context_proj = nn.Linear(text_dim, dim, bias=False)
else:
self.context_proj = None
self._init_weights()
def _init_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None: nn.init.zeros_(m.bias)
elif isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
if m.bias is not None: nn.init.zeros_(m.bias)
elif isinstance(m, (nn.LayerNorm, nn.GroupNorm)):
if m.weight is not None: nn.init.ones_(m.weight)
if m.bias is not None: nn.init.zeros_(m.bias)
nn.init.zeros_(self.unpatchify.proj.weight)
nn.init.zeros_(self.unpatchify.proj.bias)
def forward(self, z_t, t, context, num_iterations=4):
tokens, H_tok, W_tok = self.patchify(z_t)
# Project text embeddings to model dim if needed
if self.context_proj is not None:
context = self.context_proj(context)
elif context.shape[-1] != self.dim:
# Fallback: lazy projection for backwards compat
if not hasattr(self, '_lazy_context_proj'):
self._lazy_context_proj = nn.Linear(
context.shape[-1], self.dim, bias=False
).to(context.device, context.dtype)
context = self._lazy_context_proj(context)
refined = self.core(tokens, context, t, H_tok, W_tok, num_iterations=num_iterations)
return self.unpatchify(refined, H_tok, W_tok)
def decode_latent(self, z):
return self.tiny_decoder(z)
def count_params(self):
counts = {}
for name, module in self.named_children():
counts[name] = sum(p.numel() for p in module.parameters())
counts["total"] = sum(p.numel() for p in self.parameters())
counts["trainable"] = sum(p.numel() for p in self.parameters() if p.requires_grad)
return counts