File size: 4,446 Bytes
002ca81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import torch.nn.functional as F
from torch import nn
from math import log2
from einops import rearrange

from .Blur import Blur
from .Noise import Noise
from .FCANet import FCANet
from .PreNorm import PreNorm
from .Conv2dSame import Conv2dSame
from .GlobalContext import GlobalContext
from .LinearAttention import LinearAttention
from .PixelShuffleUpsample import PixelShuffleUpsample
from .helper_funcs import exists, is_power_of_two, default


class Generator(nn.Module):
    def __init__(

        self,

        *,

        image_size,

        latent_dim=256,

        fmap_max=512,

        fmap_inverse_coef=12,

        transparent=False,

        greyscale=False,

        attn_res_layers=[],

        freq_chan_attn=False,

        syncbatchnorm=False,

        antialias=False,

    ):
        super().__init__()
        resolution = log2(image_size)
        assert is_power_of_two(image_size), "image size must be a power of 2"

        # Set the normalization and blur
        norm_class = nn.SyncBatchNorm if syncbatchnorm else nn.BatchNorm2d
        Blur = nn.Identity if not antialias else Blur

        if transparent:
            init_channel = 4
        elif greyscale:
            init_channel = 1
        else:
            init_channel = 3

        self.latent_dim = latent_dim

        fmap_max = default(fmap_max, latent_dim)

        self.initial_conv = nn.Sequential(
            nn.ConvTranspose2d(latent_dim, latent_dim * 2, 4),
            norm_class(latent_dim * 2),
            nn.GLU(dim=1),
        )

        num_layers = int(resolution) - 2
        features = list(
            map(lambda n: (n, 2 ** (fmap_inverse_coef - n)), range(2, num_layers + 2))
        )
        features = list(map(lambda n: (n[0], min(n[1], fmap_max)), features))
        features = list(map(lambda n: 3 if n[0] >= 8 else n[1], features))
        features = [latent_dim, *features]

        in_out_features = list(zip(features[:-1], features[1:]))

        self.res_layers = range(2, num_layers + 2)
        self.layers = nn.ModuleList([])
        self.res_to_feature_map = dict(zip(self.res_layers, in_out_features))

        self.sle_map = ((3, 7), (4, 8), (5, 9), (6, 10))
        self.sle_map = list(
            filter(lambda t: t[0] <= resolution and t[1] <= resolution, self.sle_map)
        )
        self.sle_map = dict(self.sle_map)

        self.num_layers_spatial_res = 1

        for res, (chan_in, chan_out) in zip(self.res_layers, in_out_features):
            image_width = 2**res

            attn = None
            if image_width in attn_res_layers:
                attn = PreNorm(chan_in, LinearAttention(chan_in))

            sle = None
            if res in self.sle_map:
                residual_layer = self.sle_map[res]
                sle_chan_out = self.res_to_feature_map[residual_layer - 1][-1]

                if freq_chan_attn:
                    sle = FCANet(
                        chan_in=chan_out, chan_out=sle_chan_out, width=2 ** (res + 1)
                    )
                else:
                    sle = GlobalContext(chan_in=chan_out, chan_out=sle_chan_out)

            layer = nn.ModuleList(
                [
                    nn.Sequential(
                        PixelShuffleUpsample(chan_in),
                        Blur(),
                        Conv2dSame(chan_in, chan_out * 2, 4),
                        Noise(),
                        norm_class(chan_out * 2),
                        nn.GLU(dim=1),
                    ),
                    sle,
                    attn,
                ]
            )
            self.layers.append(layer)

        self.out_conv = nn.Conv2d(features[-1], init_channel, 3, padding=1)

    def forward(self, x):
        x = rearrange(x, "b c -> b c () ()")
        x = self.initial_conv(x)
        x = F.normalize(x, dim=1)

        residuals = dict()

        for res, (up, sle, attn) in zip(self.res_layers, self.layers):
            if exists(attn):
                x = attn(x) + x

            x = up(x)

            if exists(sle):
                out_res = self.sle_map[res]
                residual = sle(x)
                residuals[out_res] = residual

            next_res = res + 1
            if next_res in residuals:
                x = x * residuals[next_res]

        return self.out_conv(x)