File size: 6,403 Bytes
2541ce6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange

# my layer normalization

class LayerNorm(nn.Module):
    def __init__(self, eps= 1e-5):
        super(LayerNorm, self).__init__()
        self.eps = eps
    def forward(self, input):
        shape=tuple(input.size()[1:])
        return F.layer_norm(input, shape, eps=self.eps)
    def extra_repr(self):
        return f'eps={self.eps}'

# helpers functions
def exists(val):
    return val is not None

def default(val, d):
    return val if exists(val) else d

# lambda layer

class LambdaLayer(nn.Module):
    def __init__(

        self,

        dim,

        *,

        dim_k,

        n = None,

        r = None,

        heads = 4,

        dim_out = None,

        dim_u = 1,

        normalization="batch"):
        super().__init__()
        dim_out = default(dim_out, dim)
        self.u = dim_u # intra-depth dimension
        self.heads = heads

        assert (dim_out % heads) == 0, 'values dimension must be divisible by number of heads for multi-head query'
        dim_v = dim_out // heads

        self.to_q = nn.Conv2d(dim, dim_k * heads, 1, bias = False)
        self.to_k = nn.Conv2d(dim, dim_k * dim_u, 1, bias = False)
        self.to_v = nn.Conv2d(dim, dim_v * dim_u, 1, bias = False)
        print(f"using {normalization} in lambda layer")
        if normalization=="none":
            self.norm_q = nn.Identity(dim_k * heads)
            self.norm_v = nn.Identity(dim_v * dim_u)
        elif normalization=="instance":
            self.norm_q = nn.InstanceNorm2d(dim_k * heads)
            self.norm_v = nn.InstanceNorm2d(dim_v * dim_u)
        elif normalization=="layer":
            self.norm_q = LayerNorm()
            self.norm_v = LayerNorm()
        else:
            self.norm_q = nn.BatchNorm2d(dim_k * heads)
            self.norm_v = nn.BatchNorm2d(dim_v * dim_u)
            print(f"using BN in lambda layer?")

        self.local_contexts = exists(r)
        if exists(r):
            assert (r % 2) == 1, 'Receptive kernel size should be odd'
            self.pos_conv = nn.Conv3d(dim_u, dim_k, (1, r, r), padding = (0, r // 2, r // 2))
        else:
            assert exists(n), 'You must specify the total sequence length (h x w)'
            self.pos_emb = nn.Parameter(torch.randn(n, n, dim_k, dim_u))


    def forward(self, x):
        b, c, hh, ww, u, h = *x.shape, self.u, self.heads

        q = self.to_q(x)
        k = self.to_k(x)
        v = self.to_v(x)

        q = self.norm_q(q)
        v = self.norm_v(v)

        q = rearrange(q, 'b (h k) hh ww -> b h k (hh ww)', h = h)
        k = rearrange(k, 'b (u k) hh ww -> b u k (hh ww)', u = u)
        v = rearrange(v, 'b (u v) hh ww -> b u v (hh ww)', u = u)

        k = k.softmax(dim=-1)

        位c = einsum('b u k m, b u v m -> b k v', k, v)
        Yc = einsum('b h k n, b k v -> b h v n', q, 位c)

        if self.local_contexts:
            v = rearrange(v, 'b u v (hh ww) -> b u v hh ww', hh = hh, ww = ww)
            位p = self.pos_conv(v)
            Yp = einsum('b h k n, b k v n -> b h v n', q, 位p.flatten(3))
        else:
            位p = einsum('n m k u, b u v m -> b n k v', self.pos_emb, v)
            Yp = einsum('b h k n, b n k v -> b h v n', q, 位p)

        Y = Yc + Yp
        out = rearrange(Y, 'b h v (hh ww) -> b (h v) hh ww', hh = hh, ww = ww)
        return out


# i'm not sure whether this will work or not
class Recursion(nn.Module):
    def __init__(self, N: int, hidden_dim:int=64):
        super(Recursion,self).__init__()
        self.N = N
        self.lambdaNxN_identity = LambdaLayer(dim=hidden_dim, dim_out=hidden_dim, n=N * N, dim_k=16, heads=2, dim_u=1)
        # merge upstream information here
        self.lambdaNxN_merge = LambdaLayer(dim=2*hidden_dim, dim_out=hidden_dim, n=N * N, dim_k=16, heads=2, dim_u=1)
        self.downscale_conv = nn.Conv2d(hidden_dim, hidden_dim, kernel_size=N, stride=N)
        self.upscale_conv = nn.Conv2d(hidden_dim, hidden_dim * N * N, kernel_size=3,padding=1)
        self.pixel_shuffle = nn.PixelShuffle(N)
    
    def forward(self, x: torch.Tensor):
        N = self.N

        def to_patch(blocks:torch.Tensor)->torch.Tensor:
            shape = blocks.shape
            blocks_patch = F.unfold(blocks, kernel_size=N, stride=N)
            blocks_patch = blocks_patch.view(shape[0], shape[1], N, N, -1)
            num_patch = blocks_patch.shape[-1]
            blocks_patch = blocks_patch.permute(0, 4, 1, 2, 3).reshape(-1, shape[1], N, N).contiguous()
            return blocks_patch, num_patch
        
        def combine_patch(processed_patch,shape,num_patch):
            processed_patch = processed_patch.reshape(shape[0], num_patch, shape[1], N, N)
            processed_patch=processed_patch.permute(0, 2, 3, 4, 1).reshape(shape[0],shape[1] * N * N,num_patch).contiguous()
            processed=F.fold(processed_patch,output_size=(shape[-2],shape[-1]),kernel_size=N,stride=N)
            return processed

        def process(blocks:torch.Tensor)->torch.Tensor:
            shape = blocks.shape
            if blocks.shape[-1] == N:
                processed = self.lambdaNxN_identity(blocks)
                return processed
            # to NxN patchs
            blocks_patch,num_patch=to_patch(blocks)
            # pass through identity
            processed_patch = self.lambdaNxN_identity(blocks_patch)
            # back to HxW
            processed=combine_patch(processed_patch,shape,num_patch)
            # get feedback
            feedback = process(self.downscale_conv(processed))
            # upscale feedback
            upscale_feedback = self.upscale_conv(feedback)
            upscale_feedback=self.pixel_shuffle(upscale_feedback)
            # combine results
            combined = torch.cat([processed, upscale_feedback], dim=1)
            combined_shape=combined.shape
            combined_patch,num_patch=to_patch(combined)
            combined_patch_reduced = self.lambdaNxN_merge(combined_patch)
            ret_shape=(combined_shape[0],combined_shape[1]//2,combined_shape[2],combined_shape[3])
            ret=combine_patch(combined_patch_reduced,ret_shape,num_patch)
            return ret

        return process(x)