File size: 6,403 Bytes
2541ce6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange
# my layer normalization
class LayerNorm(nn.Module):
def __init__(self, eps= 1e-5):
super(LayerNorm, self).__init__()
self.eps = eps
def forward(self, input):
shape=tuple(input.size()[1:])
return F.layer_norm(input, shape, eps=self.eps)
def extra_repr(self):
return f'eps={self.eps}'
# helpers functions
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
# lambda layer
class LambdaLayer(nn.Module):
def __init__(
self,
dim,
*,
dim_k,
n = None,
r = None,
heads = 4,
dim_out = None,
dim_u = 1,
normalization="batch"):
super().__init__()
dim_out = default(dim_out, dim)
self.u = dim_u # intra-depth dimension
self.heads = heads
assert (dim_out % heads) == 0, 'values dimension must be divisible by number of heads for multi-head query'
dim_v = dim_out // heads
self.to_q = nn.Conv2d(dim, dim_k * heads, 1, bias = False)
self.to_k = nn.Conv2d(dim, dim_k * dim_u, 1, bias = False)
self.to_v = nn.Conv2d(dim, dim_v * dim_u, 1, bias = False)
print(f"using {normalization} in lambda layer")
if normalization=="none":
self.norm_q = nn.Identity(dim_k * heads)
self.norm_v = nn.Identity(dim_v * dim_u)
elif normalization=="instance":
self.norm_q = nn.InstanceNorm2d(dim_k * heads)
self.norm_v = nn.InstanceNorm2d(dim_v * dim_u)
elif normalization=="layer":
self.norm_q = LayerNorm()
self.norm_v = LayerNorm()
else:
self.norm_q = nn.BatchNorm2d(dim_k * heads)
self.norm_v = nn.BatchNorm2d(dim_v * dim_u)
print(f"using BN in lambda layer?")
self.local_contexts = exists(r)
if exists(r):
assert (r % 2) == 1, 'Receptive kernel size should be odd'
self.pos_conv = nn.Conv3d(dim_u, dim_k, (1, r, r), padding = (0, r // 2, r // 2))
else:
assert exists(n), 'You must specify the total sequence length (h x w)'
self.pos_emb = nn.Parameter(torch.randn(n, n, dim_k, dim_u))
def forward(self, x):
b, c, hh, ww, u, h = *x.shape, self.u, self.heads
q = self.to_q(x)
k = self.to_k(x)
v = self.to_v(x)
q = self.norm_q(q)
v = self.norm_v(v)
q = rearrange(q, 'b (h k) hh ww -> b h k (hh ww)', h = h)
k = rearrange(k, 'b (u k) hh ww -> b u k (hh ww)', u = u)
v = rearrange(v, 'b (u v) hh ww -> b u v (hh ww)', u = u)
k = k.softmax(dim=-1)
位c = einsum('b u k m, b u v m -> b k v', k, v)
Yc = einsum('b h k n, b k v -> b h v n', q, 位c)
if self.local_contexts:
v = rearrange(v, 'b u v (hh ww) -> b u v hh ww', hh = hh, ww = ww)
位p = self.pos_conv(v)
Yp = einsum('b h k n, b k v n -> b h v n', q, 位p.flatten(3))
else:
位p = einsum('n m k u, b u v m -> b n k v', self.pos_emb, v)
Yp = einsum('b h k n, b n k v -> b h v n', q, 位p)
Y = Yc + Yp
out = rearrange(Y, 'b h v (hh ww) -> b (h v) hh ww', hh = hh, ww = ww)
return out
# i'm not sure whether this will work or not
class Recursion(nn.Module):
def __init__(self, N: int, hidden_dim:int=64):
super(Recursion,self).__init__()
self.N = N
self.lambdaNxN_identity = LambdaLayer(dim=hidden_dim, dim_out=hidden_dim, n=N * N, dim_k=16, heads=2, dim_u=1)
# merge upstream information here
self.lambdaNxN_merge = LambdaLayer(dim=2*hidden_dim, dim_out=hidden_dim, n=N * N, dim_k=16, heads=2, dim_u=1)
self.downscale_conv = nn.Conv2d(hidden_dim, hidden_dim, kernel_size=N, stride=N)
self.upscale_conv = nn.Conv2d(hidden_dim, hidden_dim * N * N, kernel_size=3,padding=1)
self.pixel_shuffle = nn.PixelShuffle(N)
def forward(self, x: torch.Tensor):
N = self.N
def to_patch(blocks:torch.Tensor)->torch.Tensor:
shape = blocks.shape
blocks_patch = F.unfold(blocks, kernel_size=N, stride=N)
blocks_patch = blocks_patch.view(shape[0], shape[1], N, N, -1)
num_patch = blocks_patch.shape[-1]
blocks_patch = blocks_patch.permute(0, 4, 1, 2, 3).reshape(-1, shape[1], N, N).contiguous()
return blocks_patch, num_patch
def combine_patch(processed_patch,shape,num_patch):
processed_patch = processed_patch.reshape(shape[0], num_patch, shape[1], N, N)
processed_patch=processed_patch.permute(0, 2, 3, 4, 1).reshape(shape[0],shape[1] * N * N,num_patch).contiguous()
processed=F.fold(processed_patch,output_size=(shape[-2],shape[-1]),kernel_size=N,stride=N)
return processed
def process(blocks:torch.Tensor)->torch.Tensor:
shape = blocks.shape
if blocks.shape[-1] == N:
processed = self.lambdaNxN_identity(blocks)
return processed
# to NxN patchs
blocks_patch,num_patch=to_patch(blocks)
# pass through identity
processed_patch = self.lambdaNxN_identity(blocks_patch)
# back to HxW
processed=combine_patch(processed_patch,shape,num_patch)
# get feedback
feedback = process(self.downscale_conv(processed))
# upscale feedback
upscale_feedback = self.upscale_conv(feedback)
upscale_feedback=self.pixel_shuffle(upscale_feedback)
# combine results
combined = torch.cat([processed, upscale_feedback], dim=1)
combined_shape=combined.shape
combined_patch,num_patch=to_patch(combined)
combined_patch_reduced = self.lambdaNxN_merge(combined_patch)
ret_shape=(combined_shape[0],combined_shape[1]//2,combined_shape[2],combined_shape[3])
ret=combine_patch(combined_patch_reduced,ret_shape,num_patch)
return ret
return process(x) |