# pylint: skip-file # ----------------------------------------------------------------------------------- # SCUNet: Practical Blind Denoising via Swin-Conv-UNet and Data Synthesis, https://arxiv.org/abs/2203.13278 # Zhang, Kai and Li, Yawei and Liang, Jingyun and Cao, Jiezhang and Zhang, Yulun and Tang, Hao and Timofte, Radu and Van Gool, Luc # ----------------------------------------------------------------------------------- import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange from einops.layers.torch import Rearrange from .timm.drop import DropPath from .timm.weight_init import trunc_normal_ # Borrowed from https://github.com/cszn/SCUNet/blob/main/models/network_scunet.py class WMSA(nn.Module): """Self-attention module in Swin Transformer""" def __init__(self, input_dim, output_dim, head_dim, window_size, type): super(WMSA, self).__init__() self.input_dim = input_dim self.output_dim = output_dim self.head_dim = head_dim self.scale = self.head_dim**-0.5 self.n_heads = input_dim // head_dim self.window_size = window_size self.type = type self.embedding_layer = nn.Linear(self.input_dim, 3 * self.input_dim, bias=True) self.relative_position_params = nn.Parameter( torch.zeros((2 * window_size - 1) * (2 * window_size - 1), self.n_heads) ) # TODO recover # self.relative_position_params = nn.Parameter(torch.zeros(self.n_heads, 2 * window_size - 1, 2 * window_size -1)) self.relative_position_params = nn.Parameter( torch.zeros((2 * window_size - 1) * (2 * window_size - 1), self.n_heads) ) self.linear = nn.Linear(self.input_dim, self.output_dim) trunc_normal_(self.relative_position_params, std=0.02) self.relative_position_params = torch.nn.Parameter( self.relative_position_params.view( 2 * window_size - 1, 2 * window_size - 1, self.n_heads ) .transpose(1, 2) .transpose(0, 1) ) def generate_mask(self, h, w, p, shift): """generating the mask of SW-MSA Args: shift: shift parameters in CyclicShift. Returns: attn_mask: should be (1 1 w p p), """ # supporting square. attn_mask = torch.zeros( h, w, p, p, p, p, dtype=torch.bool, device=self.relative_position_params.device, ) if self.type == "W": return attn_mask s = p - shift attn_mask[-1, :, :s, :, s:, :] = True attn_mask[-1, :, s:, :, :s, :] = True attn_mask[:, -1, :, :s, :, s:] = True attn_mask[:, -1, :, s:, :, :s] = True attn_mask = rearrange( attn_mask, "w1 w2 p1 p2 p3 p4 -> 1 1 (w1 w2) (p1 p2) (p3 p4)" ) return attn_mask def forward(self, x): """Forward pass of Window Multi-head Self-attention module. Args: x: input tensor with shape of [b h w c]; attn_mask: attention mask, fill -inf where the value is True; Returns: output: tensor shape [b h w c] """ if self.type != "W": x = torch.roll( x, shifts=(-(self.window_size // 2), -(self.window_size // 2)), dims=(1, 2), ) x = rearrange( x, "b (w1 p1) (w2 p2) c -> b w1 w2 p1 p2 c", p1=self.window_size, p2=self.window_size, ) h_windows = x.size(1) w_windows = x.size(2) # square validation # assert h_windows == w_windows x = rearrange( x, "b w1 w2 p1 p2 c -> b (w1 w2) (p1 p2) c", p1=self.window_size, p2=self.window_size, ) qkv = self.embedding_layer(x) q, k, v = rearrange( qkv, "b nw np (threeh c) -> threeh b nw np c", c=self.head_dim ).chunk(3, dim=0) sim = torch.einsum("hbwpc,hbwqc->hbwpq", q, k) * self.scale # Adding learnable relative embedding sim = sim + rearrange(self.relative_embedding(), "h p q -> h 1 1 p q") # Using Attn Mask to distinguish different subwindows. if self.type != "W": attn_mask = self.generate_mask( h_windows, w_windows, self.window_size, shift=self.window_size // 2 ) sim = sim.masked_fill_(attn_mask, float("-inf")) probs = nn.functional.softmax(sim, dim=-1) output = torch.einsum("hbwij,hbwjc->hbwic", probs, v) output = rearrange(output, "h b w p c -> b w p (h c)") output = self.linear(output) output = rearrange( output, "b (w1 w2) (p1 p2) c -> b (w1 p1) (w2 p2) c", w1=h_windows, p1=self.window_size, ) if self.type != "W": output = torch.roll( output, shifts=(self.window_size // 2, self.window_size // 2), dims=(1, 2), ) return output def relative_embedding(self): cord = torch.tensor( np.array( [ [i, j] for i in range(self.window_size) for j in range(self.window_size) ] ) ) relation = cord[:, None, :] - cord[None, :, :] + self.window_size - 1 # negative is allowed return self.relative_position_params[ :, relation[:, :, 0].long(), relation[:, :, 1].long() ] class Block(nn.Module): def __init__( self, input_dim, output_dim, head_dim, window_size, drop_path, type="W", input_resolution=None, ): """SwinTransformer Block""" super(Block, self).__init__() self.input_dim = input_dim self.output_dim = output_dim assert type in ["W", "SW"] self.type = type if input_resolution <= window_size: self.type = "W" self.ln1 = nn.LayerNorm(input_dim) self.msa = WMSA(input_dim, input_dim, head_dim, window_size, self.type) self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.ln2 = nn.LayerNorm(input_dim) self.mlp = nn.Sequential( nn.Linear(input_dim, 4 * input_dim), nn.GELU(), nn.Linear(4 * input_dim, output_dim), ) def forward(self, x): x = x + self.drop_path(self.msa(self.ln1(x))) x = x + self.drop_path(self.mlp(self.ln2(x))) return x class ConvTransBlock(nn.Module): def __init__( self, conv_dim, trans_dim, head_dim, window_size, drop_path, type="W", input_resolution=None, ): """SwinTransformer and Conv Block""" super(ConvTransBlock, self).__init__() self.conv_dim = conv_dim self.trans_dim = trans_dim self.head_dim = head_dim self.window_size = window_size self.drop_path = drop_path self.type = type self.input_resolution = input_resolution assert self.type in ["W", "SW"] if self.input_resolution <= self.window_size: self.type = "W" self.trans_block = Block( self.trans_dim, self.trans_dim, self.head_dim, self.window_size, self.drop_path, self.type, self.input_resolution, ) self.conv1_1 = nn.Conv2d( self.conv_dim + self.trans_dim, self.conv_dim + self.trans_dim, 1, 1, 0, bias=True, ) self.conv1_2 = nn.Conv2d( self.conv_dim + self.trans_dim, self.conv_dim + self.trans_dim, 1, 1, 0, bias=True, ) self.conv_block = nn.Sequential( nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False), nn.ReLU(True), nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False), ) def forward(self, x): conv_x, trans_x = torch.split( self.conv1_1(x), (self.conv_dim, self.trans_dim), dim=1 ) conv_x = self.conv_block(conv_x) + conv_x trans_x = Rearrange("b c h w -> b h w c")(trans_x) trans_x = self.trans_block(trans_x) trans_x = Rearrange("b h w c -> b c h w")(trans_x) res = self.conv1_2(torch.cat((conv_x, trans_x), dim=1)) x = x + res return x class SCUNet(nn.Module): def __init__( self, state_dict, in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64, drop_path_rate=0.0, input_resolution=256, ): super(SCUNet, self).__init__() self.model_arch = "SCUNet" self.sub_type = "SR" self.num_filters: int = 0 self.state = state_dict self.config = config self.dim = dim self.head_dim = 32 self.window_size = 8 self.in_nc = in_nc self.out_nc = self.in_nc self.scale = 1 self.supports_fp16 = True # drop path rate for each layer dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(config))] self.m_head = [nn.Conv2d(in_nc, dim, 3, 1, 1, bias=False)] begin = 0 self.m_down1 = [ ConvTransBlock( dim // 2, dim // 2, self.head_dim, self.window_size, dpr[i + begin], "W" if not i % 2 else "SW", input_resolution, ) for i in range(config[0]) ] + [nn.Conv2d(dim, 2 * dim, 2, 2, 0, bias=False)] begin += config[0] self.m_down2 = [ ConvTransBlock( dim, dim, self.head_dim, self.window_size, dpr[i + begin], "W" if not i % 2 else "SW", input_resolution // 2, ) for i in range(config[1]) ] + [nn.Conv2d(2 * dim, 4 * dim, 2, 2, 0, bias=False)] begin += config[1] self.m_down3 = [ ConvTransBlock( 2 * dim, 2 * dim, self.head_dim, self.window_size, dpr[i + begin], "W" if not i % 2 else "SW", input_resolution // 4, ) for i in range(config[2]) ] + [nn.Conv2d(4 * dim, 8 * dim, 2, 2, 0, bias=False)] begin += config[2] self.m_body = [ ConvTransBlock( 4 * dim, 4 * dim, self.head_dim, self.window_size, dpr[i + begin], "W" if not i % 2 else "SW", input_resolution // 8, ) for i in range(config[3]) ] begin += config[3] self.m_up3 = [ nn.ConvTranspose2d(8 * dim, 4 * dim, 2, 2, 0, bias=False), ] + [ ConvTransBlock( 2 * dim, 2 * dim, self.head_dim, self.window_size, dpr[i + begin], "W" if not i % 2 else "SW", input_resolution // 4, ) for i in range(config[4]) ] begin += config[4] self.m_up2 = [ nn.ConvTranspose2d(4 * dim, 2 * dim, 2, 2, 0, bias=False), ] + [ ConvTransBlock( dim, dim, self.head_dim, self.window_size, dpr[i + begin], "W" if not i % 2 else "SW", input_resolution // 2, ) for i in range(config[5]) ] begin += config[5] self.m_up1 = [ nn.ConvTranspose2d(2 * dim, dim, 2, 2, 0, bias=False), ] + [ ConvTransBlock( dim // 2, dim // 2, self.head_dim, self.window_size, dpr[i + begin], "W" if not i % 2 else "SW", input_resolution, ) for i in range(config[6]) ] self.m_tail = [nn.Conv2d(dim, in_nc, 3, 1, 1, bias=False)] self.m_head = nn.Sequential(*self.m_head) self.m_down1 = nn.Sequential(*self.m_down1) self.m_down2 = nn.Sequential(*self.m_down2) self.m_down3 = nn.Sequential(*self.m_down3) self.m_body = nn.Sequential(*self.m_body) self.m_up3 = nn.Sequential(*self.m_up3) self.m_up2 = nn.Sequential(*self.m_up2) self.m_up1 = nn.Sequential(*self.m_up1) self.m_tail = nn.Sequential(*self.m_tail) # self.apply(self._init_weights) self.load_state_dict(state_dict, strict=True) def check_image_size(self, x): _, _, h, w = x.size() mod_pad_h = (64 - h % 64) % 64 mod_pad_w = (64 - w % 64) % 64 x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), "reflect") return x def forward(self, x0): h, w = x0.size()[-2:] x0 = self.check_image_size(x0) x1 = self.m_head(x0) x2 = self.m_down1(x1) x3 = self.m_down2(x2) x4 = self.m_down3(x3) x = self.m_body(x4) x = self.m_up3(x + x4) x = self.m_up2(x + x3) x = self.m_up1(x + x2) x = self.m_tail(x + x1) x = x[:, :, :h, :w] return x def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=0.02) if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0)