File size: 5,668 Bytes
8cb1339
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange

# helpers functions

def exists(val):
    return val is not None

def default(val, d):
    return val if exists(val) else d

# lambda layer

class LambdaLayer(nn.Module):
    def __init__(

        self,

        dim,

        *,

        dim_k,

        n = None,

        r = None,

        heads = 4,

        dim_out = None,

        dim_u = 1,

        norm="batch"):
        super().__init__()
        dim_out = default(dim_out, dim)
        self.u = dim_u # intra-depth dimension
        self.heads = heads

        assert (dim_out % heads) == 0, 'values dimension must be divisible by number of heads for multi-head query'
        dim_v = dim_out // heads

        self.to_q = nn.Conv2d(dim, dim_k * heads, 1, bias = False)
        self.to_k = nn.Conv2d(dim, dim_k * dim_u, 1, bias = False)
        self.to_v = nn.Conv2d(dim, dim_v * dim_u, 1, bias = False)
        if norm=="instance":
            self.norm_q = nn.InstanceNorm2d(dim_k * heads)
            self.norm_v = nn.InstanceNorm2d(dim_v * dim_u)
        else:
            self.norm_q = nn.BatchNorm2d(dim_k * heads)
            self.norm_v = nn.BatchNorm2d(dim_v * dim_u)
        self.local_contexts = exists(r)
        if exists(r):
            assert (r % 2) == 1, 'Receptive kernel size should be odd'
            self.pos_conv = nn.Conv3d(dim_u, dim_k, (1, r, r), padding = (0, r // 2, r // 2))
        else:
            assert exists(n), 'You must specify the total sequence length (h x w)'
            self.pos_emb = nn.Parameter(torch.randn(n, n, dim_k, dim_u))


    def forward(self, x):
        b, c, hh, ww, u, h = *x.shape, self.u, self.heads

        q = self.to_q(x)
        k = self.to_k(x)
        v = self.to_v(x)

        q = self.norm_q(q)
        v = self.norm_v(v)

        q = rearrange(q, 'b (h k) hh ww -> b h k (hh ww)', h = h)
        k = rearrange(k, 'b (u k) hh ww -> b u k (hh ww)', u = u)
        v = rearrange(v, 'b (u v) hh ww -> b u v (hh ww)', u = u)

        k = k.softmax(dim=-1)

        位c = einsum('b u k m, b u v m -> b k v', k, v)
        Yc = einsum('b h k n, b k v -> b h v n', q, 位c)

        if self.local_contexts:
            v = rearrange(v, 'b u v (hh ww) -> b u v hh ww', hh = hh, ww = ww)
            位p = self.pos_conv(v)
            Yp = einsum('b h k n, b k v n -> b h v n', q, 位p.flatten(3))
        else:
            位p = einsum('n m k u, b u v m -> b n k v', self.pos_emb, v)
            Yp = einsum('b h k n, b n k v -> b h v n', q, 位p)

        Y = Yc + Yp
        out = rearrange(Y, 'b h v (hh ww) -> b (h v) hh ww', hh = hh, ww = ww)
        return out


# i'm not sure whether this will work or not
class Recursion(nn.Module):
    def __init__(self, N: int, hidden_dim:int=64):
        super(Recursion,self).__init__()
        self.N = N
        self.lambdaNxN_identity = LambdaLayer(dim=hidden_dim, dim_out=hidden_dim, n=N * N, dim_k=16, heads=2, dim_u=1)
        # merge upstream information here
        self.lambdaNxN_merge = LambdaLayer(dim=2*hidden_dim, dim_out=hidden_dim, n=N * N, dim_k=16, heads=2, dim_u=1)
        self.downscale_conv = nn.Conv2d(hidden_dim, hidden_dim, kernel_size=N, stride=N)
        self.upscale_conv = nn.Conv2d(hidden_dim, hidden_dim * N * N, kernel_size=3,padding=1)
        self.pixel_shuffle = nn.PixelShuffle(N)
    
    def forward(self, x: torch.Tensor):
        N = self.N

        def to_patch(blocks:torch.Tensor)->torch.Tensor:
            shape = blocks.shape
            blocks_patch = F.unfold(blocks, kernel_size=N, stride=N)
            blocks_patch = blocks_patch.view(shape[0], shape[1], N, N, -1)
            num_patch = blocks_patch.shape[-1]
            blocks_patch = blocks_patch.permute(0, 4, 1, 2, 3).reshape(-1, shape[1], N, N).contiguous()
            return blocks_patch, num_patch
        
        def combine_patch(processed_patch,shape,num_patch):
            processed_patch = processed_patch.reshape(shape[0], num_patch, shape[1], N, N)
            processed_patch=processed_patch.permute(0, 2, 3, 4, 1).reshape(shape[0],shape[1] * N * N,num_patch).contiguous()
            processed=F.fold(processed_patch,output_size=(shape[-2],shape[-1]),kernel_size=N,stride=N)
            return processed

        def process(blocks:torch.Tensor)->torch.Tensor:
            shape = blocks.shape
            if blocks.shape[-1] == N:
                processed = self.lambdaNxN_identity(blocks)
                return processed
            # to NxN patchs
            blocks_patch,num_patch=to_patch(blocks)
            # pass through identity
            processed_patch = self.lambdaNxN_identity(blocks_patch)
            # back to HxW
            processed=combine_patch(processed_patch,shape,num_patch)
            # get feedback
            feedback = process(self.downscale_conv(processed))
            # upscale feedback
            upscale_feedback = self.upscale_conv(feedback)
            upscale_feedback=self.pixel_shuffle(upscale_feedback)
            # combine results
            combined = torch.cat([processed, upscale_feedback], dim=1)
            combined_shape=combined.shape
            combined_patch,num_patch=to_patch(combined)
            combined_patch_reduced = self.lambdaNxN_merge(combined_patch)
            ret_shape=(combined_shape[0],combined_shape[1]//2,combined_shape[2],combined_shape[3])
            ret=combine_patch(combined_patch_reduced,ret_shape,num_patch)
            return ret

        return process(x)