hyliu's picture
Upload folder using huggingface_hub
8cb1339 verified
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange
# helpers functions
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
# lambda layer
class LambdaLayer(nn.Module):
def __init__(
self,
dim,
*,
dim_k,
n = None,
r = None,
heads = 4,
dim_out = None,
dim_u = 1,
norm="batch"):
super().__init__()
dim_out = default(dim_out, dim)
self.u = dim_u # intra-depth dimension
self.heads = heads
assert (dim_out % heads) == 0, 'values dimension must be divisible by number of heads for multi-head query'
dim_v = dim_out // heads
self.to_q = nn.Conv2d(dim, dim_k * heads, 1, bias = False)
self.to_k = nn.Conv2d(dim, dim_k * dim_u, 1, bias = False)
self.to_v = nn.Conv2d(dim, dim_v * dim_u, 1, bias = False)
if norm=="instance":
self.norm_q = nn.InstanceNorm2d(dim_k * heads)
self.norm_v = nn.InstanceNorm2d(dim_v * dim_u)
else:
self.norm_q = nn.BatchNorm2d(dim_k * heads)
self.norm_v = nn.BatchNorm2d(dim_v * dim_u)
self.local_contexts = exists(r)
if exists(r):
assert (r % 2) == 1, 'Receptive kernel size should be odd'
self.pos_conv = nn.Conv3d(dim_u, dim_k, (1, r, r), padding = (0, r // 2, r // 2))
else:
assert exists(n), 'You must specify the total sequence length (h x w)'
self.pos_emb = nn.Parameter(torch.randn(n, n, dim_k, dim_u))
def forward(self, x):
b, c, hh, ww, u, h = *x.shape, self.u, self.heads
q = self.to_q(x)
k = self.to_k(x)
v = self.to_v(x)
q = self.norm_q(q)
v = self.norm_v(v)
q = rearrange(q, 'b (h k) hh ww -> b h k (hh ww)', h = h)
k = rearrange(k, 'b (u k) hh ww -> b u k (hh ww)', u = u)
v = rearrange(v, 'b (u v) hh ww -> b u v (hh ww)', u = u)
k = k.softmax(dim=-1)
位c = einsum('b u k m, b u v m -> b k v', k, v)
Yc = einsum('b h k n, b k v -> b h v n', q, 位c)
if self.local_contexts:
v = rearrange(v, 'b u v (hh ww) -> b u v hh ww', hh = hh, ww = ww)
位p = self.pos_conv(v)
Yp = einsum('b h k n, b k v n -> b h v n', q, 位p.flatten(3))
else:
位p = einsum('n m k u, b u v m -> b n k v', self.pos_emb, v)
Yp = einsum('b h k n, b n k v -> b h v n', q, 位p)
Y = Yc + Yp
out = rearrange(Y, 'b h v (hh ww) -> b (h v) hh ww', hh = hh, ww = ww)
return out
# i'm not sure whether this will work or not
class Recursion(nn.Module):
def __init__(self, N: int, hidden_dim:int=64):
super(Recursion,self).__init__()
self.N = N
self.lambdaNxN_identity = LambdaLayer(dim=hidden_dim, dim_out=hidden_dim, n=N * N, dim_k=16, heads=2, dim_u=1)
# merge upstream information here
self.lambdaNxN_merge = LambdaLayer(dim=2*hidden_dim, dim_out=hidden_dim, n=N * N, dim_k=16, heads=2, dim_u=1)
self.downscale_conv = nn.Conv2d(hidden_dim, hidden_dim, kernel_size=N, stride=N)
self.upscale_conv = nn.Conv2d(hidden_dim, hidden_dim * N * N, kernel_size=3,padding=1)
self.pixel_shuffle = nn.PixelShuffle(N)
def forward(self, x: torch.Tensor):
N = self.N
def to_patch(blocks:torch.Tensor)->torch.Tensor:
shape = blocks.shape
blocks_patch = F.unfold(blocks, kernel_size=N, stride=N)
blocks_patch = blocks_patch.view(shape[0], shape[1], N, N, -1)
num_patch = blocks_patch.shape[-1]
blocks_patch = blocks_patch.permute(0, 4, 1, 2, 3).reshape(-1, shape[1], N, N).contiguous()
return blocks_patch, num_patch
def combine_patch(processed_patch,shape,num_patch):
processed_patch = processed_patch.reshape(shape[0], num_patch, shape[1], N, N)
processed_patch=processed_patch.permute(0, 2, 3, 4, 1).reshape(shape[0],shape[1] * N * N,num_patch).contiguous()
processed=F.fold(processed_patch,output_size=(shape[-2],shape[-1]),kernel_size=N,stride=N)
return processed
def process(blocks:torch.Tensor)->torch.Tensor:
shape = blocks.shape
if blocks.shape[-1] == N:
processed = self.lambdaNxN_identity(blocks)
return processed
# to NxN patchs
blocks_patch,num_patch=to_patch(blocks)
# pass through identity
processed_patch = self.lambdaNxN_identity(blocks_patch)
# back to HxW
processed=combine_patch(processed_patch,shape,num_patch)
# get feedback
feedback = process(self.downscale_conv(processed))
# upscale feedback
upscale_feedback = self.upscale_conv(feedback)
upscale_feedback=self.pixel_shuffle(upscale_feedback)
# combine results
combined = torch.cat([processed, upscale_feedback], dim=1)
combined_shape=combined.shape
combined_patch,num_patch=to_patch(combined)
combined_patch_reduced = self.lambdaNxN_merge(combined_patch)
ret_shape=(combined_shape[0],combined_shape[1]//2,combined_shape[2],combined_shape[3])
ret=combine_patch(combined_patch_reduced,ret_shape,num_patch)
return ret
return process(x)