TEDM-demo / models /global_local_cl.py
anonymous
first commit without models
a2dba58
from models.unet_model import Unet, default
from torch import Tensor, nn
import torch
from typing import Optional, List
from einops.layers.torch import Rearrange
class GlobalCL(Unet):
def __init__(self,
img_size,
dim: int = 64,
init_dim: Optional[int] = None,
dim_mults: List[int] = [1, 2, 4, 8],
**kwargs):
super().__init__(**kwargs)
init_dim = default(init_dim, dim)
# from the paper
g_emb= 1024
g_out = 128
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
mid_dim = dims[-1]
mid_img_size = img_size
for _ in range(len(dims)-2):
mid_img_size = int((mid_img_size -1) / 2) + 1
self.g1 = nn.Sequential(
Rearrange('b c h w -> b (c h w)'),
nn.Linear(mid_dim * mid_img_size ** 2, g_emb, bias=False),
nn.ReLU(),
nn.Linear(g_emb, g_out, bias=False),
)
def forward(self, x: Tensor) -> Tensor:
x = self.init_conv(x)
t = None
for block1, block2, attn, downsample in self.downs:
x = block1(x, t)
x = block2(x, t)
x = attn(x)
x = downsample(x)
x = self.mid_block1(x, t)
x = self.mid_attn(x)
x = self.mid_block2(x, t)
x = self.g1(x)
return x
class LocalCL(Unet):
def __init__(self,
img_size,
dim: int = 64,
init_dim: Optional[int] = None,
dim_mults: List[int] = [1, 2, 4, 8],
**kwargs):
super().__init__(**kwargs)
init_dim = default(init_dim, dim)
# from the paper
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
#g_2 small network with two 1x1 convolutions
self.l = 2
mid_dim = dims[-self.l-1]
self.g2 = nn.Sequential(
nn.Conv2d(mid_dim, mid_dim, 1, bias=False),
nn.ReLU(),
nn.BatchNorm2d(mid_dim),
nn.Conv2d(mid_dim, mid_dim, 1, bias=False),
)
def forward(self, x: Tensor) -> Tensor:
x = self.init_conv(x)
r = x.clone()
t = None
h = []
for block1, block2, attn, downsample in self.downs:
x = block1(x, t)
h.append(x)
x = block2(x, t)
x = attn(x)
h.append(x)
x = downsample(x)
x = self.mid_block1(x, t)
x = self.mid_attn(x)
x = self.mid_block2(x, t)
for block1, block2, attn, upsample in self.ups[:self.l]:
x = torch.cat((x, h.pop()), dim=1)
x = block1(x, t)
x = torch.cat((x, h.pop()), dim=1)
x = block2(x, t)
x = attn(x)
x = upsample(x)
x = self.g2(x)
return x