| """ |
| Joint Embedding Predictive Architecture (JEPA) for PDE dynamics. |
| |
| Spatial JEPA: encoder produces spatial feature maps, predictor operates |
| on spatial features, loss computed on spatial latent representations. |
| Prevents collapse via VICReg regularization. |
| """ |
| import copy |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| |
| |
| |
|
|
|
|
| class ConvBlock(nn.Module): |
| def __init__(self, in_ch, out_ch, stride=1): |
| super().__init__() |
| self.conv1 = nn.Conv2d(in_ch, out_ch, 3, stride=stride, padding=1) |
| self.bn1 = nn.BatchNorm2d(out_ch) |
| self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1) |
| self.bn2 = nn.BatchNorm2d(out_ch) |
| self.skip = ( |
| nn.Sequential(nn.Conv2d(in_ch, out_ch, 1, stride=stride), nn.BatchNorm2d(out_ch)) |
| if in_ch != out_ch or stride != 1 |
| else nn.Identity() |
| ) |
|
|
| def forward(self, x): |
| h = F.gelu(self.bn1(self.conv1(x))) |
| h = self.bn2(self.conv2(h)) |
| return F.gelu(h + self.skip(x)) |
|
|
|
|
| |
| |
| |
|
|
|
|
| class SpatialEncoder(nn.Module): |
| """ResNet-style encoder outputting spatial latent maps. |
| |
| Input: [B, C_in, H, W] |
| Output: [B, lat_ch, H/8, W/8] |
| """ |
|
|
| def __init__(self, in_channels, latent_channels=128, base_ch=32): |
| super().__init__() |
| self.stem = nn.Sequential( |
| nn.Conv2d(in_channels, base_ch, 3, padding=1), |
| nn.BatchNorm2d(base_ch), |
| nn.GELU(), |
| ) |
| self.layer1 = ConvBlock(base_ch, base_ch * 2, stride=2) |
| self.layer2 = ConvBlock(base_ch * 2, base_ch * 4, stride=2) |
| self.layer3 = ConvBlock(base_ch * 4, latent_channels, stride=2) |
|
|
| def forward(self, x): |
| x = self.stem(x) |
| x = self.layer1(x) |
| x = self.layer2(x) |
| x = self.layer3(x) |
| return x |
|
|
|
|
| |
| |
| |
|
|
|
|
| class SpatialPredictor(nn.Module): |
| """Lightweight CNN predictor on spatial latent maps. |
| |
| Input/Output: [B, lat_ch, H', W'] |
| """ |
|
|
| def __init__(self, latent_channels=128, hidden_channels=256): |
| super().__init__() |
| self.net = nn.Sequential( |
| nn.Conv2d(latent_channels, hidden_channels, 3, padding=1), |
| nn.BatchNorm2d(hidden_channels), |
| nn.GELU(), |
| nn.Conv2d(hidden_channels, hidden_channels, 3, padding=1), |
| nn.BatchNorm2d(hidden_channels), |
| nn.GELU(), |
| nn.Conv2d(hidden_channels, latent_channels, 3, padding=1), |
| ) |
|
|
| def forward(self, x): |
| return self.net(x) |
|
|
|
|
| |
| |
| |
|
|
|
|
| def vicreg_loss(z_pred, z_target, sim_w=25.0, var_w=25.0, cov_w=1.0): |
| """VICReg loss on spatial features (flattened to [B, D]). |
| |
| Args: |
| z_pred: [B, D] predicted latent. |
| z_target: [B, D] target latent (detached). |
| sim_w, var_w, cov_w: loss weights. |
| |
| Returns: |
| total loss, dict of components. |
| """ |
| |
| sim_loss = F.mse_loss(z_pred, z_target) |
|
|
| |
| std_p = torch.sqrt(z_pred.var(dim=0) + 1e-4) |
| std_t = torch.sqrt(z_target.var(dim=0) + 1e-4) |
| var_loss = F.relu(1 - std_p).mean() + F.relu(1 - std_t).mean() |
|
|
| |
| B, D = z_pred.shape |
| zp = z_pred - z_pred.mean(0) |
| zt = z_target - z_target.mean(0) |
| cov_p = (zp.T @ zp) / max(B - 1, 1) |
| cov_t = (zt.T @ zt) / max(B - 1, 1) |
| mask = ~torch.eye(D, device=z_pred.device).bool() |
| cov_loss = cov_p[mask].pow(2).sum() / D + cov_t[mask].pow(2).sum() / D |
|
|
| total = sim_w * sim_loss + var_w * var_loss + cov_w * cov_loss |
| return total, {"sim": sim_loss.item(), "var": var_loss.item(), "cov": cov_loss.item()} |
|
|
|
|
| |
| |
| |
|
|
|
|
| class JEPA(nn.Module): |
| """Spatial JEPA for PDE dynamics prediction. |
| |
| Online encoder + predictor learn to predict the target encoder's |
| representation of the next frame. The target encoder is an EMA |
| copy of the online encoder. |
| |
| Args: |
| in_channels: number of input field channels. |
| latent_channels: spatial latent feature map channels. |
| base_ch: encoder base width. |
| pred_hidden: predictor hidden channels. |
| ema_decay: starting EMA decay. |
| """ |
|
|
| def __init__( |
| self, |
| in_channels, |
| latent_channels=128, |
| base_ch=32, |
| pred_hidden=256, |
| ema_decay=0.996, |
| ): |
| super().__init__() |
| self.online_encoder = SpatialEncoder(in_channels, latent_channels, base_ch) |
| self.predictor = SpatialPredictor(latent_channels, pred_hidden) |
| self.target_encoder = copy.deepcopy(self.online_encoder) |
| self.ema_decay = ema_decay |
|
|
| |
| for p in self.target_encoder.parameters(): |
| p.requires_grad_(False) |
|
|
| @torch.no_grad() |
| def update_target(self): |
| """EMA update of target encoder.""" |
| for pt, po in zip(self.target_encoder.parameters(), self.online_encoder.parameters()): |
| pt.data.lerp_(po.data, 1 - self.ema_decay) |
|
|
| def set_ema_decay(self, decay): |
| """Update EMA decay (e.g. cosine schedule from 0.996 to 1.0).""" |
| self.ema_decay = decay |
|
|
| def forward(self, x_input, x_target): |
| """ |
| Args: |
| x_input: current frame(s) [B, C, H, W] |
| x_target: next frame(s) [B, C, H, W] |
| |
| Returns: |
| z_pred: predicted spatial latent [B, lat_ch, H', W'] |
| z_target: target spatial latent [B, lat_ch, H', W'] |
| """ |
| z_online = self.online_encoder(x_input) |
| z_pred = self.predictor(z_online) |
|
|
| with torch.no_grad(): |
| z_target = self.target_encoder(x_target) |
|
|
| return z_pred, z_target |
|
|
| def compute_loss(self, x_input, x_target): |
| """Full forward + loss computation. |
| |
| VICReg is computed on channel vectors after spatial averaging |
| to keep the covariance matrix small (D = latent_channels). |
| |
| Returns: |
| loss: scalar. |
| metrics: dict. |
| """ |
| z_pred, z_target = self(x_input, x_target) |
|
|
| |
| spatial_mse = F.mse_loss(z_pred, z_target.detach()) |
|
|
| |
| zp_avg = z_pred.mean(dim=(-2, -1)) |
| zt_avg = z_target.mean(dim=(-2, -1)) |
|
|
| vicreg, vicreg_m = vicreg_loss(zp_avg, zt_avg.detach()) |
|
|
| |
| loss = spatial_mse + 0.1 * vicreg |
| metrics = { |
| "sim": vicreg_m["sim"], |
| "var": vicreg_m["var"], |
| "cov": vicreg_m["cov"], |
| "spatial_mse": spatial_mse.item(), |
| } |
| return loss, metrics |
|
|