|
import torch
|
|
from torch import nn
|
|
from einops import repeat
|
|
from .helper_funcs import default
|
|
|
|
|
|
class PixelShuffleUpsample(nn.Module):
|
|
def __init__(self, dim, dim_out=None):
|
|
super().__init__()
|
|
dim_out = default(dim_out, dim)
|
|
conv = nn.Conv2d(dim, dim_out * 4, 1)
|
|
|
|
self.net = nn.Sequential(conv, nn.SiLU(), nn.PixelShuffle(2))
|
|
|
|
self.init_conv_(conv)
|
|
|
|
def init_conv_(self, conv):
|
|
o, i, h, w = conv.weight.shape
|
|
conv_weight = torch.empty(o // 4, i, h, w)
|
|
nn.init.kaiming_uniform_(conv_weight)
|
|
conv_weight = repeat(conv_weight, "o ... -> (o 4) ...")
|
|
|
|
conv.weight.data.copy_(conv_weight)
|
|
nn.init.zeros_(conv.bias.data)
|
|
|
|
def forward(self, x):
|
|
return self.net(x)
|
|
|