MalikAyaanAhmed1123's picture
Create unet.py
c9625c9 verified
# src/model/unet.py
import torch
import torch.nn as nn
from einops import rearrange
class ResidualBlock(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, 1, 1),
nn.GroupNorm(8, out_ch),
nn.GELU(),
nn.Conv2d(out_ch, out_ch, 3, 1, 1),
nn.GroupNorm(8, out_ch),
)
if in_ch != out_ch:
self.skip = nn.Conv2d(in_ch, out_ch, 1)
else:
self.skip = nn.Identity()
def forward(self, x):
return self.skip(x) + self.net(x)
class SimpleUNet(nn.Module):
"""
Simple U-Net style architecture.
For a real trillion-scale model, replace with an attention-augmented UNet that supports cross-attention.
"""
def __init__(self, in_ch=4, base_channels=128, cond_dim=None):
super().__init__()
self.down1 = ResidualBlock(in_ch, base_channels)
self.pool = nn.AvgPool2d(2)
self.down2 = ResidualBlock(base_channels, base_channels*2)
self.mid = ResidualBlock(base_channels*2, base_channels*2)
self.up2 = ResidualBlock(base_channels*2 + base_channels*2, base_channels)
self.up1 = ResidualBlock(base_channels + base_channels, base_channels)
self.out = nn.Conv2d(base_channels, in_ch, 3, 1, 1)
# optional conditioning projector
if cond_dim:
self.cond_proj = nn.Linear(cond_dim, base_channels*2)
else:
self.cond_proj = None
def forward(self, x, cond=None):
d1 = self.down1(x)
p1 = self.pool(d1)
d2 = self.down2(p1)
p2 = self.pool(d2)
m = self.mid(p2)
# conditioning injection (simple broadcast add)
if self.cond_proj is not None and cond is not None:
c = self.cond_proj(cond).unsqueeze(-1).unsqueeze(-1)
m = m + c
u2 = nn.functional.interpolate(m, scale_factor=2, mode='nearest')
u2 = torch.cat([u2, d2], dim=1)
u2 = self.up2(u2)
u1 = nn.functional.interpolate(u2, scale_factor=2, mode='nearest')
u1 = torch.cat([u1, d1], dim=1)
u1 = self.up1(u1)
return self.out(u1)