Spaces:
Sleeping
Sleeping
from typing import Tuple, List | |
from torch import Tensor | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from einops.layers.torch import Rearrange | |
###################### | |
# Meta Architecture | |
###################### | |
class SeemoRe(nn.Module): | |
def __init__(self, | |
scale: int = 4, | |
in_chans: int = 3, | |
num_experts: int = 6, | |
num_layers: int = 6, | |
embedding_dim: int = 64, | |
img_range: float = 1.0, | |
use_shuffle: bool = False, | |
global_kernel_size: int = 11, | |
recursive: int = 2, | |
lr_space: int = 1, | |
topk: int = 2,): | |
super().__init__() | |
self.scale = scale | |
self.num_in_channels = in_chans | |
self.num_out_channels = in_chans | |
self.img_range = img_range | |
rgb_mean = (0.4488, 0.4371, 0.4040) | |
self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) | |
# -- SHALLOW FEATURES -- | |
self.conv_1 = nn.Conv2d(self.num_in_channels, embedding_dim, kernel_size=3, padding=1) | |
# -- DEEP FEATURES -- | |
self.body = nn.ModuleList( | |
[ResGroup(in_ch=embedding_dim, | |
num_experts=num_experts, | |
use_shuffle=use_shuffle, | |
topk=topk, | |
lr_space=lr_space, | |
recursive=recursive, | |
global_kernel_size=global_kernel_size) for i in range(num_layers)] | |
) | |
# -- UPSCALE -- | |
self.norm = LayerNorm(embedding_dim, data_format='channels_first') | |
self.conv_2 = nn.Conv2d(embedding_dim, embedding_dim, kernel_size=3, padding=1) | |
self.upsampler = nn.Sequential( | |
nn.Conv2d(embedding_dim, (scale**2) * self.num_out_channels, kernel_size=3, padding=1), | |
nn.PixelShuffle(scale) | |
) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
self.mean = self.mean.type_as(x) | |
x = (x - self.mean) * self.img_range | |
# -- SHALLOW FEATURES -- | |
x = self.conv_1(x) | |
res = x | |
# -- DEEP FEATURES -- | |
for idx, layer in enumerate(self.body): | |
x = layer(x) | |
x = self.norm(x) | |
# -- HR IMAGE RECONSTRUCTION -- | |
x = self.conv_2(x) + res | |
x = self.upsampler(x) | |
x = x / self.img_range + self.mean | |
return x | |
############################# | |
# Components | |
############################# | |
class ResGroup(nn.Module): | |
def __init__(self, | |
in_ch: int, | |
num_experts: int, | |
global_kernel_size: int = 11, | |
lr_space: int = 1, | |
topk: int = 2, | |
recursive: int = 2, | |
use_shuffle: bool = False): | |
super().__init__() | |
self.local_block = RME(in_ch=in_ch, | |
num_experts=num_experts, | |
use_shuffle=use_shuffle, | |
lr_space=lr_space, | |
topk=topk, | |
recursive=recursive) | |
self.global_block = SME(in_ch=in_ch, | |
kernel_size=global_kernel_size) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
x = self.local_block(x) | |
x = self.global_block(x) | |
return x | |
############################# | |
# Global Block | |
############################# | |
class SME(nn.Module): | |
def __init__(self, | |
in_ch: int, | |
kernel_size: int = 11): | |
super().__init__() | |
self.norm_1 = LayerNorm(in_ch, data_format='channels_first') | |
self.block = StripedConvFormer(in_ch=in_ch, kernel_size=kernel_size) | |
self.norm_2 = LayerNorm(in_ch, data_format='channels_first') | |
self.ffn = GatedFFN(in_ch, mlp_ratio=2, kernel_size=3, act_layer=nn.GELU()) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
x = self.block(self.norm_1(x)) + x | |
x = self.ffn(self.norm_2(x)) + x | |
return x | |
class StripedConvFormer(nn.Module): | |
def __init__(self, | |
in_ch: int, | |
kernel_size: int): | |
super().__init__() | |
self.in_ch = in_ch | |
self.kernel_size = kernel_size | |
self.padding = kernel_size // 2 | |
self.proj = nn.Conv2d(in_ch, in_ch, kernel_size=1, padding=0) | |
self.to_qv = nn.Sequential( | |
nn.Conv2d(in_ch, in_ch * 2, kernel_size=1, padding=0), | |
nn.GELU(), | |
) | |
self.attn = StripedConv2d(in_ch, kernel_size=kernel_size, depthwise=True) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
q, v = self.to_qv(x).chunk(2, dim=1) | |
q = self.attn(q) | |
x = self.proj(q * v) | |
return x | |
############################# | |
# Local Blocks | |
############################# | |
class RME(nn.Module): | |
def __init__(self, | |
in_ch: int, | |
num_experts: int, | |
topk: int, | |
lr_space: int = 1, | |
recursive: int = 2, | |
use_shuffle: bool = False,): | |
super().__init__() | |
self.norm_1 = LayerNorm(in_ch, data_format='channels_first') | |
self.block = MoEBlock(in_ch=in_ch, num_experts=num_experts, topk=topk, use_shuffle=use_shuffle, recursive=recursive, lr_space=lr_space,) | |
self.norm_2 = LayerNorm(in_ch, data_format='channels_first') | |
self.ffn = GatedFFN(in_ch, mlp_ratio=2, kernel_size=3, act_layer=nn.GELU()) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
x = self.block(self.norm_1(x)) + x | |
x = self.ffn(self.norm_2(x)) + x | |
return x | |
################# | |
# MoE Layer | |
################# | |
class MoEBlock(nn.Module): | |
def __init__(self, | |
in_ch: int, | |
num_experts: int, | |
topk: int, | |
use_shuffle: bool = False, | |
lr_space: str = "linear", | |
recursive: int = 2): | |
super().__init__() | |
self.use_shuffle = use_shuffle | |
self.recursive = recursive | |
self.conv_1 = nn.Sequential( | |
nn.Conv2d(in_ch, in_ch, kernel_size=3, padding=1), | |
nn.GELU(), | |
nn.Conv2d(in_ch, 2*in_ch, kernel_size=1, padding=0) | |
) | |
self.agg_conv = nn.Sequential( | |
nn.Conv2d(in_ch, in_ch, kernel_size=4, stride=4, groups=in_ch), | |
nn.GELU()) | |
self.conv = nn.Sequential( | |
nn.Conv2d(in_ch, in_ch, kernel_size=3, stride=1, padding=1, groups=in_ch), | |
nn.Conv2d(in_ch, in_ch, kernel_size=1, padding=0) | |
) | |
self.conv_2 = nn.Sequential( | |
StripedConv2d(in_ch, kernel_size=3, depthwise=True), | |
nn.GELU()) | |
if lr_space == "linear": | |
grow_func = lambda i: i+2 | |
elif lr_space == "exp": | |
grow_func = lambda i: 2**(i+1) | |
elif lr_space == "double": | |
grow_func = lambda i: 2*i+2 | |
else: | |
raise NotImplementedError(f"lr_space {lr_space} not implemented") | |
self.moe_layer = MoELayer( | |
experts=[Expert(in_ch=in_ch, low_dim=grow_func(i)) for i in range(num_experts)], # add here multiple of 2 as low_dim | |
gate=Router(in_ch=in_ch, num_experts=num_experts), | |
num_expert=topk, | |
) | |
self.proj = nn.Conv2d(in_ch, in_ch, kernel_size=1, padding=0) | |
def calibrate(self, x: torch.Tensor) -> torch.Tensor: | |
b, c, h, w = x.shape | |
res = x | |
for _ in range(self.recursive): | |
x = self.agg_conv(x) | |
x = self.conv(x) | |
x = F.interpolate(x, size=(h, w), mode="bilinear", align_corners=False) | |
return res + x | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
x = self.conv_1(x) | |
if self.use_shuffle: | |
x = channel_shuffle(x, groups=2) | |
x, k = torch.chunk(x, chunks=2, dim=1) | |
x = self.conv_2(x) | |
k = self.calibrate(k) | |
x = self.moe_layer(x, k) | |
x = self.proj(x) | |
return x | |
class MoELayer(nn.Module): | |
def __init__(self, experts: List[nn.Module], gate: nn.Module, num_expert: int = 1): | |
super().__init__() | |
assert len(experts) > 0 | |
self.experts = nn.ModuleList(experts) | |
self.gate = gate | |
self.num_expert = num_expert | |
def forward(self, inputs: torch.Tensor, k: torch.Tensor): | |
out = self.gate(inputs) | |
weights = F.softmax(out, dim=1, dtype=torch.float).to(inputs.dtype) | |
topk_weights, topk_experts = torch.topk(weights, self.num_expert) | |
out = inputs.clone() | |
if self.training: | |
exp_weights = torch.zeros_like(weights) | |
exp_weights.scatter_(1, topk_experts, weights.gather(1, topk_experts)) | |
for i, expert in enumerate(self.experts): | |
out += expert(inputs, k) * exp_weights[:, i:i+1, None, None] | |
else: | |
selected_experts = [self.experts[i] for i in topk_experts.squeeze(dim=0)] | |
for i, expert in enumerate(selected_experts): | |
out += expert(inputs, k) * topk_weights[:, i:i+1, None, None] | |
return out | |
class Expert(nn.Module): | |
def __init__(self, | |
in_ch: int, | |
low_dim: int,): | |
super().__init__() | |
self.conv_1 = nn.Conv2d(in_ch, low_dim, kernel_size=1, padding=0) | |
self.conv_2 = nn.Conv2d(in_ch, low_dim, kernel_size=1, padding=0) | |
self.conv_3 = nn.Conv2d(low_dim, in_ch, kernel_size=1, padding=0) | |
def forward(self, x: torch.Tensor, k: torch.Tensor) -> torch.Tensor: | |
x = self.conv_1(x) | |
x = self.conv_2(k) * x # here no more sigmoid | |
x = self.conv_3(x) | |
return x | |
class Router(nn.Module): | |
def __init__(self, | |
in_ch: int, | |
num_experts: int): | |
super().__init__() | |
self.body = nn.Sequential( | |
nn.AdaptiveAvgPool2d(1), | |
Rearrange('b c 1 1 -> b c'), | |
nn.Linear(in_ch, num_experts, bias=False), | |
) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
return self.body(x) | |
################# | |
# Utilities | |
################# | |
class StripedConv2d(nn.Module): | |
def __init__(self, | |
in_ch: int, | |
kernel_size: int, | |
depthwise: bool = False): | |
super().__init__() | |
self.in_ch = in_ch | |
self.kernel_size = kernel_size | |
self.padding = kernel_size // 2 | |
self.conv = nn.Sequential( | |
nn.Conv2d(in_ch, in_ch, kernel_size=(1, self.kernel_size), padding=(0, self.padding), groups=in_ch if depthwise else 1), | |
nn.Conv2d(in_ch, in_ch, kernel_size=(self.kernel_size, 1), padding=(self.padding, 0), groups=in_ch if depthwise else 1), | |
) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
return self.conv(x) | |
def channel_shuffle(x, groups=2): | |
bat_size, channels, w, h = x.shape | |
group_c = channels // groups | |
x = x.view(bat_size, groups, group_c, w, h) | |
x = torch.transpose(x, 1, 2).contiguous() | |
x = x.view(bat_size, -1, w, h) | |
return x | |
class GatedFFN(nn.Module): | |
def __init__(self, | |
in_ch, | |
mlp_ratio, | |
kernel_size, | |
act_layer,): | |
super().__init__() | |
mlp_ch = in_ch * mlp_ratio | |
self.fn_1 = nn.Sequential( | |
nn.Conv2d(in_ch, mlp_ch, kernel_size=1, padding=0), | |
act_layer, | |
) | |
self.fn_2 = nn.Sequential( | |
nn.Conv2d(in_ch, in_ch, kernel_size=1, padding=0), | |
act_layer, | |
) | |
self.gate = nn.Conv2d(mlp_ch // 2, mlp_ch // 2, | |
kernel_size=kernel_size, padding=kernel_size // 2, groups=mlp_ch // 2) | |
def feat_decompose(self, x): | |
s = x - self.gate(x) | |
x = x + self.sigma * s | |
return x | |
def forward(self, x: torch.Tensor): | |
x = self.fn_1(x) | |
x, gate = torch.chunk(x, 2, dim=1) | |
gate = self.gate(gate) | |
x = x * gate | |
x = self.fn_2(x) | |
return x | |
class LayerNorm(nn.Module): | |
r""" LayerNorm that supports two data formats: channels_last (default) or channels_first. | |
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with | |
shape (batch_size, height, width, channels) while channels_first corresponds to inputs | |
with shape (batch_size, channels, height, width). | |
""" | |
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): | |
super().__init__() | |
self.weight = nn.Parameter(torch.ones(normalized_shape)) | |
self.bias = nn.Parameter(torch.zeros(normalized_shape)) | |
self.eps = eps | |
self.data_format = data_format | |
if self.data_format not in ["channels_last", "channels_first"]: | |
raise NotImplementedError | |
self.normalized_shape = (normalized_shape, ) | |
def forward(self, x): | |
if self.data_format == "channels_last": | |
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) | |
elif self.data_format == "channels_first": | |
u = x.mean(1, keepdim=True) | |
s = (x - u).pow(2).mean(1, keepdim=True) | |
x = (x - u) / torch.sqrt(s + self.eps) | |
x = self.weight[:, None, None] * x + self.bias[:, None, None] | |
return x |