Spaces:
Runtime error
Runtime error
File size: 2,904 Bytes
a2dba58 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
from models.unet_model import Unet, default
from torch import Tensor, nn
import torch
from typing import Optional, List
from einops.layers.torch import Rearrange
class GlobalCL(Unet):
def __init__(self,
img_size,
dim: int = 64,
init_dim: Optional[int] = None,
dim_mults: List[int] = [1, 2, 4, 8],
**kwargs):
super().__init__(**kwargs)
init_dim = default(init_dim, dim)
# from the paper
g_emb= 1024
g_out = 128
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
mid_dim = dims[-1]
mid_img_size = img_size
for _ in range(len(dims)-2):
mid_img_size = int((mid_img_size -1) / 2) + 1
self.g1 = nn.Sequential(
Rearrange('b c h w -> b (c h w)'),
nn.Linear(mid_dim * mid_img_size ** 2, g_emb, bias=False),
nn.ReLU(),
nn.Linear(g_emb, g_out, bias=False),
)
def forward(self, x: Tensor) -> Tensor:
x = self.init_conv(x)
t = None
for block1, block2, attn, downsample in self.downs:
x = block1(x, t)
x = block2(x, t)
x = attn(x)
x = downsample(x)
x = self.mid_block1(x, t)
x = self.mid_attn(x)
x = self.mid_block2(x, t)
x = self.g1(x)
return x
class LocalCL(Unet):
def __init__(self,
img_size,
dim: int = 64,
init_dim: Optional[int] = None,
dim_mults: List[int] = [1, 2, 4, 8],
**kwargs):
super().__init__(**kwargs)
init_dim = default(init_dim, dim)
# from the paper
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
#g_2 small network with two 1x1 convolutions
self.l = 2
mid_dim = dims[-self.l-1]
self.g2 = nn.Sequential(
nn.Conv2d(mid_dim, mid_dim, 1, bias=False),
nn.ReLU(),
nn.BatchNorm2d(mid_dim),
nn.Conv2d(mid_dim, mid_dim, 1, bias=False),
)
def forward(self, x: Tensor) -> Tensor:
x = self.init_conv(x)
r = x.clone()
t = None
h = []
for block1, block2, attn, downsample in self.downs:
x = block1(x, t)
h.append(x)
x = block2(x, t)
x = attn(x)
h.append(x)
x = downsample(x)
x = self.mid_block1(x, t)
x = self.mid_attn(x)
x = self.mid_block2(x, t)
for block1, block2, attn, upsample in self.ups[:self.l]:
x = torch.cat((x, h.pop()), dim=1)
x = block1(x, t)
x = torch.cat((x, h.pop()), dim=1)
x = block2(x, t)
x = attn(x)
x = upsample(x)
x = self.g2(x)
return x
|