| from __future__ import annotations |
|
|
| from typing import Any |
|
|
| import torch |
| import torch.nn as nn |
|
|
|
|
| def _choose_gn_groups(channels: int, max_groups: int = 8) -> int: |
| for g in range(min(max_groups, channels), 0, -1): |
| if channels % g == 0: |
| return g |
| return 1 |
|
|
|
|
| class _ChannelGate(nn.Module): |
| def __init__(self, channels: int, reduction: int = 4) -> None: |
| super().__init__() |
| hidden = max(channels // reduction, 8) |
| self.pool = nn.AdaptiveAvgPool3d(1) |
| self.fc1 = nn.Conv3d(channels, hidden, kernel_size=1, bias=True) |
| self.act = nn.GELU() |
| self.fc2 = nn.Conv3d(hidden, channels, kernel_size=1, bias=True) |
| self.gate = nn.Sigmoid() |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| s = self.pool(x) |
| s = self.fc1(s) |
| s = self.act(s) |
| s = self.fc2(s) |
| return x * self.gate(s) |
|
|
|
|
| class _FastHyperBlock(nn.Module): |
| """ |
| Efficient RF-expanding residual block. |
| |
| Each block contributes one effective k=3 receptive-field expansion stage via |
| three parallel branches operating on the same expanded activation: |
| - spatial depthwise (1,3,3) |
| - temporal depthwise (3,1,1) |
| - grouped 3D mixing (3,3,3) |
| """ |
|
|
| def __init__( |
| self, |
| channels: int, |
| mid_dim: int, |
| mix_groups: int = 6, |
| dropout_p: float = 0.02, |
| gate_reduction: int = 4, |
| ) -> None: |
| super().__init__() |
| gn1 = _choose_gn_groups(channels) |
| gn2 = _choose_gn_groups(mid_dim) |
| mix_groups = max(1, min(mix_groups, mid_dim)) |
| while mid_dim % mix_groups != 0 and mix_groups > 1: |
| mix_groups -= 1 |
|
|
| self.pre = nn.Sequential( |
| nn.GroupNorm(gn1, channels), |
| nn.Conv3d(channels, mid_dim, kernel_size=1, bias=True), |
| nn.GELU(), |
| ) |
| self.spatial = nn.Sequential( |
| nn.Conv3d( |
| mid_dim, |
| mid_dim, |
| kernel_size=(1, 3, 3), |
| padding=(0, 1, 1), |
| groups=mid_dim, |
| bias=True, |
| ), |
| nn.GELU(), |
| ) |
| self.temporal = nn.Sequential( |
| nn.Conv3d( |
| mid_dim, |
| mid_dim, |
| kernel_size=(3, 1, 1), |
| padding=(1, 0, 0), |
| groups=mid_dim, |
| bias=True, |
| ), |
| nn.GELU(), |
| ) |
| self.mixed = nn.Sequential( |
| nn.GroupNorm(gn2, mid_dim), |
| nn.Conv3d( |
| mid_dim, |
| mid_dim, |
| kernel_size=3, |
| padding=1, |
| groups=mix_groups, |
| bias=True, |
| ), |
| nn.GELU(), |
| ) |
| self.fuse = nn.Sequential( |
| nn.Conv3d(mid_dim, channels, kernel_size=1, bias=True), |
| nn.GELU(), |
| ) |
| self.gate = _ChannelGate(channels, reduction=gate_reduction) |
| self.dropout = nn.Dropout3d(dropout_p) if dropout_p > 0 else nn.Identity() |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| h = self.pre(x) |
| h = self.spatial(h) + self.temporal(h) + self.mixed(h) |
| h = self.fuse(h) |
| h = self.gate(h) |
| h = self.dropout(h) |
| return x + h |
|
|
|
|
| class PredecoderFastHyperRF13V1(nn.Module): |
| """ |
| Faster-stronger candidate for model 6 under the public Ising-Decoding API. |
| |
| Input / output shape: |
| (B, 4, T, D, D) -> (B, 4, T, D, D) |
| """ |
|
|
| def __init__( |
| self, |
| input_channels: int = 4, |
| out_channels: int = 4, |
| hidden_dim: int = 96, |
| mid_dim: int = 144, |
| mix_groups: int = 6, |
| num_blocks: int = 5, |
| stem_kernel_size: int = 3, |
| dropout_p: float = 0.02, |
| gate_reduction: int = 4, |
| **_: Any, |
| ) -> None: |
| super().__init__() |
| pad = stem_kernel_size // 2 |
| gn = _choose_gn_groups(hidden_dim) |
| self.stem = nn.Sequential( |
| nn.Conv3d( |
| input_channels, |
| hidden_dim, |
| kernel_size=stem_kernel_size, |
| padding=pad, |
| bias=True, |
| ), |
| nn.GroupNorm(gn, hidden_dim), |
| nn.GELU(), |
| ) |
| self.blocks = nn.Sequential(*[ |
| _FastHyperBlock( |
| channels=hidden_dim, |
| mid_dim=mid_dim, |
| mix_groups=mix_groups, |
| dropout_p=dropout_p, |
| gate_reduction=gate_reduction, |
| ) for _ in range(num_blocks) |
| ]) |
| self.head = nn.Sequential( |
| nn.GroupNorm(gn, hidden_dim), |
| nn.Conv3d(hidden_dim, hidden_dim, kernel_size=1, bias=True), |
| nn.GELU(), |
| nn.Conv3d(hidden_dim, out_channels, kernel_size=1, bias=True), |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x = self.stem(x) |
| x = self.blocks(x) |
| x = self.head(x) |
| return x |
|
|