Spaces:
Running
on
Zero
Running
on
Zero
''' | |
Codes are from: | |
https://github.com/jaxony/unet-pytorch/blob/master/model.py | |
''' | |
import torch | |
import torch.nn as nn | |
from diffusers import UNet2DModel | |
import einops | |
class UNetPP(nn.Module): | |
''' | |
Wrapper for UNet in diffusers | |
''' | |
def __init__(self, in_channels): | |
super(UNetPP, self).__init__() | |
self.in_channels = in_channels | |
self.unet = UNet2DModel( | |
sample_size=[256, 256*3], | |
in_channels=in_channels, | |
out_channels=32, | |
layers_per_block=2, | |
block_out_channels=(64, 128, 128, 128*2, 128*2, 128*4, 128*4), | |
down_block_types=( | |
"DownBlock2D", | |
"DownBlock2D", | |
"DownBlock2D", | |
"AttnDownBlock2D", | |
"AttnDownBlock2D", | |
"AttnDownBlock2D", | |
"DownBlock2D", | |
), | |
up_block_types=( | |
"UpBlock2D", | |
"AttnUpBlock2D", | |
"AttnUpBlock2D", | |
"AttnUpBlock2D", | |
"UpBlock2D", | |
"UpBlock2D", | |
"UpBlock2D", | |
), | |
) | |
self.unet.enable_xformers_memory_efficient_attention() | |
if in_channels > 12: | |
self.learned_plane = torch.nn.parameter.Parameter(torch.zeros([1,in_channels-12,256,256*3])) | |
def forward(self, x, t=256): | |
learned_plane = self.learned_plane | |
if x.shape[1] < self.in_channels: | |
learned_plane = einops.repeat(learned_plane, '1 C H W -> B C H W', B=x.shape[0]).to(x.device) | |
x = torch.cat([x, learned_plane], dim = 1) | |
return self.unet(x, t).sample | |