EmaadKhwaja
file upload
5d2263b
raw
history blame
7.09 kB
import torch
import torch.nn as nn
import numpy as np
class AbstractPermuter(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
def forward(self, x, reverse=False):
raise NotImplementedError
class Identity(AbstractPermuter):
def __init__(self):
super().__init__()
def forward(self, x, reverse=False):
return x
class Subsample(AbstractPermuter):
def __init__(self, H, W):
super().__init__()
C = 1
indices = np.arange(H*W).reshape(C,H,W)
while min(H, W) > 1:
indices = indices.reshape(C,H//2,2,W//2,2)
indices = indices.transpose(0,2,4,1,3)
indices = indices.reshape(C*4,H//2, W//2)
H = H//2
W = W//2
C = C*4
assert H == W == 1
idx = torch.tensor(indices.ravel())
self.register_buffer('forward_shuffle_idx',
nn.Parameter(idx, requires_grad=False))
self.register_buffer('backward_shuffle_idx',
nn.Parameter(torch.argsort(idx), requires_grad=False))
def forward(self, x, reverse=False):
if not reverse:
return x[:, self.forward_shuffle_idx]
else:
return x[:, self.backward_shuffle_idx]
def mortonify(i, j):
"""(i,j) index to linear morton code"""
i = np.uint64(i)
j = np.uint64(j)
z = np.uint(0)
for pos in range(32):
z = (z |
((j & (np.uint64(1) << np.uint64(pos))) << np.uint64(pos)) |
((i & (np.uint64(1) << np.uint64(pos))) << np.uint64(pos+1))
)
return z
class ZCurve(AbstractPermuter):
def __init__(self, H, W):
super().__init__()
reverseidx = [np.int64(mortonify(i,j)) for i in range(H) for j in range(W)]
idx = np.argsort(reverseidx)
idx = torch.tensor(idx)
reverseidx = torch.tensor(reverseidx)
self.register_buffer('forward_shuffle_idx',
idx)
self.register_buffer('backward_shuffle_idx',
reverseidx)
def forward(self, x, reverse=False):
if not reverse:
return x[:, self.forward_shuffle_idx]
else:
return x[:, self.backward_shuffle_idx]
class SpiralOut(AbstractPermuter):
def __init__(self, H, W):
super().__init__()
assert H == W
size = W
indices = np.arange(size*size).reshape(size,size)
i0 = size//2
j0 = size//2-1
i = i0
j = j0
idx = [indices[i0, j0]]
step_mult = 0
for c in range(1, size//2+1):
step_mult += 1
# steps left
for k in range(step_mult):
i = i - 1
j = j
idx.append(indices[i, j])
# step down
for k in range(step_mult):
i = i
j = j + 1
idx.append(indices[i, j])
step_mult += 1
if c < size//2:
# step right
for k in range(step_mult):
i = i + 1
j = j
idx.append(indices[i, j])
# step up
for k in range(step_mult):
i = i
j = j - 1
idx.append(indices[i, j])
else:
# end reached
for k in range(step_mult-1):
i = i + 1
idx.append(indices[i, j])
assert len(idx) == size*size
idx = torch.tensor(idx)
self.register_buffer('forward_shuffle_idx', idx)
self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
def forward(self, x, reverse=False):
if not reverse:
return x[:, self.forward_shuffle_idx]
else:
return x[:, self.backward_shuffle_idx]
class SpiralIn(AbstractPermuter):
def __init__(self, H, W):
super().__init__()
assert H == W
size = W
indices = np.arange(size*size).reshape(size,size)
i0 = size//2
j0 = size//2-1
i = i0
j = j0
idx = [indices[i0, j0]]
step_mult = 0
for c in range(1, size//2+1):
step_mult += 1
# steps left
for k in range(step_mult):
i = i - 1
j = j
idx.append(indices[i, j])
# step down
for k in range(step_mult):
i = i
j = j + 1
idx.append(indices[i, j])
step_mult += 1
if c < size//2:
# step right
for k in range(step_mult):
i = i + 1
j = j
idx.append(indices[i, j])
# step up
for k in range(step_mult):
i = i
j = j - 1
idx.append(indices[i, j])
else:
# end reached
for k in range(step_mult-1):
i = i + 1
idx.append(indices[i, j])
assert len(idx) == size*size
idx = idx[::-1]
idx = torch.tensor(idx)
self.register_buffer('forward_shuffle_idx', idx)
self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
def forward(self, x, reverse=False):
if not reverse:
return x[:, self.forward_shuffle_idx]
else:
return x[:, self.backward_shuffle_idx]
class Random(nn.Module):
def __init__(self, H, W):
super().__init__()
indices = np.random.RandomState(1).permutation(H*W)
idx = torch.tensor(indices.ravel())
self.register_buffer('forward_shuffle_idx', idx)
self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
def forward(self, x, reverse=False):
if not reverse:
return x[:, self.forward_shuffle_idx]
else:
return x[:, self.backward_shuffle_idx]
class AlternateParsing(AbstractPermuter):
def __init__(self, H, W):
super().__init__()
indices = np.arange(W*H).reshape(H,W)
for i in range(1, H, 2):
indices[i, :] = indices[i, ::-1]
idx = indices.flatten()
assert len(idx) == H*W
idx = torch.tensor(idx)
self.register_buffer('forward_shuffle_idx', idx)
self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
def forward(self, x, reverse=False):
if not reverse:
return x[:, self.forward_shuffle_idx]
else:
return x[:, self.backward_shuffle_idx]
if __name__ == "__main__":
p0 = AlternateParsing(16, 16)
print(p0.forward_shuffle_idx)
print(p0.backward_shuffle_idx)
x = torch.randint(0, 768, size=(11, 256))
y = p0(x)
xre = p0(y, reverse=True)
assert torch.equal(x, xre)
p1 = SpiralOut(2, 2)
print(p1.forward_shuffle_idx)
print(p1.backward_shuffle_idx)