|
|
|
|
|
|
|
|
|
from pdb import set_trace as bb |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
""" From the ICLR22 paper: Patches are all you need |
|
https://openreview.net/pdf?id=TVHS5Y4dNvM |
|
""" |
|
|
|
class Residual(nn.Module): |
|
def __init__(self, fn, stride=1): |
|
super().__init__() |
|
self.fn = fn |
|
self.stride = stride |
|
|
|
def forward(self, x): |
|
s = slice(None,None,self.stride) |
|
return x[:,:,s,s] + self.fn(x)[:,:,s,s] |
|
|
|
|
|
class ConvMixer (nn.Sequential): |
|
""" Modified ConvMixer with convolutional layers at the bottom. |
|
|
|
From the ICLR22 paper: Patches are all you need, https://openreview.net/pdf?id=TVHS5Y4dNvM |
|
""" |
|
def __init__(self, output_dim, hidden_dim, |
|
depth=None, kernel_size=5, patch_size=8, group_size=1, |
|
preconv=1, faster=True, relu=nn.ReLU): |
|
|
|
assert kernel_size % 2 == 1, 'kernel_size must be odd' |
|
output_step = 1 + faster |
|
assert patch_size % output_step == 0, f'patch_size must be multiple of {output_step}' |
|
self.patch_size = patch_size |
|
|
|
hidden_dims = [hidden_dim//4]*preconv + [hidden_dim]*(depth+1) |
|
ops = [ |
|
nn.Conv2d(3, hidden_dims[0], kernel_size=5, padding=2), |
|
relu(), |
|
nn.BatchNorm2d(hidden_dims[0])] |
|
|
|
for _ in range(1,preconv): |
|
ops += [ |
|
nn.Conv2d(hidden_dims.pop(0), hidden_dims[0], kernel_size=3, padding=1), |
|
relu(), |
|
nn.BatchNorm2d(hidden_dims[0])] |
|
|
|
ops += [ |
|
nn.Conv2d(hidden_dims.pop(0), hidden_dims[0], kernel_size=patch_size, stride=patch_size), |
|
relu(), |
|
nn.BatchNorm2d(hidden_dims[0])] |
|
|
|
for idim, odim in zip(hidden_dims[0:], hidden_dims[1:]): |
|
ops += [Residual(nn.Sequential( |
|
nn.Conv2d(idim, idim, kernel_size, groups=max(1,idim//group_size), padding=kernel_size//2), |
|
relu(), |
|
nn.BatchNorm2d(idim) |
|
)), |
|
nn.Conv2d(idim, odim, kernel_size=1), |
|
relu(), |
|
nn.BatchNorm2d(odim)] |
|
ops += [ |
|
nn.Conv2d(odim, output_dim*(patch_size//output_step)**2, kernel_size=1), |
|
nn.PixelShuffle( patch_size//output_step ), |
|
nn.Upsample(scale_factor=output_step, mode='bilinear', align_corners=False)] |
|
|
|
super().__init__(*ops) |
|
|
|
def forward(self, img): |
|
assert img.ndim == 4 |
|
B, C, H, W = img.shape |
|
desc = super().forward(img) |
|
return F.normalize(desc, dim=-3) |
|
|
|
|
|
if __name__ == '__main__': |
|
net = ConvMixer3(128, 512, 7, patch_size=4, kernel_size=9) |
|
print(net) |
|
|
|
img = torch.rand(2,3,256,256) |
|
print('input.shape =', img.shape) |
|
desc = net(img) |
|
print('desc.shape =', desc.shape) |
|
|