bubbliiiing commited on
Commit
43ed08d
1 Parent(s): 08038f7

add requirements

Browse files
easyanimate/vae/ldm/models/__init__.py ADDED
File without changes
easyanimate/vae/ldm/models/autoencoder.py ADDED
@@ -0,0 +1,337 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from contextlib import contextmanager
3
+
4
+ import pytorch_lightning as pl
5
+ import torch
6
+ import torch.nn.functional as F
7
+
8
+ from ..modules.diffusionmodules.model import Decoder, Encoder
9
+ from ..modules.distributions.distributions import DiagonalGaussianDistribution
10
+ from ..util import instantiate_from_config
11
+ from .enc_dec_pytorch import Decoder as Mag_Decoder
12
+ from .enc_dec_pytorch import Encoder as Mag_Encoder
13
+
14
+
15
+ class AutoencoderKLMagvit(pl.LightningModule):
16
+ def __init__(self,
17
+ ddconfig,
18
+ lossconfig,
19
+ embed_dim,
20
+ ckpt_path=None,
21
+ ignore_keys=[],
22
+ image_key="image",
23
+ colorize_nlabels=None,
24
+ monitor=None,
25
+ ):
26
+ super().__init__()
27
+ self.image_key = image_key
28
+ self.encoder = Mag_Encoder()
29
+ self.decoder = Mag_Decoder()
30
+ self.loss = instantiate_from_config(lossconfig)
31
+ self.quant_conv = torch.nn.Conv3d(16, 16, 1)
32
+ self.post_quant_conv = torch.nn.Conv3d(8, 8, 1)
33
+ self.embed_dim = embed_dim
34
+ if colorize_nlabels is not None:
35
+ assert type(colorize_nlabels)==int
36
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
37
+ if monitor is not None:
38
+ self.monitor = monitor
39
+ if ckpt_path is not None:
40
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
41
+
42
+ def init_from_ckpt(self, path, ignore_keys=list()):
43
+ sd = torch.load(path, map_location="cpu")["state_dict"]
44
+ keys = list(sd.keys())
45
+ for k in keys:
46
+ for ik in ignore_keys:
47
+ if k.startswith(ik):
48
+ print("Deleting key {} from state_dict.".format(k))
49
+ del sd[k]
50
+ self.load_state_dict(sd, strict=False)
51
+ print(f"Restored from {path}")
52
+
53
+ def encode(self, x):
54
+ h = self.encoder(x)
55
+ moments = self.quant_conv(h)
56
+ posterior = DiagonalGaussianDistribution(moments)
57
+ return posterior
58
+
59
+ def decode(self, z):
60
+ z = self.post_quant_conv(z)
61
+ dec = self.decoder(z)
62
+ return dec
63
+
64
+ def forward(self, input, sample_posterior=True):
65
+ if input.ndim==4:
66
+ input = input.unsqueeze(2)
67
+ posterior = self.encode(input)
68
+ if sample_posterior:
69
+ z = posterior.sample()
70
+ else:
71
+ z = posterior.mode()
72
+ dec = self.decode(z)
73
+ return dec, posterior
74
+
75
+ def get_input(self, batch, k):
76
+ x = batch[k]
77
+ if x.ndim==5:
78
+ x = x.permute(0, 4, 1, 2, 3).to(memory_format=torch.contiguous_format).float()
79
+ return x
80
+ if len(x.shape) == 3:
81
+ x = x[..., None]
82
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
83
+ return x
84
+
85
+ def training_step(self, batch, batch_idx, optimizer_idx):
86
+ # tic = time.time()
87
+ inputs = self.get_input(batch, self.image_key)
88
+ # print(f"get_input time {time.time() - tic}")
89
+ # tic = time.time()
90
+ reconstructions, posterior = self(inputs)
91
+ # print(f"model forward time {time.time() - tic}")
92
+
93
+ if optimizer_idx == 0:
94
+ # train encoder+decoder+logvar
95
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
96
+ last_layer=self.get_last_layer(), split="train")
97
+ self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
98
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
99
+ # print(f"cal loss time {time.time() - tic}")
100
+ return aeloss
101
+
102
+ if optimizer_idx == 1:
103
+ # train the discriminator
104
+ discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
105
+ last_layer=self.get_last_layer(), split="train")
106
+
107
+ self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
108
+ self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
109
+ # print(f"cal loss time {time.time() - tic}")
110
+ return discloss
111
+
112
+ def validation_step(self, batch, batch_idx):
113
+ with torch.no_grad():
114
+ inputs = self.get_input(batch, self.image_key)
115
+ reconstructions, posterior = self(inputs)
116
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
117
+ last_layer=self.get_last_layer(), split="val")
118
+
119
+ discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
120
+ last_layer=self.get_last_layer(), split="val")
121
+
122
+ self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
123
+ self.log_dict(log_dict_ae)
124
+ self.log_dict(log_dict_disc)
125
+ return self.log_dict
126
+
127
+ def configure_optimizers(self):
128
+ lr = self.learning_rate
129
+ opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
130
+ list(self.decoder.parameters())+
131
+ list(self.quant_conv.parameters())+
132
+ list(self.post_quant_conv.parameters()),
133
+ lr=lr, betas=(0.5, 0.9))
134
+ opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
135
+ lr=lr, betas=(0.5, 0.9))
136
+ return [opt_ae, opt_disc], []
137
+
138
+ def get_last_layer(self):
139
+ return self.decoder.conv_out.weight
140
+
141
+ @torch.no_grad()
142
+ def log_images(self, batch, only_inputs=False, **kwargs):
143
+ log = dict()
144
+ x = self.get_input(batch, self.image_key)
145
+ x = x.to(self.device)
146
+ if not only_inputs:
147
+ xrec, posterior = self(x)
148
+ if x.shape[1] > 3:
149
+ # colorize with random projection
150
+ assert xrec.shape[1] > 3
151
+ x = self.to_rgb(x)
152
+ xrec = self.to_rgb(xrec)
153
+ log["samples"] = self.decode(torch.randn_like(posterior.sample()))
154
+ log["reconstructions"] = xrec
155
+ log["inputs"] = x
156
+ return log
157
+
158
+ def to_rgb(self, x):
159
+ assert self.image_key == "segmentation"
160
+ if not hasattr(self, "colorize"):
161
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
162
+ x = F.conv2d(x, weight=self.colorize)
163
+ x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
164
+ return x
165
+
166
+ class AutoencoderKL(pl.LightningModule):
167
+ def __init__(self,
168
+ ddconfig,
169
+ lossconfig,
170
+ embed_dim,
171
+ ckpt_path=None,
172
+ ignore_keys=[],
173
+ image_key="image",
174
+ colorize_nlabels=None,
175
+ monitor=None,
176
+ ):
177
+ super().__init__()
178
+ self.image_key = image_key
179
+ self.encoder = Encoder(**ddconfig)
180
+ self.decoder = Decoder(**ddconfig)
181
+ self.loss = instantiate_from_config(lossconfig)
182
+ assert ddconfig["double_z"]
183
+ self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
184
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
185
+ self.embed_dim = embed_dim
186
+ if colorize_nlabels is not None:
187
+ assert type(colorize_nlabels)==int
188
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
189
+ if monitor is not None:
190
+ self.monitor = monitor
191
+ if ckpt_path is not None:
192
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
193
+
194
+ def init_from_ckpt(self, path, ignore_keys=list()):
195
+ sd = torch.load(path, map_location="cpu")["state_dict"]
196
+ keys = list(sd.keys())
197
+ for k in keys:
198
+ for ik in ignore_keys:
199
+ if k.startswith(ik):
200
+ print("Deleting key {} from state_dict.".format(k))
201
+ del sd[k]
202
+ self.load_state_dict(sd, strict=False)
203
+ print(f"Restored from {path}")
204
+
205
+ def encode(self, x):
206
+ h = self.encoder(x)
207
+ moments = self.quant_conv(h)
208
+ posterior = DiagonalGaussianDistribution(moments)
209
+ return posterior
210
+
211
+ def decode(self, z):
212
+ z = self.post_quant_conv(z)
213
+ dec = self.decoder(z)
214
+ return dec
215
+
216
+ def forward(self, input, sample_posterior=True):
217
+ posterior = self.encode(input)
218
+ if sample_posterior:
219
+ z = posterior.sample()
220
+ else:
221
+ z = posterior.mode()
222
+ dec = self.decode(z)
223
+ return dec, posterior
224
+
225
+ def get_input(self, batch, k):
226
+ x = batch[k]
227
+ if len(x.shape) == 3:
228
+ x = x[..., None]
229
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
230
+ return x
231
+
232
+ def training_step(self, batch, batch_idx, optimizer_idx):
233
+ # tic = time.time()
234
+ inputs = self.get_input(batch, self.image_key)
235
+ # print(f"get_input time {time.time() - tic}")
236
+ # tic = time.time()
237
+ reconstructions, posterior = self(inputs)
238
+ # print(f"model forward time {time.time() - tic}")
239
+ tic = time.time()
240
+
241
+ if optimizer_idx == 0:
242
+ # train encoder+decoder+logvar
243
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
244
+ last_layer=self.get_last_layer(), split="train")
245
+ self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
246
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
247
+ # print(f"cal loss time {time.time() - tic}")
248
+ return aeloss
249
+
250
+ if optimizer_idx == 1:
251
+ # train the discriminator
252
+ discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
253
+ last_layer=self.get_last_layer(), split="train")
254
+
255
+ self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
256
+ self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
257
+ # print(f"cal loss time {time.time() - tic}")
258
+ return discloss
259
+
260
+ def validation_step(self, batch, batch_idx):
261
+ tic = time.time()
262
+ inputs = self.get_input(batch, self.image_key)
263
+ print(f"get_input time {time.time() - tic}")
264
+ tic = time.time()
265
+ reconstructions, posterior = self(inputs)
266
+ print(f"val forward time {time.time() - tic}")
267
+ tic = time.time()
268
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
269
+ last_layer=self.get_last_layer(), split="val")
270
+
271
+ discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
272
+ last_layer=self.get_last_layer(), split="val")
273
+
274
+ self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
275
+ self.log_dict(log_dict_ae)
276
+ self.log_dict(log_dict_disc)
277
+ print(f"val end time {time.time() - tic}")
278
+ return self.log_dict
279
+
280
+ def configure_optimizers(self):
281
+ lr = self.learning_rate
282
+ opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
283
+ list(self.decoder.parameters())+
284
+ list(self.quant_conv.parameters())+
285
+ list(self.post_quant_conv.parameters()),
286
+ lr=lr, betas=(0.5, 0.9))
287
+ opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
288
+ lr=lr, betas=(0.5, 0.9))
289
+ return [opt_ae, opt_disc], []
290
+
291
+ def get_last_layer(self):
292
+ return self.decoder.conv_out.weight
293
+
294
+ @torch.no_grad()
295
+ def log_images(self, batch, only_inputs=False, **kwargs):
296
+ log = dict()
297
+ x = self.get_input(batch, self.image_key)
298
+ x = x.to(self.device)
299
+ if not only_inputs:
300
+ xrec, posterior = self(x)
301
+ if x.shape[1] > 3:
302
+ # colorize with random projection
303
+ assert xrec.shape[1] > 3
304
+ x = self.to_rgb(x)
305
+ xrec = self.to_rgb(xrec)
306
+ log["samples"] = self.decode(torch.randn_like(posterior.sample()))
307
+ log["reconstructions"] = xrec
308
+ log["inputs"] = x
309
+ return log
310
+
311
+ def to_rgb(self, x):
312
+ assert self.image_key == "segmentation"
313
+ if not hasattr(self, "colorize"):
314
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
315
+ x = F.conv2d(x, weight=self.colorize)
316
+ x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
317
+ return x
318
+
319
+
320
+ class IdentityFirstStage(torch.nn.Module):
321
+ def __init__(self, *args, vq_interface=False, **kwargs):
322
+ self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff
323
+ super().__init__()
324
+
325
+ def encode(self, x, *args, **kwargs):
326
+ return x
327
+
328
+ def decode(self, x, *args, **kwargs):
329
+ return x
330
+
331
+ def quantize(self, x, *args, **kwargs):
332
+ if self.vq_interface:
333
+ return x, None, [None, None, None]
334
+ return x
335
+
336
+ def forward(self, x, *args, **kwargs):
337
+ return x
easyanimate/vae/ldm/models/enc_dec_pytorch.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch import nn
4
+
5
+
6
+ def cast_tuple(t, length = 1):
7
+ return t if isinstance(t, tuple) else ((t,) * length)
8
+
9
+ def divisible_by(num, den):
10
+ return (num % den) == 0
11
+
12
+ def is_odd(n):
13
+ return not divisible_by(n, 2)
14
+
15
+ class CausalConv3d(nn.Module):
16
+ def __init__(
17
+ self,
18
+ chan_in,
19
+ chan_out,
20
+ kernel_size,
21
+ pad_mode = 'constant',
22
+ **kwargs
23
+ ):
24
+ super().__init__()
25
+ kernel_size = cast_tuple(kernel_size, 3)
26
+
27
+ time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
28
+
29
+ assert is_odd(height_kernel_size) and is_odd(width_kernel_size)
30
+
31
+ dilation = kwargs.pop('dilation', 1)
32
+ stride = kwargs.pop('stride', 1)
33
+
34
+ self.pad_mode = pad_mode
35
+ time_pad = dilation * (time_kernel_size - 1) + (1 - stride)
36
+ height_pad = height_kernel_size // 2
37
+ width_pad = width_kernel_size // 2
38
+
39
+ self.time_pad = time_pad
40
+ self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0)
41
+
42
+ stride = (stride, 1, 1)
43
+ dilation = (dilation, 1, 1)
44
+ self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride = stride, dilation = dilation, **kwargs)
45
+
46
+ def forward(self, x):
47
+ x = F.pad(x, self.time_causal_padding, mode = 'replicate')
48
+ return self.conv(x)
49
+
50
+ class Swish(nn.Module):
51
+ def __init__(self) -> None:
52
+ super().__init__()
53
+
54
+ def forward(self, x):
55
+ return x * F.sigmoid(x)
56
+
57
+ class ResBlockX(nn.Module):
58
+ def __init__(self, inchannel) -> None:
59
+ super().__init__()
60
+ self.conv = nn.Sequential(
61
+ nn.GroupNorm(32, inchannel),
62
+ Swish(),
63
+ CausalConv3d(inchannel, inchannel, 3),
64
+ nn.GroupNorm(32, inchannel),
65
+ Swish(),
66
+ CausalConv3d(inchannel, inchannel, 3)
67
+ )
68
+
69
+ def forward(self, x):
70
+ return x + self.conv(x)
71
+
72
+ class ResBlockXY(nn.Module):
73
+ def __init__(self, inchannel, outchannel) -> None:
74
+ super().__init__()
75
+ self.conv = nn.Sequential(
76
+ nn.GroupNorm(32, inchannel),
77
+ Swish(),
78
+ CausalConv3d(inchannel, outchannel, 3),
79
+ nn.GroupNorm(32, outchannel),
80
+ Swish(),
81
+ CausalConv3d(outchannel, outchannel, 3)
82
+ )
83
+ self.conv_1 = nn.Conv3d(inchannel, outchannel, 1)
84
+
85
+ def forward(self, x):
86
+ return self.conv_1(x) + self.conv(x)
87
+
88
+ class PoolDown222(nn.Module):
89
+ def __init__(self) -> None:
90
+ super().__init__()
91
+ self.pool = nn.AvgPool3d(2, 2)
92
+
93
+ def forward(self, x):
94
+ x = F.pad(x, (0, 0, 0, 0, 1, 0), 'replicate')
95
+ return self.pool(x)
96
+
97
+ class PoolDown122(nn.Module):
98
+ def __init__(self) -> None:
99
+ super().__init__()
100
+ self.pool = nn.AvgPool3d((1, 2, 2), (1, 2, 2))
101
+
102
+ def forward(self, x):
103
+ return self.pool(x)
104
+
105
+ class Unpool222(nn.Module):
106
+ def __init__(self) -> None:
107
+ super().__init__()
108
+ self.up = nn.Upsample(scale_factor=2, mode='nearest')
109
+
110
+ def forward(self, x):
111
+ x = self.up(x)
112
+ return x[:, :, 1:]
113
+
114
+ class Unpool122(nn.Module):
115
+ def __init__(self) -> None:
116
+ super().__init__()
117
+ self.up = nn.Upsample(scale_factor=(1, 2, 2), mode='nearest')
118
+
119
+ def forward(self, x):
120
+ x = self.up(x)
121
+ return x
122
+
123
+ class ResBlockDown(nn.Module):
124
+ def __init__(self, inchannel, outchannel) -> None:
125
+ super().__init__()
126
+ self.blcok = nn.Sequential(
127
+ CausalConv3d(inchannel, outchannel, 3),
128
+ nn.LeakyReLU(inplace=True),
129
+ PoolDown222(),
130
+ CausalConv3d(outchannel, outchannel, 3),
131
+ nn.LeakyReLU(inplace=True)
132
+ )
133
+ self.res = nn.Sequential(
134
+ PoolDown222(),
135
+ nn.Conv3d(inchannel, outchannel, 1)
136
+ )
137
+
138
+ def forward(self, x):
139
+ return self.res(x) + self.blcok(x)
140
+
141
+
142
+ class Discriminator(nn.Module):
143
+ def __init__(self) -> None:
144
+ super().__init__()
145
+ self.block = nn.Sequential(
146
+ CausalConv3d(3, 64, 3),
147
+ nn.LeakyReLU(inplace=True),
148
+ ResBlockDown(64, 128),
149
+ ResBlockDown(128, 256),
150
+ ResBlockDown(256, 256),
151
+ ResBlockDown(256, 256),
152
+ ResBlockDown(256, 256),
153
+ CausalConv3d(256, 256, 3),
154
+ nn.LeakyReLU(inplace=True),
155
+ nn.AdaptiveAvgPool3d(1),
156
+ nn.Flatten(),
157
+ nn.Linear(256, 256),
158
+ nn.LeakyReLU(inplace=True),
159
+ nn.Linear(256, 1)
160
+ )
161
+
162
+ def forward(self, x):
163
+ if x.ndim==4:
164
+ x = x.unsqueeze(2)
165
+ return self.block(x)
166
+
167
+
168
+
169
+ class Encoder(nn.Module):
170
+ def __init__(self) -> None:
171
+ super().__init__()
172
+ self.encoder = nn.Sequential(
173
+ CausalConv3d(3, 64, 3),
174
+ ResBlockX(64),
175
+ ResBlockX(64),
176
+ PoolDown222(),
177
+ ResBlockXY(64, 128),
178
+ ResBlockX(128),
179
+ PoolDown222(),
180
+ ResBlockX(128),
181
+ ResBlockX(128),
182
+ PoolDown122(),
183
+ ResBlockXY(128, 256),
184
+ ResBlockX(256),
185
+ ResBlockX(256),
186
+ ResBlockX(256),
187
+ nn.GroupNorm(32, 256),
188
+ Swish(),
189
+ nn.Conv3d(256, 16, 1)
190
+ )
191
+
192
+ def forward(self, x):
193
+ return self.encoder(x)
194
+
195
+ class Decoder(nn.Module):
196
+ def __init__(self) -> None:
197
+ super().__init__()
198
+ self.decoder = nn.Sequential(
199
+ CausalConv3d(8, 256, 3),
200
+ ResBlockX(256),
201
+ ResBlockX(256),
202
+ ResBlockX(256),
203
+ ResBlockX(256),
204
+ Unpool122(),
205
+ CausalConv3d(256, 256, 3),
206
+ ResBlockXY(256, 128),
207
+ ResBlockX(128),
208
+ Unpool222(),
209
+ CausalConv3d(128, 128, 3),
210
+ ResBlockX(128),
211
+ ResBlockX(128),
212
+ Unpool222(),
213
+ CausalConv3d(128, 128, 3),
214
+ ResBlockXY(128, 64),
215
+ ResBlockX(64),
216
+ nn.GroupNorm(32, 64),
217
+ Swish(),
218
+ CausalConv3d(64, 64, 3)
219
+ )
220
+ self.conv_out = nn.Conv3d(64, 3, 1)
221
+
222
+ def forward(self, x):
223
+ return self.conv_out(self.decoder(x))
224
+
225
+
226
+ if __name__=='__main__':
227
+ encoder = Encoder()
228
+ decoder = Decoder()
229
+ dis = Discriminator()
230
+ x = torch.randn((1, 3, 1, 64, 64))
231
+ embedding = encoder(x)
232
+ y = decoder(embedding)
233
+ tmp = torch.randn((1, 4, 1, 64, 64))
234
+ print('something mmm')
easyanimate/vae/ldm/models/omnigen_casual3dcnn.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ from dataclasses import dataclass
3
+ from typing import Optional
4
+
5
+ import numpy as np
6
+ import pytorch_lightning as pl
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+ from ..util import instantiate_from_config
12
+ from .omnigen_enc_dec import Decoder as omnigen_Mag_Decoder
13
+ from .omnigen_enc_dec import Encoder as omnigen_Mag_Encoder
14
+
15
+
16
+ class DiagonalGaussianDistribution:
17
+ def __init__(
18
+ self,
19
+ mean: torch.Tensor,
20
+ logvar: torch.Tensor,
21
+ deterministic: bool = False,
22
+ ):
23
+ self.mean = mean
24
+ self.logvar = torch.clamp(logvar, -30.0, 20.0)
25
+ self.deterministic = deterministic
26
+
27
+ if deterministic:
28
+ self.var = self.std = torch.zeros_like(self.mean)
29
+ else:
30
+ self.std = torch.exp(0.5 * self.logvar)
31
+ self.var = torch.exp(self.logvar)
32
+
33
+ def sample(self, generator = None) -> torch.FloatTensor:
34
+ x = torch.randn(
35
+ self.mean.shape,
36
+ generator=generator,
37
+ device=self.mean.device,
38
+ dtype=self.mean.dtype,
39
+ )
40
+ return self.mean + self.std * x
41
+
42
+ def mode(self):
43
+ return self.mean
44
+
45
+ def kl(self, other: Optional["DiagonalGaussianDistribution"] = None) -> torch.Tensor:
46
+ dims = list(range(1, self.mean.ndim))
47
+
48
+ if self.deterministic:
49
+ return torch.Tensor([0.0])
50
+ else:
51
+ if other is None:
52
+ return 0.5 * torch.sum(
53
+ torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
54
+ dim=dims,
55
+ )
56
+ else:
57
+ return 0.5 * torch.sum(
58
+ torch.pow(self.mean - other.mean, 2) / other.var
59
+ + self.var / other.var
60
+ - 1.0
61
+ - self.logvar
62
+ + other.logvar,
63
+ dim=dims,
64
+ )
65
+
66
+ def nll(self, sample: torch.Tensor) -> torch.Tensor:
67
+ dims = list(range(1, self.mean.ndim))
68
+
69
+ if self.deterministic:
70
+ return torch.Tensor([0.0])
71
+
72
+ logtwopi = np.log(2.0 * np.pi)
73
+ return 0.5 * torch.sum(
74
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
75
+ dim=dims,
76
+ )
77
+
78
+ @dataclass
79
+ class EncoderOutput:
80
+ latent_dist: DiagonalGaussianDistribution
81
+
82
+ @dataclass
83
+ class DecoderOutput:
84
+ sample: torch.Tensor
85
+
86
+ def str_eval(item):
87
+ if type(item) == str:
88
+ return eval(item)
89
+ else:
90
+ return item
91
+
92
+ class AutoencoderKLMagvit_fromOmnigen(pl.LightningModule):
93
+ def __init__(
94
+ self,
95
+ in_channels: int = 3,
96
+ out_channels: int = 3,
97
+ ch = 128,
98
+ ch_mult = [ 1,2,4,4 ],
99
+ use_gc_blocks = None,
100
+ down_block_types: tuple = None,
101
+ up_block_types: tuple = None,
102
+ mid_block_type: str = "MidBlock3D",
103
+ mid_block_use_attention: bool = True,
104
+ mid_block_attention_type: str = "3d",
105
+ mid_block_num_attention_heads: int = 1,
106
+ layers_per_block: int = 2,
107
+ act_fn: str = "silu",
108
+ num_attention_heads: int = 1,
109
+ latent_channels: int = 4,
110
+ norm_num_groups: int = 32,
111
+ image_key="image",
112
+ monitor=None,
113
+ ckpt_path=None,
114
+ lossconfig=None,
115
+ slice_compression_vae=False,
116
+ mini_batch_encoder=9,
117
+ mini_batch_decoder=3,
118
+ train_decoder_only=False,
119
+ ):
120
+ super().__init__()
121
+ self.image_key = image_key
122
+ down_block_types = str_eval(down_block_types)
123
+ up_block_types = str_eval(up_block_types)
124
+ self.encoder = omnigen_Mag_Encoder(
125
+ in_channels=in_channels,
126
+ out_channels=latent_channels,
127
+ down_block_types=down_block_types,
128
+ ch = ch,
129
+ ch_mult = ch_mult,
130
+ use_gc_blocks=use_gc_blocks,
131
+ mid_block_type=mid_block_type,
132
+ mid_block_use_attention=mid_block_use_attention,
133
+ mid_block_attention_type=mid_block_attention_type,
134
+ mid_block_num_attention_heads=mid_block_num_attention_heads,
135
+ layers_per_block=layers_per_block,
136
+ norm_num_groups=norm_num_groups,
137
+ act_fn=act_fn,
138
+ num_attention_heads=num_attention_heads,
139
+ double_z=True,
140
+ slice_compression_vae=slice_compression_vae,
141
+ mini_batch_encoder=mini_batch_encoder,
142
+ )
143
+
144
+ self.decoder = omnigen_Mag_Decoder(
145
+ in_channels=latent_channels,
146
+ out_channels=out_channels,
147
+ up_block_types=up_block_types,
148
+ ch = ch,
149
+ ch_mult = ch_mult,
150
+ use_gc_blocks=use_gc_blocks,
151
+ mid_block_type=mid_block_type,
152
+ mid_block_use_attention=mid_block_use_attention,
153
+ mid_block_attention_type=mid_block_attention_type,
154
+ mid_block_num_attention_heads=mid_block_num_attention_heads,
155
+ layers_per_block=layers_per_block,
156
+ norm_num_groups=norm_num_groups,
157
+ act_fn=act_fn,
158
+ num_attention_heads=num_attention_heads,
159
+ slice_compression_vae=slice_compression_vae,
160
+ mini_batch_decoder=mini_batch_decoder,
161
+ )
162
+
163
+ self.quant_conv = nn.Conv3d(2 * latent_channels, 2 * latent_channels, kernel_size=1)
164
+ self.post_quant_conv = nn.Conv3d(latent_channels, latent_channels, kernel_size=1)
165
+
166
+ self.mini_batch_encoder = mini_batch_encoder
167
+ self.mini_batch_decoder = mini_batch_decoder
168
+ self.train_decoder_only = train_decoder_only
169
+ if train_decoder_only:
170
+ self.encoder.requires_grad_(False)
171
+ self.quant_conv.requires_grad_(False)
172
+ if monitor is not None:
173
+ self.monitor = monitor
174
+ if ckpt_path is not None:
175
+ self.init_from_ckpt(ckpt_path, ignore_keys="loss")
176
+ if lossconfig is not None:
177
+ self.loss = instantiate_from_config(lossconfig)
178
+
179
+ def init_from_ckpt(self, path, ignore_keys=list()):
180
+ if path.endswith("safetensors"):
181
+ from safetensors.torch import load_file, safe_open
182
+ sd = load_file(path)
183
+ else:
184
+ sd = torch.load(path, map_location="cpu")
185
+ if "state_dict" in list(sd.keys()):
186
+ sd = sd["state_dict"]
187
+ keys = list(sd.keys())
188
+ for k in keys:
189
+ for ik in ignore_keys:
190
+ if k.startswith(ik):
191
+ print("Deleting key {} from state_dict.".format(k))
192
+ del sd[k]
193
+ self.load_state_dict(sd, strict=False) # loss.item can be ignored successfully
194
+ print(f"Restored from {path}")
195
+
196
+ def encode(self, x: torch.Tensor) -> EncoderOutput:
197
+ h = self.encoder(x)
198
+
199
+ moments: torch.Tensor = self.quant_conv(h)
200
+ mean, logvar = moments.chunk(2, dim=1)
201
+ posterior = DiagonalGaussianDistribution(mean, logvar)
202
+
203
+ # return EncoderOutput(latent_dist=posterior)
204
+ return posterior
205
+
206
+ def decode(self, z: torch.Tensor) -> DecoderOutput:
207
+ z = self.post_quant_conv(z)
208
+
209
+ decoded = self.decoder(z)
210
+
211
+ # return DecoderOutput(sample=decoded)
212
+ return decoded
213
+
214
+
215
+ def forward(self, input, sample_posterior=True):
216
+ if input.ndim==4:
217
+ input = input.unsqueeze(2)
218
+ posterior = self.encode(input)
219
+ if sample_posterior:
220
+ z = posterior.sample()
221
+ else:
222
+ z = posterior.mode()
223
+ # print("stt latent shape", z.shape)
224
+ dec = self.decode(z)
225
+ return dec, posterior
226
+
227
+ def get_input(self, batch, k):
228
+ x = batch[k]
229
+ if x.ndim==5:
230
+ x = x.permute(0, 4, 1, 2, 3).to(memory_format=torch.contiguous_format).float()
231
+ return x
232
+ if len(x.shape) == 3:
233
+ x = x[..., None]
234
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
235
+ return x
236
+
237
+ def training_step(self, batch, batch_idx, optimizer_idx):
238
+ # tic = time.time()
239
+ inputs = self.get_input(batch, self.image_key)
240
+ # print(f"get_input time {time.time() - tic}")
241
+ # tic = time.time()
242
+ reconstructions, posterior = self(inputs)
243
+ # print(f"model forward time {time.time() - tic}")
244
+
245
+ if optimizer_idx == 0:
246
+ # train encoder+decoder+logvar
247
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
248
+ last_layer=self.get_last_layer(), split="train")
249
+ self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
250
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
251
+ # print(f"cal loss time {time.time() - tic}")
252
+ return aeloss
253
+
254
+ if optimizer_idx == 1:
255
+ # train the discriminator
256
+ discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
257
+ last_layer=self.get_last_layer(), split="train")
258
+
259
+ self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
260
+ self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
261
+ # print(f"cal loss time {time.time() - tic}")
262
+ return discloss
263
+
264
+ def validation_step(self, batch, batch_idx):
265
+ with torch.no_grad():
266
+ inputs = self.get_input(batch, self.image_key)
267
+ reconstructions, posterior = self(inputs)
268
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
269
+ last_layer=self.get_last_layer(), split="val")
270
+
271
+ discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
272
+ last_layer=self.get_last_layer(), split="val")
273
+
274
+ self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
275
+ self.log_dict(log_dict_ae)
276
+ self.log_dict(log_dict_disc)
277
+ return self.log_dict
278
+
279
+ def configure_optimizers(self):
280
+ lr = self.learning_rate
281
+ if self.train_decoder_only:
282
+ opt_ae = torch.optim.Adam(list(self.decoder.parameters())+
283
+ list(self.post_quant_conv.parameters()),
284
+ lr=lr, betas=(0.5, 0.9))
285
+ else:
286
+ opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
287
+ list(self.decoder.parameters())+
288
+ list(self.quant_conv.parameters())+
289
+ list(self.post_quant_conv.parameters()),
290
+ lr=lr, betas=(0.5, 0.9))
291
+ opt_disc = torch.optim.Adam(list(self.loss.discriminator3d.parameters()) + list(self.loss.discriminator.parameters()),
292
+ lr=lr, betas=(0.5, 0.9))
293
+ return [opt_ae, opt_disc], []
294
+
295
+ def get_last_layer(self):
296
+ return self.decoder.conv_out.weight
297
+
298
+ @torch.no_grad()
299
+ def log_images(self, batch, only_inputs=False, **kwargs):
300
+ log = dict()
301
+ x = self.get_input(batch, self.image_key)
302
+ x = x.to(self.device)
303
+ if not only_inputs:
304
+ xrec, posterior = self(x)
305
+ if x.shape[1] > 3:
306
+ # colorize with random projection
307
+ assert xrec.shape[1] > 3
308
+ x = self.to_rgb(x)
309
+ xrec = self.to_rgb(xrec)
310
+ log["samples"] = self.decode(torch.randn_like(posterior.sample()))
311
+ log["reconstructions"] = xrec
312
+ log["inputs"] = x
313
+ return log
314
+
315
+ def to_rgb(self, x):
316
+ assert self.image_key == "segmentation"
317
+ if not hasattr(self, "colorize"):
318
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
319
+ x = F.conv2d(x, weight=self.colorize)
320
+ x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
321
+ return x
easyanimate/vae/ldm/models/omnigen_enc_dec.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ from ..modules.vaemodules.activations import get_activation
5
+ from ..modules.vaemodules.common import CausalConv3d
6
+ from ..modules.vaemodules.down_blocks import get_down_block
7
+ from ..modules.vaemodules.mid_blocks import get_mid_block
8
+ from ..modules.vaemodules.up_blocks import get_up_block
9
+
10
+
11
+ class Encoder(nn.Module):
12
+ r"""
13
+ The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation.
14
+
15
+ Args:
16
+ in_channels (`int`, *optional*, defaults to 3):
17
+ The number of input channels.
18
+ out_channels (`int`, *optional*, defaults to 8):
19
+ The number of output channels.
20
+ down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("SpatialDownBlock3D",)`):
21
+ The types of down blocks to use.
22
+ block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
23
+ The number of output channels for each block.
24
+ use_gc_blocks (`Tuple[bool, ...]`, *optional*, defaults to `None`):
25
+ Whether to use global context blocks for each down block.
26
+ mid_block_type (`str`, *optional*, defaults to `"MidBlock3D"`):
27
+ The type of mid block to use.
28
+ layers_per_block (`int`, *optional*, defaults to 2):
29
+ The number of layers per block.
30
+ norm_num_groups (`int`, *optional*, defaults to 32):
31
+ The number of groups for normalization.
32
+ act_fn (`str`, *optional*, defaults to `"silu"`):
33
+ The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
34
+ num_attention_heads (`int`, *optional*, defaults to 1):
35
+ The number of attention heads to use.
36
+ double_z (`bool`, *optional*, defaults to `True`):
37
+ Whether to double the number of output channels for the last block.
38
+ """
39
+
40
+ def __init__(
41
+ self,
42
+ in_channels: int = 3,
43
+ out_channels: int = 8,
44
+ down_block_types = ("SpatialDownBlock3D",),
45
+ ch = 128,
46
+ ch_mult = [1,2,4,4,],
47
+ use_gc_blocks = None,
48
+ mid_block_type: str = "MidBlock3D",
49
+ mid_block_use_attention: bool = True,
50
+ mid_block_attention_type: str = "3d",
51
+ mid_block_num_attention_heads: int = 1,
52
+ layers_per_block: int = 2,
53
+ norm_num_groups: int = 32,
54
+ act_fn: str = "silu",
55
+ num_attention_heads: int = 1,
56
+ double_z: bool = True,
57
+ slice_compression_vae: bool = False,
58
+ mini_batch_encoder: int = 9,
59
+ verbose = False,
60
+ ):
61
+ super().__init__()
62
+ block_out_channels = [ch * i for i in ch_mult]
63
+ assert len(down_block_types) == len(block_out_channels), (
64
+ "Number of down block types must match number of block output channels."
65
+ )
66
+ if use_gc_blocks is not None:
67
+ assert len(use_gc_blocks) == len(down_block_types), (
68
+ "Number of GC blocks must match number of down block types."
69
+ )
70
+ else:
71
+ use_gc_blocks = [False] * len(down_block_types)
72
+ self.conv_in = CausalConv3d(
73
+ in_channels,
74
+ block_out_channels[0],
75
+ kernel_size=3,
76
+ )
77
+
78
+ self.down_blocks = nn.ModuleList([])
79
+
80
+ output_channels = block_out_channels[0]
81
+ for i, down_block_type in enumerate(down_block_types):
82
+ input_channels = output_channels
83
+ output_channels = block_out_channels[i]
84
+ is_final_block = (i == len(block_out_channels) - 1)
85
+ down_block = get_down_block(
86
+ down_block_type,
87
+ in_channels=input_channels,
88
+ out_channels=output_channels,
89
+ num_layers=layers_per_block,
90
+ act_fn=act_fn,
91
+ norm_num_groups=norm_num_groups,
92
+ norm_eps=1e-6,
93
+ num_attention_heads=num_attention_heads,
94
+ add_gc_block=use_gc_blocks[i],
95
+ add_downsample=not is_final_block,
96
+ )
97
+ self.down_blocks.append(down_block)
98
+
99
+ self.mid_block = get_mid_block(
100
+ mid_block_type,
101
+ in_channels=block_out_channels[-1],
102
+ num_layers=layers_per_block,
103
+ act_fn=act_fn,
104
+ norm_num_groups=norm_num_groups,
105
+ norm_eps=1e-6,
106
+ add_attention=mid_block_use_attention,
107
+ attention_type=mid_block_attention_type,
108
+ num_attention_heads=mid_block_num_attention_heads,
109
+ )
110
+
111
+ self.conv_norm_out = nn.GroupNorm(
112
+ num_channels=block_out_channels[-1],
113
+ num_groups=norm_num_groups,
114
+ eps=1e-6,
115
+ )
116
+ self.conv_act = get_activation(act_fn)
117
+
118
+ conv_out_channels = 2 * out_channels if double_z else out_channels
119
+ self.conv_out = CausalConv3d(block_out_channels[-1], conv_out_channels, kernel_size=3)
120
+
121
+ self.slice_compression_vae = slice_compression_vae
122
+ self.mini_batch_encoder = mini_batch_encoder
123
+ self.features_share = False
124
+ self.verbose = verbose
125
+
126
+ def set_padding_one_frame(self):
127
+ def _set_padding_one_frame(name, module):
128
+ if hasattr(module, 'padding_flag'):
129
+ if self.verbose:
130
+ print('Set pad mode for module[%s] type=%s' % (name, str(type(module))))
131
+ module.padding_flag = 1
132
+ for sub_name, sub_mod in module.named_children():
133
+ _set_padding_one_frame(sub_name, sub_mod)
134
+ for name, module in self.named_children():
135
+ _set_padding_one_frame(name, module)
136
+
137
+ def set_padding_more_frame(self):
138
+ def _set_padding_more_frame(name, module):
139
+ if hasattr(module, 'padding_flag'):
140
+ if self.verbose:
141
+ print('Set pad mode for module[%s] type=%s' % (name, str(type(module))))
142
+ module.padding_flag = 2
143
+ for sub_name, sub_mod in module.named_children():
144
+ _set_padding_more_frame(sub_name, sub_mod)
145
+ for name, module in self.named_children():
146
+ _set_padding_more_frame(name, module)
147
+
148
+ def single_forward(self, x: torch.Tensor, previous_features: torch.Tensor, after_features: torch.Tensor) -> torch.Tensor:
149
+ # x: (B, C, T, H, W)
150
+ if self.features_share and previous_features is not None and after_features is None:
151
+ x = torch.concat([previous_features, x], 2)
152
+ elif self.features_share and previous_features is None and after_features is not None:
153
+ x = torch.concat([x, after_features], 2)
154
+ elif self.features_share and previous_features is not None and after_features is not None:
155
+ x = torch.concat([previous_features, x, after_features], 2)
156
+
157
+ x = self.conv_in(x)
158
+
159
+ for down_block in self.down_blocks:
160
+ x = down_block(x)
161
+
162
+ x = self.mid_block(x)
163
+
164
+ x = self.conv_norm_out(x)
165
+ x = self.conv_act(x)
166
+ x = self.conv_out(x)
167
+
168
+ if self.features_share and previous_features is not None and after_features is None:
169
+ x = x[:, :, 1:]
170
+ elif self.features_share and previous_features is None and after_features is not None:
171
+ x = x[:, :, :2]
172
+ elif self.features_share and previous_features is not None and after_features is not None:
173
+ x = x[:, :, 1:3]
174
+ return x
175
+
176
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
177
+ if self.slice_compression_vae:
178
+ _, _, f, _, _ = x.size()
179
+ if f % 2 != 0:
180
+ self.set_padding_one_frame()
181
+ first_frames = self.single_forward(x[:, :, 0:1, :, :], None, None)
182
+ self.set_padding_more_frame()
183
+
184
+ new_pixel_values = [first_frames]
185
+ start_index = 1
186
+ else:
187
+ self.set_padding_more_frame()
188
+ new_pixel_values = []
189
+ start_index = 0
190
+
191
+ previous_features = None
192
+ for i in range(start_index, x.shape[2], self.mini_batch_encoder):
193
+ after_features = x[:, :, i + self.mini_batch_encoder: i + self.mini_batch_encoder + 4, :, :] if i + self.mini_batch_encoder < x.shape[2] else None
194
+ next_frames = self.single_forward(x[:, :, i: i + self.mini_batch_encoder, :, :], previous_features, after_features)
195
+ previous_features = x[:, :, i + self.mini_batch_encoder - 4: i + self.mini_batch_encoder, :, :]
196
+ new_pixel_values.append(next_frames)
197
+ new_pixel_values = torch.cat(new_pixel_values, dim=2)
198
+ else:
199
+ new_pixel_values = self.single_forward(x, None, None)
200
+ return new_pixel_values
201
+
202
+ class Decoder(nn.Module):
203
+ r"""
204
+ The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample.
205
+
206
+ Args:
207
+ in_channels (`int`, *optional*, defaults to 8):
208
+ The number of input channels.
209
+ out_channels (`int`, *optional*, defaults to 3):
210
+ The number of output channels.
211
+ up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("SpatialUpBlock3D",)`):
212
+ The types of up blocks to use.
213
+ block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
214
+ The number of output channels for each block.
215
+ use_gc_blocks (`Tuple[bool, ...]`, *optional*, defaults to `None`):
216
+ Whether to use global context blocks for each down block.
217
+ mid_block_type (`str`, *optional*, defaults to `"MidBlock3D"`):
218
+ The type of mid block to use.
219
+ layers_per_block (`int`, *optional*, defaults to 2):
220
+ The number of layers per block.
221
+ norm_num_groups (`int`, *optional*, defaults to 32):
222
+ The number of groups for normalization.
223
+ act_fn (`str`, *optional*, defaults to `"silu"`):
224
+ The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
225
+ num_attention_heads (`int`, *optional*, defaults to 1):
226
+ The number of attention heads to use.
227
+ """
228
+
229
+ def __init__(
230
+ self,
231
+ in_channels: int = 8,
232
+ out_channels: int = 3,
233
+ up_block_types = ("SpatialUpBlock3D",),
234
+ ch = 128,
235
+ ch_mult = [1,2,4,4,],
236
+ use_gc_blocks = None,
237
+ mid_block_type: str = "MidBlock3D",
238
+ mid_block_use_attention: bool = True,
239
+ mid_block_attention_type: str = "3d",
240
+ mid_block_num_attention_heads: int = 1,
241
+ layers_per_block: int = 2,
242
+ norm_num_groups: int = 32,
243
+ act_fn: str = "silu",
244
+ num_attention_heads: int = 1,
245
+ slice_compression_vae: bool = False,
246
+ mini_batch_decoder: int = 3,
247
+ verbose = False,
248
+ ):
249
+ super().__init__()
250
+ block_out_channels = [ch * i for i in ch_mult]
251
+ assert len(up_block_types) == len(block_out_channels), (
252
+ "Number of up block types must match number of block output channels."
253
+ )
254
+ if use_gc_blocks is not None:
255
+ assert len(use_gc_blocks) == len(up_block_types), (
256
+ "Number of GC blocks must match number of up block types."
257
+ )
258
+ else:
259
+ use_gc_blocks = [False] * len(up_block_types)
260
+
261
+ self.conv_in = CausalConv3d(
262
+ in_channels,
263
+ block_out_channels[-1],
264
+ kernel_size=3,
265
+ )
266
+
267
+ self.mid_block = get_mid_block(
268
+ mid_block_type,
269
+ in_channels=block_out_channels[-1],
270
+ num_layers=layers_per_block,
271
+ act_fn=act_fn,
272
+ norm_num_groups=norm_num_groups,
273
+ norm_eps=1e-6,
274
+ add_attention=mid_block_use_attention,
275
+ attention_type=mid_block_attention_type,
276
+ num_attention_heads=mid_block_num_attention_heads,
277
+ )
278
+
279
+ self.up_blocks = nn.ModuleList([])
280
+
281
+ reversed_block_out_channels = list(reversed(block_out_channels))
282
+ output_channels = reversed_block_out_channels[0]
283
+ for i, up_block_type in enumerate(up_block_types):
284
+ input_channels = output_channels
285
+ output_channels = reversed_block_out_channels[i]
286
+ # is_first_block = i == 0
287
+ is_final_block = i == len(block_out_channels) - 1
288
+
289
+ up_block = get_up_block(
290
+ up_block_type,
291
+ in_channels=input_channels,
292
+ out_channels=output_channels,
293
+ num_layers=layers_per_block + 1,
294
+ act_fn=act_fn,
295
+ norm_num_groups=norm_num_groups,
296
+ norm_eps=1e-6,
297
+ num_attention_heads=num_attention_heads,
298
+ add_gc_block=use_gc_blocks[i],
299
+ add_upsample=not is_final_block,
300
+ )
301
+ self.up_blocks.append(up_block)
302
+
303
+ self.conv_norm_out = nn.GroupNorm(
304
+ num_channels=block_out_channels[0],
305
+ num_groups=norm_num_groups,
306
+ eps=1e-6,
307
+ )
308
+ self.conv_act = get_activation(act_fn)
309
+
310
+ self.conv_out = CausalConv3d(block_out_channels[0], out_channels, kernel_size=3)
311
+
312
+ self.slice_compression_vae = slice_compression_vae
313
+ self.mini_batch_decoder = mini_batch_decoder
314
+ self.features_share = True
315
+ self.verbose = verbose
316
+
317
+ def set_padding_one_frame(self):
318
+ def _set_padding_one_frame(name, module):
319
+ if hasattr(module, 'padding_flag'):
320
+ if self.verbose:
321
+ print('Set pad mode for module[%s] type=%s' % (name, str(type(module))))
322
+ module.padding_flag = 1
323
+ for sub_name, sub_mod in module.named_children():
324
+ _set_padding_one_frame(sub_name, sub_mod)
325
+ for name, module in self.named_children():
326
+ _set_padding_one_frame(name, module)
327
+
328
+ def set_padding_more_frame(self):
329
+ def _set_padding_more_frame(name, module):
330
+ if hasattr(module, 'padding_flag'):
331
+ if self.verbose:
332
+ print('Set pad mode for module[%s] type=%s' % (name, str(type(module))))
333
+ module.padding_flag = 2
334
+ for sub_name, sub_mod in module.named_children():
335
+ _set_padding_more_frame(sub_name, sub_mod)
336
+ for name, module in self.named_children():
337
+ _set_padding_more_frame(name, module)
338
+
339
+ def single_forward(self, x: torch.Tensor, previous_features: torch.Tensor, after_features: torch.Tensor) -> torch.Tensor:
340
+ # x: (B, C, T, H, W)
341
+ if self.features_share and previous_features is not None and after_features is None:
342
+ b, c, t, h, w = x.size()
343
+ x = torch.concat([previous_features, x], 2)
344
+ x = self.conv_in(x)
345
+ x = self.mid_block(x)
346
+ x = x[:, :, -t:]
347
+ elif self.features_share and previous_features is None and after_features is not None:
348
+ b, c, t, h, w = x.size()
349
+ x = torch.concat([x, after_features], 2)
350
+ x = self.conv_in(x)
351
+ x = self.mid_block(x)
352
+ x = x[:, :, :t]
353
+ elif self.features_share and previous_features is not None and after_features is not None:
354
+ _, _, t_1, _, _ = previous_features.size()
355
+ _, _, t_2, _, _ = x.size()
356
+ x = torch.concat([previous_features, x, after_features], 2)
357
+ x = self.conv_in(x)
358
+ x = self.mid_block(x)
359
+ x = x[:, :, t_1:(t_1 + t_2)]
360
+ else:
361
+ x = self.conv_in(x)
362
+ x = self.mid_block(x)
363
+
364
+ for up_block in self.up_blocks:
365
+ x = up_block(x)
366
+
367
+ x = self.conv_norm_out(x)
368
+ x = self.conv_act(x)
369
+ x = self.conv_out(x)
370
+
371
+ return x
372
+
373
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
374
+ if self.slice_compression_vae:
375
+ _, _, f, _, _ = x.size()
376
+ if f % 2 != 0:
377
+ self.set_padding_one_frame()
378
+ first_frames = self.single_forward(x[:, :, 0:1, :, :], None, None)
379
+ self.set_padding_more_frame()
380
+ new_pixel_values = [first_frames]
381
+ start_index = 1
382
+ else:
383
+ self.set_padding_more_frame()
384
+ new_pixel_values = []
385
+ start_index = 0
386
+
387
+ previous_features = None
388
+ for i in range(start_index, x.shape[2], self.mini_batch_decoder):
389
+ after_features = x[:, :, i + self.mini_batch_decoder: i + 2 * self.mini_batch_decoder, :, :] if i + self.mini_batch_decoder < x.shape[2] else None
390
+ next_frames = self.single_forward(x[:, :, i: i + self.mini_batch_decoder, :, :], previous_features, after_features)
391
+ previous_features = x[:, :, i: i + self.mini_batch_decoder, :, :]
392
+ new_pixel_values.append(next_frames)
393
+ new_pixel_values = torch.cat(new_pixel_values, dim=2)
394
+ else:
395
+ new_pixel_values = self.single_forward(x, None, None)
396
+ return new_pixel_values
easyanimate/video_caption/datasets/put preprocess datasets here.txt ADDED
File without changes