File size: 791 Bytes
002ca81 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 |
import torch
from torch import nn
from einops import repeat
from .helper_funcs import default
class PixelShuffleUpsample(nn.Module):
def __init__(self, dim, dim_out=None):
super().__init__()
dim_out = default(dim_out, dim)
conv = nn.Conv2d(dim, dim_out * 4, 1)
self.net = nn.Sequential(conv, nn.SiLU(), nn.PixelShuffle(2))
self.init_conv_(conv)
def init_conv_(self, conv):
o, i, h, w = conv.weight.shape
conv_weight = torch.empty(o // 4, i, h, w)
nn.init.kaiming_uniform_(conv_weight)
conv_weight = repeat(conv_weight, "o ... -> (o 4) ...")
conv.weight.data.copy_(conv_weight)
nn.init.zeros_(conv.bias.data)
def forward(self, x):
return self.net(x)
|