FunSR / models /swin_backbone.py
KyanChen's picture
Upload 147 files
2e3372e
from models import register
from einops import rearrange
from mmcv.cnn import build_norm_layer
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as cp
@register('Swin_backbone')
class SwinTransformerBackbone(nn.Module):
def __init__(self, in_channels=3, embed_dims=256, depth=4, drop_path_rate=0.1):
super().__init__()
self.patch_embed = PatchEmbed(
in_channels=in_channels,
embed_dims=embed_dims,
conv_type='Conv2d',
kernel_size=5,
stride=1,
padding='same',
norm_cfg=dict(type='LN')
)
dpr = [
x.item() for x in torch.linspace(0, drop_path_rate, depth)
]
self.stage = SwinBlockSequence(
embed_dims=embed_dims,
num_heads=8,
feedforward_channels=embed_dims * 2,
depth=depth,
window_size=4,
drop_path_rate=dpr,
downsample=None
)
self.norm_layer = build_norm_layer(dict(type='LN'), embed_dims)[1]
self.out_dim = embed_dims
def forward(self, x):
x, hw_shape = self.patch_embed(x)
x, hw_shape, out, out_hw_shape = self.stage(x, hw_shape)
out = self.norm_layer(out)
x = out.view(-1, *out_hw_shape, self.out_dim).permute(0, 3, 1, 2).contiguous()
return x