| """Flow Matching training utilities for IRIS.""" |
|
|
| import torch |
| import torch.nn.functional as F |
| import math |
| from typing import Optional |
|
|
| DCAE_F32C32_SCALE = 0.41407 |
|
|
|
|
| def sample_timesteps_logit_normal(batch_size, device, mean=0.0, std=1.0): |
| u = torch.normal(mean=mean, std=std, size=(batch_size,), device=device) |
| return torch.sigmoid(u).clamp(1e-5, 1.0 - 1e-5) |
|
|
|
|
| def sample_timesteps_uniform(batch_size, device): |
| return torch.rand(batch_size, device=device).clamp(1e-5, 1.0 - 1e-5) |
|
|
|
|
| def rectified_flow_forward(z_0, t, noise=None): |
| if noise is None: |
| noise = torch.randn_like(z_0) |
| t_expand = t.view(-1, 1, 1, 1) |
| z_t = t_expand * noise + (1.0 - t_expand) * z_0 |
| target = noise - z_0 |
| return z_t, target |
|
|
|
|
| def flow_matching_loss(model, z_0, context, num_iterations=4, timestep_sampling="logit_normal", scale_factor=DCAE_F32C32_SCALE): |
| B = z_0.shape[0] |
| device = z_0.device |
| z_0_scaled = z_0 * scale_factor |
| t = sample_timesteps_logit_normal(B, device) if timestep_sampling == "logit_normal" else sample_timesteps_uniform(B, device) |
| noise = torch.randn_like(z_0_scaled) |
| z_t, target = rectified_flow_forward(z_0_scaled, t, noise) |
| v_pred = model(z_t, t, context, num_iterations=num_iterations) |
| flow_loss = F.mse_loss(v_pred, target) |
| return {"loss": flow_loss, "flow_loss": flow_loss.detach()} |
|
|
|
|
| @torch.no_grad() |
| def euler_sample(model, noise, context, num_steps=20, num_iterations=4, cfg_scale=1.0, scale_factor=DCAE_F32C32_SCALE): |
| dt = -1.0 / num_steps |
| z_t = noise.clone() |
| for i in range(num_steps): |
| t_val = 1.0 - i / num_steps |
| t = torch.full((noise.shape[0],), t_val, device=noise.device, dtype=noise.dtype) |
| if cfg_scale > 1.0: |
| z_double = torch.cat([z_t, z_t], dim=0) |
| t_double = torch.cat([t, t], dim=0) |
| ctx_double = torch.cat([context, torch.zeros_like(context)], dim=0) |
| v_pred = model(z_double, t_double, ctx_double, num_iterations=num_iterations) |
| v_cond, v_uncond = v_pred.chunk(2, dim=0) |
| v_pred = v_uncond + cfg_scale * (v_cond - v_uncond) |
| else: |
| v_pred = model(z_t, t, context, num_iterations=num_iterations) |
| z_t = z_t + v_pred * dt |
| return z_t / scale_factor |
|
|