|
import torch.nn.functional as F |
|
from torch import nn |
|
|
|
|
|
class PreactResBlock(nn.Sequential): |
|
def __init__(self, dim): |
|
super().__init__( |
|
nn.GroupNorm(dim // 16, dim), |
|
nn.GELU(), |
|
nn.Conv2d(dim, dim, 3, padding=1), |
|
nn.GroupNorm(dim // 16, dim), |
|
nn.GELU(), |
|
nn.Conv2d(dim, dim, 3, padding=1), |
|
) |
|
|
|
def forward(self, x): |
|
return x + super().forward(x) |
|
|
|
|
|
class UNetBlock(nn.Module): |
|
def __init__(self, input_dim, output_dim=None, scale_factor=1.0): |
|
super().__init__() |
|
if output_dim is None: |
|
output_dim = input_dim |
|
self.pre_conv = nn.Conv2d(input_dim, output_dim, 3, padding=1) |
|
self.res_block1 = PreactResBlock(output_dim) |
|
self.res_block2 = PreactResBlock(output_dim) |
|
self.downsample = self.upsample = nn.Identity() |
|
if scale_factor > 1: |
|
self.upsample = nn.Upsample(scale_factor=scale_factor) |
|
elif scale_factor < 1: |
|
self.downsample = nn.Upsample(scale_factor=scale_factor) |
|
|
|
def forward(self, x, h=None): |
|
""" |
|
Args: |
|
x: (b c h w), last output |
|
h: (b c h w), skip output |
|
Returns: |
|
o: (b c h w), output |
|
s: (b c h w), skip output |
|
""" |
|
x = self.upsample(x) |
|
if h is not None: |
|
assert x.shape == h.shape, f"{x.shape} != {h.shape}" |
|
x = x + h |
|
x = self.pre_conv(x) |
|
x = self.res_block1(x) |
|
x = self.res_block2(x) |
|
return self.downsample(x), x |
|
|
|
|
|
class UNet(nn.Module): |
|
def __init__( |
|
self, input_dim, output_dim, hidden_dim=16, num_blocks=4, num_middle_blocks=2 |
|
): |
|
super().__init__() |
|
self.input_dim = input_dim |
|
self.output_dim = output_dim |
|
self.input_proj = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) |
|
self.encoder_blocks = nn.ModuleList( |
|
[ |
|
UNetBlock( |
|
input_dim=hidden_dim * 2**i, |
|
output_dim=hidden_dim * 2 ** (i + 1), |
|
scale_factor=0.5, |
|
) |
|
for i in range(num_blocks) |
|
] |
|
) |
|
self.middle_blocks = nn.ModuleList( |
|
[ |
|
UNetBlock(input_dim=hidden_dim * 2**num_blocks) |
|
for _ in range(num_middle_blocks) |
|
] |
|
) |
|
self.decoder_blocks = nn.ModuleList( |
|
[ |
|
UNetBlock( |
|
input_dim=hidden_dim * 2 ** (i + 1), |
|
output_dim=hidden_dim * 2**i, |
|
scale_factor=2, |
|
) |
|
for i in reversed(range(num_blocks)) |
|
] |
|
) |
|
self.head = nn.Sequential( |
|
nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), |
|
nn.GELU(), |
|
nn.Conv2d(hidden_dim, output_dim, 1), |
|
) |
|
|
|
@property |
|
def scale_factor(self): |
|
return 2 ** len(self.encoder_blocks) |
|
|
|
def pad_to_fit(self, x): |
|
""" |
|
Args: |
|
x: (b c h w), input |
|
Returns: |
|
x: (b c h' w'), padded input |
|
""" |
|
hpad = (self.scale_factor - x.shape[2] % self.scale_factor) % self.scale_factor |
|
wpad = (self.scale_factor - x.shape[3] % self.scale_factor) % self.scale_factor |
|
return F.pad(x, (0, wpad, 0, hpad)) |
|
|
|
def forward(self, x): |
|
""" |
|
Args: |
|
x: (b c h w), input |
|
Returns: |
|
o: (b c h w), output |
|
""" |
|
shape = x.shape |
|
|
|
x = self.pad_to_fit(x) |
|
x = self.input_proj(x) |
|
|
|
s_list = [] |
|
for block in self.encoder_blocks: |
|
x, s = block(x) |
|
s_list.append(s) |
|
|
|
for block in self.middle_blocks: |
|
x, _ = block(x) |
|
|
|
for block, s in zip(self.decoder_blocks, reversed(s_list)): |
|
x, _ = block(x, s) |
|
|
|
x = self.head(x) |
|
x = x[..., : shape[2], : shape[3]] |
|
|
|
return x |
|
|
|
def test(self, shape=(3, 512, 256)): |
|
import ptflops |
|
|
|
macs, params = ptflops.get_model_complexity_info( |
|
self, |
|
shape, |
|
as_strings=True, |
|
print_per_layer_stat=True, |
|
verbose=True, |
|
) |
|
|
|
print(f"macs: {macs}") |
|
print(f"params: {params}") |
|
|
|
|
|
def main(): |
|
model = UNet(3, 3) |
|
model.test() |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|