geolip-diffusion-proto / modeling_flow_match.py
AbstractPhil's picture
Update modeling_flow_match.py
39b176f verified
"""
FlowMatchRelay model β€” HuggingFace compatible.
Usage:
from transformers import AutoModel
model = AutoModel.from_pretrained(
"AbstractPhil/geolip-diffusion-proto",
trust_remote_code=True
)
# Generate samples
samples = model.sample(n_samples=8, class_label=3) # 8 cats
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from transformers import PreTrainedModel
from .configuration_flow_match import FlowMatchRelayConfig
# ══════════════════════════════════════════════════════════════════
# CONSTELLATION RELAY
# ══════════════════════════════════════════════════════════════════
class ConstellationRelay(nn.Module):
"""
Geometric regulator for feature maps.
Fixed anchors on S^(d-1), multi-phase stroboscope triangulation,
gated residual correction.
"""
def __init__(self, channels, patch_dim=16, n_anchors=16, n_phases=3,
pw_hidden=32, gate_init=-3.0, mode='channel'):
super().__init__()
assert channels % patch_dim == 0
self.channels = channels
self.patch_dim = patch_dim
self.n_patches = channels // patch_dim
self.n_anchors = n_anchors
self.n_phases = n_phases
self.mode = mode
P, A, d = self.n_patches, n_anchors, patch_dim
home = torch.empty(P, A, d)
nn.init.xavier_normal_(home.view(P * A, d))
home = F.normalize(home.view(P, A, d), dim=-1)
self.register_buffer('home', home)
self.anchors = nn.Parameter(home.clone())
tri_dim = n_phases * A
self.pw_w1 = nn.Parameter(torch.empty(P, tri_dim, pw_hidden))
self.pw_b1 = nn.Parameter(torch.zeros(1, P, pw_hidden))
self.pw_w2 = nn.Parameter(torch.empty(P, pw_hidden, d))
self.pw_b2 = nn.Parameter(torch.zeros(1, P, d))
for p in range(P):
nn.init.xavier_normal_(self.pw_w1.data[p])
nn.init.xavier_normal_(self.pw_w2.data[p])
self.pw_norm = nn.LayerNorm(d)
self.gates = nn.Parameter(torch.full((P,), gate_init))
self.norm = nn.LayerNorm(channels)
def drift(self):
h, c = F.normalize(self.home, dim=-1), F.normalize(self.anchors, dim=-1)
return torch.acos((h * c).sum(-1).clamp(-1 + 1e-7, 1 - 1e-7))
def at_phase(self, t):
h, c = F.normalize(self.home, dim=-1), F.normalize(self.anchors, dim=-1)
omega = self.drift().unsqueeze(-1)
so = omega.sin().clamp(min=1e-7)
return torch.sin((1-t)*omega)/so * h + torch.sin(t*omega)/so * c
def _relay_core(self, x_flat):
N, C = x_flat.shape
P, A, d = self.n_patches, self.n_anchors, self.patch_dim
x_n = self.norm(x_flat)
patches = x_n.reshape(N, P, d)
patches_n = F.normalize(patches, dim=-1)
phases = torch.linspace(0, 1, self.n_phases, device=x_flat.device).tolist()
tris = []
for t in phases:
at = F.normalize(self.at_phase(t), dim=-1)
tris.append(1.0 - torch.einsum('npd,pad->npa', patches_n, at))
tri = torch.cat(tris, dim=-1)
h = F.gelu(torch.einsum('npt,pth->nph', tri, self.pw_w1) + self.pw_b1)
pw = self.pw_norm(torch.einsum('nph,phd->npd', h, self.pw_w2) + self.pw_b2)
g = self.gates.sigmoid().unsqueeze(0).unsqueeze(-1)
blended = g * pw + (1-g) * patches
return x_flat + blended.reshape(N, C)
def forward(self, x):
B, C, H, W = x.shape
if self.mode == 'channel':
pooled = x.mean(dim=(-2, -1))
relayed = self._relay_core(pooled)
scale = (relayed / (pooled + 1e-8)).unsqueeze(-1).unsqueeze(-1)
return x * scale.clamp(-3, 3)
else:
x_flat = x.permute(0, 2, 3, 1).reshape(B * H * W, C)
out = self._relay_core(x_flat)
return out.reshape(B, H, W, C).permute(0, 3, 1, 2)
# ══════════════════════════════════════════════════════════════════
# BUILDING BLOCKS
# ══════════════════════════════════════════════════════════════════
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, t):
half = self.dim // 2
emb = math.log(10000) / (half - 1)
emb = torch.exp(torch.arange(half, device=t.device, dtype=t.dtype) * -emb)
emb = t.unsqueeze(-1) * emb.unsqueeze(0)
return torch.cat([emb.sin(), emb.cos()], dim=-1)
class AdaGroupNorm(nn.Module):
def __init__(self, channels, cond_dim, n_groups=8):
super().__init__()
self.gn = nn.GroupNorm(min(n_groups, channels), channels, affine=False)
self.proj = nn.Linear(cond_dim, channels * 2)
nn.init.zeros_(self.proj.weight)
nn.init.zeros_(self.proj.bias)
def forward(self, x, cond):
x = self.gn(x)
scale, shift = self.proj(cond).unsqueeze(-1).unsqueeze(-1).chunk(2, dim=1)
return x * (1 + scale) + shift
class ConvBlock(nn.Module):
def __init__(self, channels, cond_dim, use_relay=False,
relay_patch_dim=16, relay_n_anchors=16, relay_n_phases=3,
relay_pw_hidden=32, relay_gate_init=-3.0, relay_mode='channel'):
super().__init__()
self.dw_conv = nn.Conv2d(channels, channels, 7, padding=3, groups=channels)
self.norm = AdaGroupNorm(channels, cond_dim)
self.pw1 = nn.Conv2d(channels, channels * 4, 1)
self.pw2 = nn.Conv2d(channels * 4, channels, 1)
self.act = nn.GELU()
self.relay = ConstellationRelay(
channels,
patch_dim=min(relay_patch_dim, channels),
n_anchors=min(relay_n_anchors, channels),
n_phases=relay_n_phases,
pw_hidden=relay_pw_hidden,
gate_init=relay_gate_init,
mode=relay_mode) if use_relay else None
def forward(self, x, cond):
residual = x
x = self.dw_conv(x)
x = self.norm(x, cond)
x = self.pw1(x)
x = self.act(x)
x = self.pw2(x)
x = residual + x
if self.relay is not None:
x = self.relay(x)
return x
class SelfAttnBlock(nn.Module):
def __init__(self, channels, n_heads=4):
super().__init__()
self.n_heads = n_heads
self.head_dim = channels // n_heads
self.norm = nn.GroupNorm(8, channels)
self.qkv = nn.Conv2d(channels, channels * 3, 1)
self.out = nn.Conv2d(channels, channels, 1)
nn.init.zeros_(self.out.weight)
nn.init.zeros_(self.out.bias)
def forward(self, x):
B, C, H, W = x.shape
residual = x
x = self.norm(x)
qkv = self.qkv(x).reshape(B, 3, self.n_heads, self.head_dim, H * W)
q, k, v = qkv[:, 0], qkv[:, 1], qkv[:, 2]
attn = F.scaled_dot_product_attention(q, k, v)
out = attn.reshape(B, C, H, W)
return residual + self.out(out)
class Downsample(nn.Module):
def __init__(self, channels):
super().__init__()
self.conv = nn.Conv2d(channels, channels, 3, stride=2, padding=1)
def forward(self, x):
return self.conv(x)
class Upsample(nn.Module):
def __init__(self, channels):
super().__init__()
self.conv = nn.Conv2d(channels, channels, 3, padding=1)
def forward(self, x):
x = F.interpolate(x, scale_factor=2, mode='nearest')
return self.conv(x)
# ══════════════════════════════════════════════════════════════════
# FLOW MATCHING UNET
# ══════════════════════════════════════════════════════════════════
class FlowMatchUNet(nn.Module):
def __init__(self, config):
super().__init__()
in_channels = config.in_channels
base_channels = config.base_channels
channel_mults = config.channel_mults
n_classes = config.n_classes
cond_dim = config.cond_dim
use_relay = config.use_relay
self.channel_mults = channel_mults
# Relay kwargs
rk = dict(
relay_patch_dim=config.relay_patch_dim,
relay_n_anchors=config.relay_n_anchors,
relay_n_phases=config.relay_n_phases,
relay_pw_hidden=config.relay_pw_hidden,
relay_gate_init=config.relay_gate_init,
relay_mode=config.relay_mode,
)
self.time_emb = nn.Sequential(
SinusoidalPosEmb(cond_dim),
nn.Linear(cond_dim, cond_dim), nn.GELU(),
nn.Linear(cond_dim, cond_dim))
self.class_emb = nn.Embedding(n_classes, cond_dim)
self.in_conv = nn.Conv2d(in_channels, base_channels, 3, padding=1)
# Encoder
self.enc = nn.ModuleList()
self.enc_down = nn.ModuleList()
ch_in = base_channels
enc_channels = [base_channels]
for i, mult in enumerate(channel_mults):
ch_out = base_channels * mult
self.enc.append(nn.ModuleList([
ConvBlock(ch_in, cond_dim) if ch_in == ch_out
else nn.Sequential(nn.Conv2d(ch_in, ch_out, 1),
ConvBlock(ch_out, cond_dim)),
ConvBlock(ch_out, cond_dim),
]))
ch_in = ch_out
enc_channels.append(ch_out)
if i < len(channel_mults) - 1:
self.enc_down.append(Downsample(ch_out))
# Middle
mid_ch = ch_in
self.mid_block1 = ConvBlock(mid_ch, cond_dim, use_relay=use_relay, **rk)
self.mid_attn = SelfAttnBlock(mid_ch, n_heads=4)
self.mid_block2 = ConvBlock(mid_ch, cond_dim, use_relay=use_relay, **rk)
# Decoder
self.dec_up = nn.ModuleList()
self.dec_skip_proj = nn.ModuleList()
self.dec = nn.ModuleList()
for i in range(len(channel_mults) - 1, -1, -1):
ch_out = base_channels * channel_mults[i]
skip_ch = enc_channels.pop()
self.dec_skip_proj.append(nn.Conv2d(ch_in + skip_ch, ch_out, 1))
self.dec.append(nn.ModuleList([
ConvBlock(ch_out, cond_dim),
ConvBlock(ch_out, cond_dim),
]))
ch_in = ch_out
if i > 0:
self.dec_up.append(Upsample(ch_out))
self.out_norm = nn.GroupNorm(8, ch_in)
self.out_conv = nn.Conv2d(ch_in, in_channels, 3, padding=1)
nn.init.zeros_(self.out_conv.weight)
nn.init.zeros_(self.out_conv.bias)
def forward(self, x, t, class_labels):
cond = self.time_emb(t) + self.class_emb(class_labels)
h = self.in_conv(x)
skips = [h]
for i in range(len(self.channel_mults)):
for block in self.enc[i]:
if isinstance(block, ConvBlock):
h = block(h, cond)
elif isinstance(block, nn.Sequential):
h = block[0](h)
h = block[1](h, cond)
else:
h = block(h)
skips.append(h)
if i < len(self.enc_down):
h = self.enc_down[i](h)
h = self.mid_block1(h, cond)
h = self.mid_attn(h)
h = self.mid_block2(h, cond)
for i in range(len(self.channel_mults)):
skip = skips.pop()
if i > 0:
h = self.dec_up[i - 1](h)
h = torch.cat([h, skip], dim=1)
h = self.dec_skip_proj[i](h)
for block in self.dec[i]:
h = block(h, cond)
h = self.out_norm(h)
h = F.silu(h)
return self.out_conv(h)
# ══════════════════════════════════════════════════════════════════
# HUGGINGFACE PRETRAINED MODEL WRAPPER
# ══════════════════════════════════════════════════════════════════
class FlowMatchRelayModel(PreTrainedModel):
"""
HuggingFace-compatible wrapper for flow matching with constellation relay.
Load:
model = AutoModel.from_pretrained(
"AbstractPhil/geolip-diffusion-proto", trust_remote_code=True)
Generate:
images = model.sample(n_samples=8, class_label=3)
"""
config_class = FlowMatchRelayConfig
_tied_weights_keys = []
_keys_to_ignore_on_load_missing = []
_keys_to_ignore_on_load_unexpected = []
_no_split_modules = []
supports_gradient_checkpointing = False
def __init__(self, config):
super().__init__(config)
self.unet = FlowMatchUNet(config)
self.post_init()
def _init_weights(self, module):
"""No-op β€” weights loaded from checkpoint or already initialized."""
pass
def forward(self, x, t, class_labels):
"""
Predict velocity field for flow matching.
Args:
x: (B, 3, H, W) noisy images
t: (B,) timesteps in [0, 1]
class_labels: (B,) integer class labels
Returns:
v_pred: (B, 3, H, W) predicted velocity
"""
return self.unet(x, t, class_labels)
@torch.no_grad()
def sample(self, n_samples=8, n_steps=None, class_label=None, device=None):
"""
Generate images via Euler ODE integration.
Args:
n_samples: number of images to generate
n_steps: ODE integration steps (default from config)
class_label: optional class conditioning (0-9 for CIFAR-10)
device: target device
Returns:
images: (n_samples, 3, 32, 32) in [0, 1]
"""
if device is None:
device = next(self.parameters()).device
if n_steps is None:
n_steps = self.config.n_sample_steps
self.eval()
x = torch.randn(n_samples, self.config.in_channels,
self.config.image_size, self.config.image_size,
device=device)
if class_label is not None:
labels = torch.full((n_samples,), class_label,
dtype=torch.long, device=device)
else:
labels = torch.randint(0, self.config.n_classes,
(n_samples,), device=device)
dt = 1.0 / n_steps
for step in range(n_steps):
t_val = 1.0 - step * dt
t = torch.full((n_samples,), t_val, device=device)
v = self.unet(x, t, labels)
x = x - v * dt
# [-1, 1] β†’ [0, 1]
return (x.clamp(-1, 1) + 1) / 2
def get_relay_diagnostics(self):
"""Report constellation relay drift and gate values."""
diagnostics = {}
for name, module in self.named_modules():
if isinstance(module, ConstellationRelay):
drift = module.drift().mean().item()
gate = module.gates.sigmoid().mean().item()
diagnostics[name] = {
'drift_rad': drift,
'drift_deg': math.degrees(drift),
'gate': gate,
}
return diagnostics