Spaces:
Runtime error
Runtime error
| # pylint: skip-file | |
| # ----------------------------------------------------------------------------------- | |
| # SCUNet: Practical Blind Denoising via Swin-Conv-UNet and Data Synthesis, https://arxiv.org/abs/2203.13278 | |
| # Zhang, Kai and Li, Yawei and Liang, Jingyun and Cao, Jiezhang and Zhang, Yulun and Tang, Hao and Timofte, Radu and Van Gool, Luc | |
| # ----------------------------------------------------------------------------------- | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from einops import rearrange | |
| from einops.layers.torch import Rearrange | |
| from .timm.drop import DropPath | |
| from .timm.weight_init import trunc_normal_ | |
| # Borrowed from https://github.com/cszn/SCUNet/blob/main/models/network_scunet.py | |
| class WMSA(nn.Module): | |
| """Self-attention module in Swin Transformer""" | |
| def __init__(self, input_dim, output_dim, head_dim, window_size, type): | |
| super(WMSA, self).__init__() | |
| self.input_dim = input_dim | |
| self.output_dim = output_dim | |
| self.head_dim = head_dim | |
| self.scale = self.head_dim**-0.5 | |
| self.n_heads = input_dim // head_dim | |
| self.window_size = window_size | |
| self.type = type | |
| self.embedding_layer = nn.Linear(self.input_dim, 3 * self.input_dim, bias=True) | |
| self.relative_position_params = nn.Parameter( | |
| torch.zeros((2 * window_size - 1) * (2 * window_size - 1), self.n_heads) | |
| ) | |
| # TODO recover | |
| # self.relative_position_params = nn.Parameter(torch.zeros(self.n_heads, 2 * window_size - 1, 2 * window_size -1)) | |
| self.relative_position_params = nn.Parameter( | |
| torch.zeros((2 * window_size - 1) * (2 * window_size - 1), self.n_heads) | |
| ) | |
| self.linear = nn.Linear(self.input_dim, self.output_dim) | |
| trunc_normal_(self.relative_position_params, std=0.02) | |
| self.relative_position_params = torch.nn.Parameter( | |
| self.relative_position_params.view( | |
| 2 * window_size - 1, 2 * window_size - 1, self.n_heads | |
| ) | |
| .transpose(1, 2) | |
| .transpose(0, 1) | |
| ) | |
| def generate_mask(self, h, w, p, shift): | |
| """generating the mask of SW-MSA | |
| Args: | |
| shift: shift parameters in CyclicShift. | |
| Returns: | |
| attn_mask: should be (1 1 w p p), | |
| """ | |
| # supporting square. | |
| attn_mask = torch.zeros( | |
| h, | |
| w, | |
| p, | |
| p, | |
| p, | |
| p, | |
| dtype=torch.bool, | |
| device=self.relative_position_params.device, | |
| ) | |
| if self.type == "W": | |
| return attn_mask | |
| s = p - shift | |
| attn_mask[-1, :, :s, :, s:, :] = True | |
| attn_mask[-1, :, s:, :, :s, :] = True | |
| attn_mask[:, -1, :, :s, :, s:] = True | |
| attn_mask[:, -1, :, s:, :, :s] = True | |
| attn_mask = rearrange( | |
| attn_mask, "w1 w2 p1 p2 p3 p4 -> 1 1 (w1 w2) (p1 p2) (p3 p4)" | |
| ) | |
| return attn_mask | |
| def forward(self, x): | |
| """Forward pass of Window Multi-head Self-attention module. | |
| Args: | |
| x: input tensor with shape of [b h w c]; | |
| attn_mask: attention mask, fill -inf where the value is True; | |
| Returns: | |
| output: tensor shape [b h w c] | |
| """ | |
| if self.type != "W": | |
| x = torch.roll( | |
| x, | |
| shifts=(-(self.window_size // 2), -(self.window_size // 2)), | |
| dims=(1, 2), | |
| ) | |
| x = rearrange( | |
| x, | |
| "b (w1 p1) (w2 p2) c -> b w1 w2 p1 p2 c", | |
| p1=self.window_size, | |
| p2=self.window_size, | |
| ) | |
| h_windows = x.size(1) | |
| w_windows = x.size(2) | |
| # square validation | |
| # assert h_windows == w_windows | |
| x = rearrange( | |
| x, | |
| "b w1 w2 p1 p2 c -> b (w1 w2) (p1 p2) c", | |
| p1=self.window_size, | |
| p2=self.window_size, | |
| ) | |
| qkv = self.embedding_layer(x) | |
| q, k, v = rearrange( | |
| qkv, "b nw np (threeh c) -> threeh b nw np c", c=self.head_dim | |
| ).chunk(3, dim=0) | |
| sim = torch.einsum("hbwpc,hbwqc->hbwpq", q, k) * self.scale | |
| # Adding learnable relative embedding | |
| sim = sim + rearrange(self.relative_embedding(), "h p q -> h 1 1 p q") | |
| # Using Attn Mask to distinguish different subwindows. | |
| if self.type != "W": | |
| attn_mask = self.generate_mask( | |
| h_windows, w_windows, self.window_size, shift=self.window_size // 2 | |
| ) | |
| sim = sim.masked_fill_(attn_mask, float("-inf")) | |
| probs = nn.functional.softmax(sim, dim=-1) | |
| output = torch.einsum("hbwij,hbwjc->hbwic", probs, v) | |
| output = rearrange(output, "h b w p c -> b w p (h c)") | |
| output = self.linear(output) | |
| output = rearrange( | |
| output, | |
| "b (w1 w2) (p1 p2) c -> b (w1 p1) (w2 p2) c", | |
| w1=h_windows, | |
| p1=self.window_size, | |
| ) | |
| if self.type != "W": | |
| output = torch.roll( | |
| output, | |
| shifts=(self.window_size // 2, self.window_size // 2), | |
| dims=(1, 2), | |
| ) | |
| return output | |
| def relative_embedding(self): | |
| cord = torch.tensor( | |
| np.array( | |
| [ | |
| [i, j] | |
| for i in range(self.window_size) | |
| for j in range(self.window_size) | |
| ] | |
| ) | |
| ) | |
| relation = cord[:, None, :] - cord[None, :, :] + self.window_size - 1 | |
| # negative is allowed | |
| return self.relative_position_params[ | |
| :, relation[:, :, 0].long(), relation[:, :, 1].long() | |
| ] | |
| class Block(nn.Module): | |
| def __init__( | |
| self, | |
| input_dim, | |
| output_dim, | |
| head_dim, | |
| window_size, | |
| drop_path, | |
| type="W", | |
| input_resolution=None, | |
| ): | |
| """SwinTransformer Block""" | |
| super(Block, self).__init__() | |
| self.input_dim = input_dim | |
| self.output_dim = output_dim | |
| assert type in ["W", "SW"] | |
| self.type = type | |
| if input_resolution <= window_size: | |
| self.type = "W" | |
| self.ln1 = nn.LayerNorm(input_dim) | |
| self.msa = WMSA(input_dim, input_dim, head_dim, window_size, self.type) | |
| self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() | |
| self.ln2 = nn.LayerNorm(input_dim) | |
| self.mlp = nn.Sequential( | |
| nn.Linear(input_dim, 4 * input_dim), | |
| nn.GELU(), | |
| nn.Linear(4 * input_dim, output_dim), | |
| ) | |
| def forward(self, x): | |
| x = x + self.drop_path(self.msa(self.ln1(x))) | |
| x = x + self.drop_path(self.mlp(self.ln2(x))) | |
| return x | |
| class ConvTransBlock(nn.Module): | |
| def __init__( | |
| self, | |
| conv_dim, | |
| trans_dim, | |
| head_dim, | |
| window_size, | |
| drop_path, | |
| type="W", | |
| input_resolution=None, | |
| ): | |
| """SwinTransformer and Conv Block""" | |
| super(ConvTransBlock, self).__init__() | |
| self.conv_dim = conv_dim | |
| self.trans_dim = trans_dim | |
| self.head_dim = head_dim | |
| self.window_size = window_size | |
| self.drop_path = drop_path | |
| self.type = type | |
| self.input_resolution = input_resolution | |
| assert self.type in ["W", "SW"] | |
| if self.input_resolution <= self.window_size: | |
| self.type = "W" | |
| self.trans_block = Block( | |
| self.trans_dim, | |
| self.trans_dim, | |
| self.head_dim, | |
| self.window_size, | |
| self.drop_path, | |
| self.type, | |
| self.input_resolution, | |
| ) | |
| self.conv1_1 = nn.Conv2d( | |
| self.conv_dim + self.trans_dim, | |
| self.conv_dim + self.trans_dim, | |
| 1, | |
| 1, | |
| 0, | |
| bias=True, | |
| ) | |
| self.conv1_2 = nn.Conv2d( | |
| self.conv_dim + self.trans_dim, | |
| self.conv_dim + self.trans_dim, | |
| 1, | |
| 1, | |
| 0, | |
| bias=True, | |
| ) | |
| self.conv_block = nn.Sequential( | |
| nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False), | |
| nn.ReLU(True), | |
| nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False), | |
| ) | |
| def forward(self, x): | |
| conv_x, trans_x = torch.split( | |
| self.conv1_1(x), (self.conv_dim, self.trans_dim), dim=1 | |
| ) | |
| conv_x = self.conv_block(conv_x) + conv_x | |
| trans_x = Rearrange("b c h w -> b h w c")(trans_x) | |
| trans_x = self.trans_block(trans_x) | |
| trans_x = Rearrange("b h w c -> b c h w")(trans_x) | |
| res = self.conv1_2(torch.cat((conv_x, trans_x), dim=1)) | |
| x = x + res | |
| return x | |
| class SCUNet(nn.Module): | |
| def __init__( | |
| self, | |
| state_dict, | |
| in_nc=3, | |
| config=[4, 4, 4, 4, 4, 4, 4], | |
| dim=64, | |
| drop_path_rate=0.0, | |
| input_resolution=256, | |
| ): | |
| super(SCUNet, self).__init__() | |
| self.model_arch = "SCUNet" | |
| self.sub_type = "SR" | |
| self.num_filters: int = 0 | |
| self.state = state_dict | |
| self.config = config | |
| self.dim = dim | |
| self.head_dim = 32 | |
| self.window_size = 8 | |
| self.in_nc = in_nc | |
| self.out_nc = self.in_nc | |
| self.scale = 1 | |
| self.supports_fp16 = True | |
| # drop path rate for each layer | |
| dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(config))] | |
| self.m_head = [nn.Conv2d(in_nc, dim, 3, 1, 1, bias=False)] | |
| begin = 0 | |
| self.m_down1 = [ | |
| ConvTransBlock( | |
| dim // 2, | |
| dim // 2, | |
| self.head_dim, | |
| self.window_size, | |
| dpr[i + begin], | |
| "W" if not i % 2 else "SW", | |
| input_resolution, | |
| ) | |
| for i in range(config[0]) | |
| ] + [nn.Conv2d(dim, 2 * dim, 2, 2, 0, bias=False)] | |
| begin += config[0] | |
| self.m_down2 = [ | |
| ConvTransBlock( | |
| dim, | |
| dim, | |
| self.head_dim, | |
| self.window_size, | |
| dpr[i + begin], | |
| "W" if not i % 2 else "SW", | |
| input_resolution // 2, | |
| ) | |
| for i in range(config[1]) | |
| ] + [nn.Conv2d(2 * dim, 4 * dim, 2, 2, 0, bias=False)] | |
| begin += config[1] | |
| self.m_down3 = [ | |
| ConvTransBlock( | |
| 2 * dim, | |
| 2 * dim, | |
| self.head_dim, | |
| self.window_size, | |
| dpr[i + begin], | |
| "W" if not i % 2 else "SW", | |
| input_resolution // 4, | |
| ) | |
| for i in range(config[2]) | |
| ] + [nn.Conv2d(4 * dim, 8 * dim, 2, 2, 0, bias=False)] | |
| begin += config[2] | |
| self.m_body = [ | |
| ConvTransBlock( | |
| 4 * dim, | |
| 4 * dim, | |
| self.head_dim, | |
| self.window_size, | |
| dpr[i + begin], | |
| "W" if not i % 2 else "SW", | |
| input_resolution // 8, | |
| ) | |
| for i in range(config[3]) | |
| ] | |
| begin += config[3] | |
| self.m_up3 = [ | |
| nn.ConvTranspose2d(8 * dim, 4 * dim, 2, 2, 0, bias=False), | |
| ] + [ | |
| ConvTransBlock( | |
| 2 * dim, | |
| 2 * dim, | |
| self.head_dim, | |
| self.window_size, | |
| dpr[i + begin], | |
| "W" if not i % 2 else "SW", | |
| input_resolution // 4, | |
| ) | |
| for i in range(config[4]) | |
| ] | |
| begin += config[4] | |
| self.m_up2 = [ | |
| nn.ConvTranspose2d(4 * dim, 2 * dim, 2, 2, 0, bias=False), | |
| ] + [ | |
| ConvTransBlock( | |
| dim, | |
| dim, | |
| self.head_dim, | |
| self.window_size, | |
| dpr[i + begin], | |
| "W" if not i % 2 else "SW", | |
| input_resolution // 2, | |
| ) | |
| for i in range(config[5]) | |
| ] | |
| begin += config[5] | |
| self.m_up1 = [ | |
| nn.ConvTranspose2d(2 * dim, dim, 2, 2, 0, bias=False), | |
| ] + [ | |
| ConvTransBlock( | |
| dim // 2, | |
| dim // 2, | |
| self.head_dim, | |
| self.window_size, | |
| dpr[i + begin], | |
| "W" if not i % 2 else "SW", | |
| input_resolution, | |
| ) | |
| for i in range(config[6]) | |
| ] | |
| self.m_tail = [nn.Conv2d(dim, in_nc, 3, 1, 1, bias=False)] | |
| self.m_head = nn.Sequential(*self.m_head) | |
| self.m_down1 = nn.Sequential(*self.m_down1) | |
| self.m_down2 = nn.Sequential(*self.m_down2) | |
| self.m_down3 = nn.Sequential(*self.m_down3) | |
| self.m_body = nn.Sequential(*self.m_body) | |
| self.m_up3 = nn.Sequential(*self.m_up3) | |
| self.m_up2 = nn.Sequential(*self.m_up2) | |
| self.m_up1 = nn.Sequential(*self.m_up1) | |
| self.m_tail = nn.Sequential(*self.m_tail) | |
| # self.apply(self._init_weights) | |
| self.load_state_dict(state_dict, strict=True) | |
| def check_image_size(self, x): | |
| _, _, h, w = x.size() | |
| mod_pad_h = (64 - h % 64) % 64 | |
| mod_pad_w = (64 - w % 64) % 64 | |
| x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), "reflect") | |
| return x | |
| def forward(self, x0): | |
| h, w = x0.size()[-2:] | |
| x0 = self.check_image_size(x0) | |
| x1 = self.m_head(x0) | |
| x2 = self.m_down1(x1) | |
| x3 = self.m_down2(x2) | |
| x4 = self.m_down3(x3) | |
| x = self.m_body(x4) | |
| x = self.m_up3(x + x4) | |
| x = self.m_up2(x + x3) | |
| x = self.m_up1(x + x2) | |
| x = self.m_tail(x + x1) | |
| x = x[:, :, :h, :w] | |
| return x | |
| def _init_weights(self, m): | |
| if isinstance(m, nn.Linear): | |
| trunc_normal_(m.weight, std=0.02) | |
| if m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| elif isinstance(m, nn.LayerNorm): | |
| nn.init.constant_(m.bias, 0) | |
| nn.init.constant_(m.weight, 1.0) | |