File size: 4,183 Bytes
32b2aaa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import torch.nn.functional as F
from torch import nn
class PreactResBlock(nn.Sequential):
def __init__(self, dim):
super().__init__(
nn.GroupNorm(dim // 16, dim),
nn.GELU(),
nn.Conv2d(dim, dim, 3, padding=1),
nn.GroupNorm(dim // 16, dim),
nn.GELU(),
nn.Conv2d(dim, dim, 3, padding=1),
)
def forward(self, x):
return x + super().forward(x)
class UNetBlock(nn.Module):
def __init__(self, input_dim, output_dim=None, scale_factor=1.0):
super().__init__()
if output_dim is None:
output_dim = input_dim
self.pre_conv = nn.Conv2d(input_dim, output_dim, 3, padding=1)
self.res_block1 = PreactResBlock(output_dim)
self.res_block2 = PreactResBlock(output_dim)
self.downsample = self.upsample = nn.Identity()
if scale_factor > 1:
self.upsample = nn.Upsample(scale_factor=scale_factor)
elif scale_factor < 1:
self.downsample = nn.Upsample(scale_factor=scale_factor)
def forward(self, x, h=None):
"""
Args:
x: (b c h w), last output
h: (b c h w), skip output
Returns:
o: (b c h w), output
s: (b c h w), skip output
"""
x = self.upsample(x)
if h is not None:
assert x.shape == h.shape, f"{x.shape} != {h.shape}"
x = x + h
x = self.pre_conv(x)
x = self.res_block1(x)
x = self.res_block2(x)
return self.downsample(x), x
class UNet(nn.Module):
def __init__(self, input_dim, output_dim, hidden_dim=16, num_blocks=4, num_middle_blocks=2):
super().__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.input_proj = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
self.encoder_blocks = nn.ModuleList(
[
UNetBlock(input_dim=hidden_dim * 2**i, output_dim=hidden_dim * 2 ** (i + 1), scale_factor=0.5)
for i in range(num_blocks)
]
)
self.middle_blocks = nn.ModuleList(
[UNetBlock(input_dim=hidden_dim * 2**num_blocks) for _ in range(num_middle_blocks)]
)
self.decoder_blocks = nn.ModuleList(
[
UNetBlock(input_dim=hidden_dim * 2 ** (i + 1), output_dim=hidden_dim * 2**i, scale_factor=2)
for i in reversed(range(num_blocks))
]
)
self.head = nn.Sequential(
nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
nn.GELU(),
nn.Conv2d(hidden_dim, output_dim, 1),
)
@property
def scale_factor(self):
return 2 ** len(self.encoder_blocks)
def pad_to_fit(self, x):
"""
Args:
x: (b c h w), input
Returns:
x: (b c h' w'), padded input
"""
hpad = (self.scale_factor - x.shape[2] % self.scale_factor) % self.scale_factor
wpad = (self.scale_factor - x.shape[3] % self.scale_factor) % self.scale_factor
return F.pad(x, (0, wpad, 0, hpad))
def forward(self, x):
"""
Args:
x: (b c h w), input
Returns:
o: (b c h w), output
"""
shape = x.shape
x = self.pad_to_fit(x)
x = self.input_proj(x)
s_list = []
for block in self.encoder_blocks:
x, s = block(x)
s_list.append(s)
for block in self.middle_blocks:
x, _ = block(x)
for block, s in zip(self.decoder_blocks, reversed(s_list)):
x, _ = block(x, s)
x = self.head(x)
x = x[..., : shape[2], : shape[3]]
return x
def test(self, shape=(3, 512, 256)):
import ptflops
macs, params = ptflops.get_model_complexity_info(
self,
shape,
as_strings=True,
print_per_layer_stat=True,
verbose=True,
)
print(f"macs: {macs}")
print(f"params: {params}")
def main():
model = UNet(3, 3)
model.test()
if __name__ == "__main__":
main()
|