File size: 1,780 Bytes
2a8a75a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
'''
Codes are from:
https://github.com/jaxony/unet-pytorch/blob/master/model.py
'''

import torch
import torch.nn as nn
from diffusers import UNet2DModel
import einops
class UNetPP(nn.Module):
    '''
        Wrapper for UNet in diffusers
    '''
    def __init__(self, in_channels):
        super(UNetPP, self).__init__()
        self.in_channels = in_channels
        self.unet = UNet2DModel(
                sample_size=[256, 256*3],
                in_channels=in_channels,
                out_channels=32,
                layers_per_block=2,
                block_out_channels=(64, 128, 128, 128*2, 128*2, 128*4, 128*4),
                down_block_types=(
                    "DownBlock2D",
                    "DownBlock2D",
                    "DownBlock2D",
                    "AttnDownBlock2D",
                    "AttnDownBlock2D",
                    "AttnDownBlock2D",
                    "DownBlock2D",
                ),
                up_block_types=(
                    "UpBlock2D",
                    "AttnUpBlock2D",
                    "AttnUpBlock2D",
                    "AttnUpBlock2D",
                    "UpBlock2D",
                    "UpBlock2D",
                    "UpBlock2D",
                ),
            )
             
        self.unet.enable_xformers_memory_efficient_attention()    
        if in_channels > 12:
            self.learned_plane = torch.nn.parameter.Parameter(torch.zeros([1,in_channels-12,256,256*3]))

    def forward(self, x, t=256):
        learned_plane = self.learned_plane
        if x.shape[1] < self.in_channels:
            learned_plane = einops.repeat(learned_plane, '1 C H W -> B C H W', B=x.shape[0]).to(x.device)
            x = torch.cat([x, learned_plane], dim = 1)
        return self.unet(x, t).sample