File size: 2,904 Bytes
a2dba58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
from models.unet_model import Unet, default
from torch import Tensor, nn
import torch
from typing import Optional, List   
from einops.layers.torch import Rearrange


class GlobalCL(Unet):
    def __init__(self, 
                 img_size,
                 dim: int = 64,
                 init_dim: Optional[int] = None,
                 dim_mults: List[int] = [1, 2, 4, 8],
                  **kwargs):
        super().__init__(**kwargs)
        init_dim = default(init_dim, dim)
        # from the paper 
        g_emb= 1024
        g_out = 128
        dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
        mid_dim = dims[-1]
        mid_img_size = img_size
        for _ in range(len(dims)-2):
            mid_img_size = int((mid_img_size -1) / 2) + 1
        self.g1 = nn.Sequential(
            Rearrange('b c h w -> b (c h w)'),
            nn.Linear(mid_dim * mid_img_size ** 2, g_emb, bias=False),
            nn.ReLU(),
            nn.Linear(g_emb, g_out, bias=False),
        )

    def forward(self, x: Tensor) -> Tensor:
        x = self.init_conv(x)

        t = None

        for block1, block2, attn, downsample in self.downs:
            x = block1(x, t)

            x = block2(x, t)
            x = attn(x)

            x = downsample(x)

        x = self.mid_block1(x, t)
        x = self.mid_attn(x)
        x = self.mid_block2(x, t)

        x = self.g1(x)
        return x


class LocalCL(Unet):
    def __init__(self, 
                 img_size,
                 dim: int = 64,
                 init_dim: Optional[int] = None,
                 dim_mults: List[int] = [1, 2, 4, 8],
                  **kwargs):
        super().__init__(**kwargs)
        init_dim = default(init_dim, dim)
        # from the paper 
        dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
        #g_2 small network with two 1x1 convolutions
        self.l = 2
        mid_dim = dims[-self.l-1]
        self.g2 = nn.Sequential(
            nn.Conv2d(mid_dim, mid_dim, 1, bias=False),
            nn.ReLU(),
            nn.BatchNorm2d(mid_dim),
            nn.Conv2d(mid_dim, mid_dim, 1, bias=False),
        )

    def forward(self, x: Tensor) -> Tensor:
        x = self.init_conv(x)
        r = x.clone()

        t = None

        h = []

        for block1, block2, attn, downsample in self.downs:
            x = block1(x, t)
            h.append(x)

            x = block2(x, t)
            x = attn(x)
            h.append(x)

            x = downsample(x)

        x = self.mid_block1(x, t)
        x = self.mid_attn(x)
        x = self.mid_block2(x, t)

        for block1, block2, attn, upsample in self.ups[:self.l]:
            x = torch.cat((x, h.pop()), dim=1)
            x = block1(x, t)

            x = torch.cat((x, h.pop()), dim=1)
            x = block2(x, t)
            x = attn(x)

            x = upsample(x)
        
        x = self.g2(x)
        return x