File size: 5,668 Bytes
8cb1339 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange
# helpers functions
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
# lambda layer
class LambdaLayer(nn.Module):
def __init__(
self,
dim,
*,
dim_k,
n = None,
r = None,
heads = 4,
dim_out = None,
dim_u = 1,
norm="batch"):
super().__init__()
dim_out = default(dim_out, dim)
self.u = dim_u # intra-depth dimension
self.heads = heads
assert (dim_out % heads) == 0, 'values dimension must be divisible by number of heads for multi-head query'
dim_v = dim_out // heads
self.to_q = nn.Conv2d(dim, dim_k * heads, 1, bias = False)
self.to_k = nn.Conv2d(dim, dim_k * dim_u, 1, bias = False)
self.to_v = nn.Conv2d(dim, dim_v * dim_u, 1, bias = False)
if norm=="instance":
self.norm_q = nn.InstanceNorm2d(dim_k * heads)
self.norm_v = nn.InstanceNorm2d(dim_v * dim_u)
else:
self.norm_q = nn.BatchNorm2d(dim_k * heads)
self.norm_v = nn.BatchNorm2d(dim_v * dim_u)
self.local_contexts = exists(r)
if exists(r):
assert (r % 2) == 1, 'Receptive kernel size should be odd'
self.pos_conv = nn.Conv3d(dim_u, dim_k, (1, r, r), padding = (0, r // 2, r // 2))
else:
assert exists(n), 'You must specify the total sequence length (h x w)'
self.pos_emb = nn.Parameter(torch.randn(n, n, dim_k, dim_u))
def forward(self, x):
b, c, hh, ww, u, h = *x.shape, self.u, self.heads
q = self.to_q(x)
k = self.to_k(x)
v = self.to_v(x)
q = self.norm_q(q)
v = self.norm_v(v)
q = rearrange(q, 'b (h k) hh ww -> b h k (hh ww)', h = h)
k = rearrange(k, 'b (u k) hh ww -> b u k (hh ww)', u = u)
v = rearrange(v, 'b (u v) hh ww -> b u v (hh ww)', u = u)
k = k.softmax(dim=-1)
位c = einsum('b u k m, b u v m -> b k v', k, v)
Yc = einsum('b h k n, b k v -> b h v n', q, 位c)
if self.local_contexts:
v = rearrange(v, 'b u v (hh ww) -> b u v hh ww', hh = hh, ww = ww)
位p = self.pos_conv(v)
Yp = einsum('b h k n, b k v n -> b h v n', q, 位p.flatten(3))
else:
位p = einsum('n m k u, b u v m -> b n k v', self.pos_emb, v)
Yp = einsum('b h k n, b n k v -> b h v n', q, 位p)
Y = Yc + Yp
out = rearrange(Y, 'b h v (hh ww) -> b (h v) hh ww', hh = hh, ww = ww)
return out
# i'm not sure whether this will work or not
class Recursion(nn.Module):
def __init__(self, N: int, hidden_dim:int=64):
super(Recursion,self).__init__()
self.N = N
self.lambdaNxN_identity = LambdaLayer(dim=hidden_dim, dim_out=hidden_dim, n=N * N, dim_k=16, heads=2, dim_u=1)
# merge upstream information here
self.lambdaNxN_merge = LambdaLayer(dim=2*hidden_dim, dim_out=hidden_dim, n=N * N, dim_k=16, heads=2, dim_u=1)
self.downscale_conv = nn.Conv2d(hidden_dim, hidden_dim, kernel_size=N, stride=N)
self.upscale_conv = nn.Conv2d(hidden_dim, hidden_dim * N * N, kernel_size=3,padding=1)
self.pixel_shuffle = nn.PixelShuffle(N)
def forward(self, x: torch.Tensor):
N = self.N
def to_patch(blocks:torch.Tensor)->torch.Tensor:
shape = blocks.shape
blocks_patch = F.unfold(blocks, kernel_size=N, stride=N)
blocks_patch = blocks_patch.view(shape[0], shape[1], N, N, -1)
num_patch = blocks_patch.shape[-1]
blocks_patch = blocks_patch.permute(0, 4, 1, 2, 3).reshape(-1, shape[1], N, N).contiguous()
return blocks_patch, num_patch
def combine_patch(processed_patch,shape,num_patch):
processed_patch = processed_patch.reshape(shape[0], num_patch, shape[1], N, N)
processed_patch=processed_patch.permute(0, 2, 3, 4, 1).reshape(shape[0],shape[1] * N * N,num_patch).contiguous()
processed=F.fold(processed_patch,output_size=(shape[-2],shape[-1]),kernel_size=N,stride=N)
return processed
def process(blocks:torch.Tensor)->torch.Tensor:
shape = blocks.shape
if blocks.shape[-1] == N:
processed = self.lambdaNxN_identity(blocks)
return processed
# to NxN patchs
blocks_patch,num_patch=to_patch(blocks)
# pass through identity
processed_patch = self.lambdaNxN_identity(blocks_patch)
# back to HxW
processed=combine_patch(processed_patch,shape,num_patch)
# get feedback
feedback = process(self.downscale_conv(processed))
# upscale feedback
upscale_feedback = self.upscale_conv(feedback)
upscale_feedback=self.pixel_shuffle(upscale_feedback)
# combine results
combined = torch.cat([processed, upscale_feedback], dim=1)
combined_shape=combined.shape
combined_patch,num_patch=to_patch(combined)
combined_patch_reduced = self.lambdaNxN_merge(combined_patch)
ret_shape=(combined_shape[0],combined_shape[1]//2,combined_shape[2],combined_shape[3])
ret=combine_patch(combined_patch_reduced,ret_shape,num_patch)
return ret
return process(x) |