smatta commited on
Commit
e4bd436
·
verified ·
1 Parent(s): ad632ac

Upload vae/dcae.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. vae/dcae.py +271 -0
vae/dcae.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2025 Hugging Face Team and Overworld
2
+ #
3
+ # This program is free software: you can redistribute it and/or modify
4
+ # it under the terms of the GNU General Public License as published by
5
+ # the Free Software Foundation, either version 3 of the License, or
6
+ # (at your option) any later version.
7
+ #
8
+ # This program is distributed in the hope that it will be useful,
9
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
10
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11
+ # GNU General Public License for more details.
12
+ #
13
+ # You should have received a copy of the GNU General Public License
14
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
15
+
16
+ import torch
17
+ from torch import nn
18
+ import torch.nn.functional as F
19
+
20
+ from torch.nn.utils.parametrizations import weight_norm
21
+ from torch.nn.utils.parametrize import remove_parametrizations
22
+
23
+
24
+ def bake_weight_norm(model: nn.Module) -> nn.Module:
25
+ """Remove weight_norm parametrizations, baking normalized weights into regular tensors.
26
+
27
+ This is required for torch.compile/CUDA graph compatibility since weight_norm
28
+ performs in-place updates during forward passes.
29
+ """
30
+ for module in model.modules():
31
+ if hasattr(module, "parametrizations") and "weight" in getattr(module, "parametrizations", {}):
32
+ remove_parametrizations(module, "weight", leave_parametrized=True)
33
+ return model
34
+
35
+
36
+ # === General Blocks ===
37
+
38
+ def WeightNormConv2d(*args, **kwargs):
39
+ return weight_norm(nn.Conv2d(*args, **kwargs))
40
+
41
+ class ResBlock(nn.Module):
42
+ def __init__(self, ch):
43
+ super().__init__()
44
+
45
+ hidden = 2 * ch
46
+ # 16 channels per group (matches checkpoint shapes like [128,16,3,3] when ch=64)
47
+ n_grps = max(1, hidden // 16)
48
+
49
+ self.conv1 = WeightNormConv2d(ch, hidden, 1, 1, 0)
50
+ self.conv2 = WeightNormConv2d(hidden, hidden, 3, 1, 1, groups=n_grps)
51
+ self.conv3 = WeightNormConv2d(hidden, ch, 1, 1, 0, bias=False)
52
+
53
+ self.act1 = nn.LeakyReLU(inplace=False)
54
+ self.act2 = nn.LeakyReLU(inplace=False)
55
+
56
+ def forward(self, x):
57
+ h = self.conv1(x)
58
+ h = self.act1(h)
59
+ h = self.conv2(h)
60
+ h = self.act2(h)
61
+ h = self.conv3(h)
62
+ return x + h
63
+
64
+ # === Encoder ===
65
+
66
+ class LandscapeToSquare(nn.Module):
67
+ # Strict assumption of 360p
68
+ def __init__(self, ch_in, ch_out):
69
+ super().__init__()
70
+
71
+ self.proj = WeightNormConv2d(ch_in, ch_out, 3, 1, 1)
72
+
73
+ def forward(self, x):
74
+ x = F.interpolate(x, (512, 512), mode='bicubic')
75
+ x = self.proj(x)
76
+ return x
77
+
78
+ class Downsample(nn.Module):
79
+ def __init__(self, ch_in, ch_out):
80
+ super().__init__()
81
+
82
+ self.proj = WeightNormConv2d(ch_in, ch_out, 1, 1, 0, bias=False)
83
+
84
+ def forward(self, x):
85
+ x = F.interpolate(x, scale_factor=0.5, mode='bicubic')
86
+ x = self.proj(x)
87
+ return x
88
+
89
+ class DownBlock(nn.Module):
90
+ def __init__(self, ch_in, ch_out, num_res=1):
91
+ super().__init__()
92
+
93
+ self.down = Downsample(ch_in, ch_out)
94
+ blocks = []
95
+ for _ in range(num_res):
96
+ blocks.append(ResBlock(ch_in))
97
+ self.blocks = nn.ModuleList(blocks)
98
+
99
+ def forward(self, x):
100
+ for block in self.blocks:
101
+ x = block(x)
102
+ x = self.down(x)
103
+ return x
104
+
105
+ class SpaceToChannel(nn.Module):
106
+ def __init__(self, ch_in, ch_out):
107
+ super().__init__()
108
+
109
+ self.proj = WeightNormConv2d(ch_in, ch_out // 4, 3, 1, 1)
110
+
111
+ def forward(self, x):
112
+ x = self.proj(x)
113
+ x = F.pixel_unshuffle(x, 2).contiguous()
114
+ return x
115
+
116
+ class ChannelAverage(nn.Module):
117
+ def __init__(self, ch_in, ch_out):
118
+ super().__init__()
119
+
120
+ self.proj = WeightNormConv2d(ch_in, ch_out, 3, 1, 1)
121
+ self.grps = ch_in // ch_out
122
+ self.scale = (self.grps) ** 0.5
123
+
124
+ def forward(self, x):
125
+ res = x
126
+ x = self.proj(x.contiguous()) # [b, ch_out, h, w]
127
+
128
+ # Residual goes through channel avg
129
+ res = res.view(res.shape[0], self.grps, res.shape[1] // self.grps, res.shape[2], res.shape[3]).contiguous()
130
+ res = res.mean(dim=1) * self.scale # [b, ch_out, h, w]
131
+
132
+ return res + x
133
+
134
+ # === Decoder ===
135
+
136
+ class SquareToLandscape(nn.Module):
137
+ def __init__(self, ch_in, ch_out):
138
+ super().__init__()
139
+
140
+ self.proj = WeightNormConv2d(ch_in, ch_out, 3, 1, 1)
141
+
142
+ def forward(self, x):
143
+ x = self.proj(x) # TODO This ordering is wrong for both
144
+ x = F.interpolate(x, (360, 640), mode='bicubic')
145
+ return x
146
+
147
+ class Upsample(nn.Module):
148
+ def __init__(self, ch_in, ch_out):
149
+ super().__init__()
150
+
151
+ self.proj = nn.Identity() if ch_in == ch_out else WeightNormConv2d(
152
+ ch_in, ch_out, 1, 1, 0, bias=False
153
+ )
154
+
155
+ def forward(self, x):
156
+ x = self.proj(x)
157
+ x = F.interpolate(x, scale_factor=2.0, mode='bicubic')
158
+ return x
159
+
160
+ class UpBlock(nn.Module):
161
+ def __init__(self, ch_in, ch_out, num_res=1):
162
+ super().__init__()
163
+
164
+ self.up = Upsample(ch_in, ch_out)
165
+ blocks = []
166
+ for _ in range(num_res):
167
+ blocks.append(ResBlock(ch_out))
168
+ self.blocks = nn.ModuleList(blocks)
169
+
170
+ def forward(self, x):
171
+ x = self.up(x)
172
+ for block in self.blocks:
173
+ x = block(x)
174
+ return x
175
+
176
+ class ChannelToSpace(nn.Module):
177
+ def __init__(self, ch_in, ch_out):
178
+ super().__init__()
179
+
180
+ self.proj = WeightNormConv2d(ch_in, ch_out * 4, 3, 1, 1)
181
+
182
+ def forward(self, x):
183
+ x = self.proj(x)
184
+ x = F.pixel_shuffle(x, 2).contiguous()
185
+ return x
186
+
187
+ class ChannelDuplication(nn.Module):
188
+ def __init__(self, ch_in, ch_out):
189
+ super().__init__()
190
+
191
+ self.proj = WeightNormConv2d(ch_in, ch_out, 3, 1, 1)
192
+ self.reps = ch_out // ch_in
193
+ self.scale = (self.reps) ** -0.5
194
+
195
+ def forward(self, x):
196
+ res = x
197
+ x = self.proj(x.contiguous())
198
+
199
+ b, c, h, w = res.shape
200
+ res = res.unsqueeze(2) # [b, c, 1, h, w]
201
+ res = res.expand(b, c, self.reps, h, w) # [b, c, reps, h, w]
202
+ res = res.reshape(b, c * self.reps, h, w).contiguous()
203
+ res = res * self.scale
204
+
205
+ return res + x
206
+
207
+ # === Main AE ===
208
+
209
+ class Encoder(nn.Module):
210
+ def __init__(self, config):
211
+ super().__init__()
212
+
213
+ self.conv_in = LandscapeToSquare(config.channels, config.ch_0)
214
+
215
+ blocks = []
216
+ residuals = []
217
+
218
+ ch = config.ch_0
219
+ for block_count in config.encoder_blocks_per_stage:
220
+ next_ch = min(ch*2, config.ch_max)
221
+
222
+ blocks.append(DownBlock(ch, next_ch, block_count))
223
+ residuals.append(SpaceToChannel(ch, next_ch))
224
+
225
+ ch = next_ch
226
+
227
+ self.blocks = nn.ModuleList(blocks)
228
+ self.residuals = nn.ModuleList(residuals)
229
+ self.conv_out = ChannelAverage(ch, config.latent_channels)
230
+
231
+ self.skip_logvar = bool(getattr(config, "skip_logvar", False))
232
+ if not self.skip_logvar:
233
+ # Checkpoint expects a 1-channel logvar head: [1, ch, 3, 3]
234
+ self.conv_out_logvar = WeightNormConv2d(ch, 1, 3, 1, 1)
235
+
236
+ def forward(self, x):
237
+ x = self.conv_in(x)
238
+ for block, residual in zip(self.blocks, self.residuals):
239
+ x = block(x) + residual(x)
240
+ return self.conv_out(x)
241
+
242
+ class Decoder(nn.Module):
243
+ def __init__(self, config):
244
+ super().__init__()
245
+
246
+ self.conv_in = ChannelDuplication(config.latent_channels, config.ch_max)
247
+
248
+ blocks = []
249
+ residuals = []
250
+
251
+ ch = config.ch_0
252
+ for block_count in reversed(config.decoder_blocks_per_stage):
253
+ next_ch = min(ch*2, config.ch_max)
254
+
255
+ blocks.append(UpBlock(next_ch, ch, block_count))
256
+ residuals.append(ChannelToSpace(next_ch, ch))
257
+
258
+ ch = next_ch
259
+
260
+ self.blocks = nn.ModuleList(reversed(blocks))
261
+ self.residuals = nn.ModuleList(reversed(residuals))
262
+
263
+ self.act_out = nn.SiLU()
264
+ self.conv_out = SquareToLandscape(config.ch_0, config.channels)
265
+
266
+ def forward(self, x):
267
+ x = self.conv_in(x)
268
+ for block, residual in zip(self.blocks, self.residuals):
269
+ x = block(x) + residual(x)
270
+ x = self.act_out(x)
271
+ return self.conv_out(x)