|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from einops import rearrange |
|
|
|
|
|
class ResidualBlock(nn.Module): |
|
|
def __init__(self, in_ch, out_ch): |
|
|
super().__init__() |
|
|
self.net = nn.Sequential( |
|
|
nn.Conv2d(in_ch, out_ch, 3, 1, 1), |
|
|
nn.GroupNorm(8, out_ch), |
|
|
nn.GELU(), |
|
|
nn.Conv2d(out_ch, out_ch, 3, 1, 1), |
|
|
nn.GroupNorm(8, out_ch), |
|
|
) |
|
|
if in_ch != out_ch: |
|
|
self.skip = nn.Conv2d(in_ch, out_ch, 1) |
|
|
else: |
|
|
self.skip = nn.Identity() |
|
|
|
|
|
def forward(self, x): |
|
|
return self.skip(x) + self.net(x) |
|
|
|
|
|
class SimpleUNet(nn.Module): |
|
|
""" |
|
|
Simple U-Net style architecture. |
|
|
For a real trillion-scale model, replace with an attention-augmented UNet that supports cross-attention. |
|
|
""" |
|
|
def __init__(self, in_ch=4, base_channels=128, cond_dim=None): |
|
|
super().__init__() |
|
|
self.down1 = ResidualBlock(in_ch, base_channels) |
|
|
self.pool = nn.AvgPool2d(2) |
|
|
self.down2 = ResidualBlock(base_channels, base_channels*2) |
|
|
self.mid = ResidualBlock(base_channels*2, base_channels*2) |
|
|
self.up2 = ResidualBlock(base_channels*2 + base_channels*2, base_channels) |
|
|
self.up1 = ResidualBlock(base_channels + base_channels, base_channels) |
|
|
self.out = nn.Conv2d(base_channels, in_ch, 3, 1, 1) |
|
|
|
|
|
|
|
|
if cond_dim: |
|
|
self.cond_proj = nn.Linear(cond_dim, base_channels*2) |
|
|
else: |
|
|
self.cond_proj = None |
|
|
|
|
|
def forward(self, x, cond=None): |
|
|
d1 = self.down1(x) |
|
|
p1 = self.pool(d1) |
|
|
d2 = self.down2(p1) |
|
|
p2 = self.pool(d2) |
|
|
m = self.mid(p2) |
|
|
|
|
|
if self.cond_proj is not None and cond is not None: |
|
|
c = self.cond_proj(cond).unsqueeze(-1).unsqueeze(-1) |
|
|
m = m + c |
|
|
u2 = nn.functional.interpolate(m, scale_factor=2, mode='nearest') |
|
|
u2 = torch.cat([u2, d2], dim=1) |
|
|
u2 = self.up2(u2) |
|
|
u1 = nn.functional.interpolate(u2, scale_factor=2, mode='nearest') |
|
|
u1 = torch.cat([u1, d1], dim=1) |
|
|
u1 = self.up1(u1) |
|
|
return self.out(u1) |