from typing import Tuple, List from torch import Tensor import torch import torch.nn as nn import torch.nn.functional as F from einops.layers.torch import Rearrange ###################### # Meta Architecture ###################### class SeemoRe(nn.Module): def __init__(self, scale: int = 4, in_chans: int = 3, num_experts: int = 6, num_layers: int = 6, embedding_dim: int = 64, img_range: float = 1.0, use_shuffle: bool = False, global_kernel_size: int = 11, recursive: int = 2, lr_space: int = 1, topk: int = 2,): super().__init__() self.scale = scale self.num_in_channels = in_chans self.num_out_channels = in_chans self.img_range = img_range rgb_mean = (0.4488, 0.4371, 0.4040) self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) # -- SHALLOW FEATURES -- self.conv_1 = nn.Conv2d(self.num_in_channels, embedding_dim, kernel_size=3, padding=1) # -- DEEP FEATURES -- self.body = nn.ModuleList( [ResGroup(in_ch=embedding_dim, num_experts=num_experts, use_shuffle=use_shuffle, topk=topk, lr_space=lr_space, recursive=recursive, global_kernel_size=global_kernel_size) for i in range(num_layers)] ) # -- UPSCALE -- self.norm = LayerNorm(embedding_dim, data_format='channels_first') self.conv_2 = nn.Conv2d(embedding_dim, embedding_dim, kernel_size=3, padding=1) self.upsampler = nn.Sequential( nn.Conv2d(embedding_dim, (scale**2) * self.num_out_channels, kernel_size=3, padding=1), nn.PixelShuffle(scale) ) def forward(self, x: torch.Tensor) -> torch.Tensor: self.mean = self.mean.type_as(x) x = (x - self.mean) * self.img_range # -- SHALLOW FEATURES -- x = self.conv_1(x) res = x # -- DEEP FEATURES -- for idx, layer in enumerate(self.body): x = layer(x) x = self.norm(x) # -- HR IMAGE RECONSTRUCTION -- x = self.conv_2(x) + res x = self.upsampler(x) x = x / self.img_range + self.mean return x ############################# # Components ############################# class ResGroup(nn.Module): def __init__(self, in_ch: int, num_experts: int, global_kernel_size: int = 11, lr_space: int = 1, topk: int = 2, recursive: int = 2, use_shuffle: bool = False): super().__init__() self.local_block = RME(in_ch=in_ch, num_experts=num_experts, use_shuffle=use_shuffle, lr_space=lr_space, topk=topk, recursive=recursive) self.global_block = SME(in_ch=in_ch, kernel_size=global_kernel_size) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.local_block(x) x = self.global_block(x) return x ############################# # Global Block ############################# class SME(nn.Module): def __init__(self, in_ch: int, kernel_size: int = 11): super().__init__() self.norm_1 = LayerNorm(in_ch, data_format='channels_first') self.block = StripedConvFormer(in_ch=in_ch, kernel_size=kernel_size) self.norm_2 = LayerNorm(in_ch, data_format='channels_first') self.ffn = GatedFFN(in_ch, mlp_ratio=2, kernel_size=3, act_layer=nn.GELU()) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.block(self.norm_1(x)) + x x = self.ffn(self.norm_2(x)) + x return x class StripedConvFormer(nn.Module): def __init__(self, in_ch: int, kernel_size: int): super().__init__() self.in_ch = in_ch self.kernel_size = kernel_size self.padding = kernel_size // 2 self.proj = nn.Conv2d(in_ch, in_ch, kernel_size=1, padding=0) self.to_qv = nn.Sequential( nn.Conv2d(in_ch, in_ch * 2, kernel_size=1, padding=0), nn.GELU(), ) self.attn = StripedConv2d(in_ch, kernel_size=kernel_size, depthwise=True) def forward(self, x: torch.Tensor) -> torch.Tensor: q, v = self.to_qv(x).chunk(2, dim=1) q = self.attn(q) x = self.proj(q * v) return x ############################# # Local Blocks ############################# class RME(nn.Module): def __init__(self, in_ch: int, num_experts: int, topk: int, lr_space: int = 1, recursive: int = 2, use_shuffle: bool = False,): super().__init__() self.norm_1 = LayerNorm(in_ch, data_format='channels_first') self.block = MoEBlock(in_ch=in_ch, num_experts=num_experts, topk=topk, use_shuffle=use_shuffle, recursive=recursive, lr_space=lr_space,) self.norm_2 = LayerNorm(in_ch, data_format='channels_first') self.ffn = GatedFFN(in_ch, mlp_ratio=2, kernel_size=3, act_layer=nn.GELU()) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.block(self.norm_1(x)) + x x = self.ffn(self.norm_2(x)) + x return x ################# # MoE Layer ################# class MoEBlock(nn.Module): def __init__(self, in_ch: int, num_experts: int, topk: int, use_shuffle: bool = False, lr_space: str = "linear", recursive: int = 2): super().__init__() self.use_shuffle = use_shuffle self.recursive = recursive self.conv_1 = nn.Sequential( nn.Conv2d(in_ch, in_ch, kernel_size=3, padding=1), nn.GELU(), nn.Conv2d(in_ch, 2*in_ch, kernel_size=1, padding=0) ) self.agg_conv = nn.Sequential( nn.Conv2d(in_ch, in_ch, kernel_size=4, stride=4, groups=in_ch), nn.GELU()) self.conv = nn.Sequential( nn.Conv2d(in_ch, in_ch, kernel_size=3, stride=1, padding=1, groups=in_ch), nn.Conv2d(in_ch, in_ch, kernel_size=1, padding=0) ) self.conv_2 = nn.Sequential( StripedConv2d(in_ch, kernel_size=3, depthwise=True), nn.GELU()) if lr_space == "linear": grow_func = lambda i: i+2 elif lr_space == "exp": grow_func = lambda i: 2**(i+1) elif lr_space == "double": grow_func = lambda i: 2*i+2 else: raise NotImplementedError(f"lr_space {lr_space} not implemented") self.moe_layer = MoELayer( experts=[Expert(in_ch=in_ch, low_dim=grow_func(i)) for i in range(num_experts)], # add here multiple of 2 as low_dim gate=Router(in_ch=in_ch, num_experts=num_experts), num_expert=topk, ) self.proj = nn.Conv2d(in_ch, in_ch, kernel_size=1, padding=0) def calibrate(self, x: torch.Tensor) -> torch.Tensor: b, c, h, w = x.shape res = x for _ in range(self.recursive): x = self.agg_conv(x) x = self.conv(x) x = F.interpolate(x, size=(h, w), mode="bilinear", align_corners=False) return res + x def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.conv_1(x) if self.use_shuffle: x = channel_shuffle(x, groups=2) x, k = torch.chunk(x, chunks=2, dim=1) x = self.conv_2(x) k = self.calibrate(k) x = self.moe_layer(x, k) x = self.proj(x) return x class MoELayer(nn.Module): def __init__(self, experts: List[nn.Module], gate: nn.Module, num_expert: int = 1): super().__init__() assert len(experts) > 0 self.experts = nn.ModuleList(experts) self.gate = gate self.num_expert = num_expert def forward(self, inputs: torch.Tensor, k: torch.Tensor): out = self.gate(inputs) weights = F.softmax(out, dim=1, dtype=torch.float).to(inputs.dtype) topk_weights, topk_experts = torch.topk(weights, self.num_expert) out = inputs.clone() if self.training: exp_weights = torch.zeros_like(weights) exp_weights.scatter_(1, topk_experts, weights.gather(1, topk_experts)) for i, expert in enumerate(self.experts): out += expert(inputs, k) * exp_weights[:, i:i+1, None, None] else: selected_experts = [self.experts[i] for i in topk_experts.squeeze(dim=0)] for i, expert in enumerate(selected_experts): out += expert(inputs, k) * topk_weights[:, i:i+1, None, None] return out class Expert(nn.Module): def __init__(self, in_ch: int, low_dim: int,): super().__init__() self.conv_1 = nn.Conv2d(in_ch, low_dim, kernel_size=1, padding=0) self.conv_2 = nn.Conv2d(in_ch, low_dim, kernel_size=1, padding=0) self.conv_3 = nn.Conv2d(low_dim, in_ch, kernel_size=1, padding=0) def forward(self, x: torch.Tensor, k: torch.Tensor) -> torch.Tensor: x = self.conv_1(x) x = self.conv_2(k) * x # here no more sigmoid x = self.conv_3(x) return x class Router(nn.Module): def __init__(self, in_ch: int, num_experts: int): super().__init__() self.body = nn.Sequential( nn.AdaptiveAvgPool2d(1), Rearrange('b c 1 1 -> b c'), nn.Linear(in_ch, num_experts, bias=False), ) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.body(x) ################# # Utilities ################# class StripedConv2d(nn.Module): def __init__(self, in_ch: int, kernel_size: int, depthwise: bool = False): super().__init__() self.in_ch = in_ch self.kernel_size = kernel_size self.padding = kernel_size // 2 self.conv = nn.Sequential( nn.Conv2d(in_ch, in_ch, kernel_size=(1, self.kernel_size), padding=(0, self.padding), groups=in_ch if depthwise else 1), nn.Conv2d(in_ch, in_ch, kernel_size=(self.kernel_size, 1), padding=(self.padding, 0), groups=in_ch if depthwise else 1), ) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.conv(x) def channel_shuffle(x, groups=2): bat_size, channels, w, h = x.shape group_c = channels // groups x = x.view(bat_size, groups, group_c, w, h) x = torch.transpose(x, 1, 2).contiguous() x = x.view(bat_size, -1, w, h) return x class GatedFFN(nn.Module): def __init__(self, in_ch, mlp_ratio, kernel_size, act_layer,): super().__init__() mlp_ch = in_ch * mlp_ratio self.fn_1 = nn.Sequential( nn.Conv2d(in_ch, mlp_ch, kernel_size=1, padding=0), act_layer, ) self.fn_2 = nn.Sequential( nn.Conv2d(in_ch, in_ch, kernel_size=1, padding=0), act_layer, ) self.gate = nn.Conv2d(mlp_ch // 2, mlp_ch // 2, kernel_size=kernel_size, padding=kernel_size // 2, groups=mlp_ch // 2) def feat_decompose(self, x): s = x - self.gate(x) x = x + self.sigma * s return x def forward(self, x: torch.Tensor): x = self.fn_1(x) x, gate = torch.chunk(x, 2, dim=1) gate = self.gate(gate) x = x * gate x = self.fn_2(x) return x class LayerNorm(nn.Module): r""" LayerNorm that supports two data formats: channels_last (default) or channels_first. The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height, width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width). """ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): super().__init__() self.weight = nn.Parameter(torch.ones(normalized_shape)) self.bias = nn.Parameter(torch.zeros(normalized_shape)) self.eps = eps self.data_format = data_format if self.data_format not in ["channels_last", "channels_first"]: raise NotImplementedError self.normalized_shape = (normalized_shape, ) def forward(self, x): if self.data_format == "channels_last": return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) elif self.data_format == "channels_first": u = x.mean(1, keepdim=True) s = (x - u).pow(2).mean(1, keepdim=True) x = (x - u) / torch.sqrt(s + self.eps) x = self.weight[:, None, None] * x + self.bias[:, None, None] return x