import torch from torch import nn from einops import repeat from .helper_funcs import default class PixelShuffleUpsample(nn.Module): def __init__(self, dim, dim_out=None): super().__init__() dim_out = default(dim_out, dim) conv = nn.Conv2d(dim, dim_out * 4, 1) self.net = nn.Sequential(conv, nn.SiLU(), nn.PixelShuffle(2)) self.init_conv_(conv) def init_conv_(self, conv): o, i, h, w = conv.weight.shape conv_weight = torch.empty(o // 4, i, h, w) nn.init.kaiming_uniform_(conv_weight) conv_weight = repeat(conv_weight, "o ... -> (o 4) ...") conv.weight.data.copy_(conv_weight) nn.init.zeros_(conv.bias.data) def forward(self, x): return self.net(x)