MonsterForge-small / LightweightGAN /PixelShuffleUpsample.py
michaelriedl's picture
Initial dump
002ca81
raw
history blame contribute delete
791 Bytes
import torch
from torch import nn
from einops import repeat
from .helper_funcs import default
class PixelShuffleUpsample(nn.Module):
def __init__(self, dim, dim_out=None):
super().__init__()
dim_out = default(dim_out, dim)
conv = nn.Conv2d(dim, dim_out * 4, 1)
self.net = nn.Sequential(conv, nn.SiLU(), nn.PixelShuffle(2))
self.init_conv_(conv)
def init_conv_(self, conv):
o, i, h, w = conv.weight.shape
conv_weight = torch.empty(o // 4, i, h, w)
nn.init.kaiming_uniform_(conv_weight)
conv_weight = repeat(conv_weight, "o ... -> (o 4) ...")
conv.weight.data.copy_(conv_weight)
nn.init.zeros_(conv.bias.data)
def forward(self, x):
return self.net(x)