Spaces:
Paused
Paused
Merge pull request #12 from LightricksResearch/cvae-arch-refactoring
Browse files
xora/models/autoencoders/causal_conv3d.py
CHANGED
@@ -11,6 +11,8 @@ class CausalConv3d(nn.Module):
|
|
11 |
out_channels,
|
12 |
kernel_size: int = 3,
|
13 |
stride: Union[int, Tuple[int]] = 1,
|
|
|
|
|
14 |
**kwargs,
|
15 |
):
|
16 |
super().__init__()
|
@@ -21,7 +23,6 @@ class CausalConv3d(nn.Module):
|
|
21 |
kernel_size = (kernel_size, kernel_size, kernel_size)
|
22 |
self.time_kernel_size = kernel_size[0]
|
23 |
|
24 |
-
dilation = kwargs.pop("dilation", 1)
|
25 |
dilation = (dilation, 1, 1)
|
26 |
|
27 |
height_pad = kernel_size[1] // 2
|
@@ -36,6 +37,7 @@ class CausalConv3d(nn.Module):
|
|
36 |
dilation=dilation,
|
37 |
padding=padding,
|
38 |
padding_mode="zeros",
|
|
|
39 |
)
|
40 |
|
41 |
def forward(self, x, causal: bool = True):
|
|
|
11 |
out_channels,
|
12 |
kernel_size: int = 3,
|
13 |
stride: Union[int, Tuple[int]] = 1,
|
14 |
+
dilation: int = 1,
|
15 |
+
groups: int = 1,
|
16 |
**kwargs,
|
17 |
):
|
18 |
super().__init__()
|
|
|
23 |
kernel_size = (kernel_size, kernel_size, kernel_size)
|
24 |
self.time_kernel_size = kernel_size[0]
|
25 |
|
|
|
26 |
dilation = (dilation, 1, 1)
|
27 |
|
28 |
height_pad = kernel_size[1] // 2
|
|
|
37 |
dilation=dilation,
|
38 |
padding=padding,
|
39 |
padding_mode="zeros",
|
40 |
+
groups=groups,
|
41 |
)
|
42 |
|
43 |
def forward(self, x, causal: bool = True):
|
xora/models/autoencoders/causal_video_autoencoder.py
CHANGED
@@ -78,7 +78,7 @@ class CausalVideoAutoencoder(AutoencoderKLWrapper):
|
|
78 |
dims=config["dims"],
|
79 |
in_channels=config.get("in_channels", 3),
|
80 |
out_channels=config["latent_channels"],
|
81 |
-
blocks=config
|
82 |
patch_size=config.get("patch_size", 1),
|
83 |
latent_log_var=latent_log_var,
|
84 |
norm_layer=config.get("norm_layer", "group_norm"),
|
@@ -88,7 +88,7 @@ class CausalVideoAutoencoder(AutoencoderKLWrapper):
|
|
88 |
dims=config["dims"],
|
89 |
in_channels=config["latent_channels"],
|
90 |
out_channels=config.get("out_channels", 3),
|
91 |
-
blocks=config
|
92 |
patch_size=config.get("patch_size", 1),
|
93 |
norm_layer=config.get("norm_layer", "group_norm"),
|
94 |
causal=config.get("causal_decoder", False),
|
@@ -112,7 +112,8 @@ class CausalVideoAutoencoder(AutoencoderKLWrapper):
|
|
112 |
out_channels=self.decoder.conv_out.out_channels
|
113 |
// self.decoder.patch_size**2,
|
114 |
latent_channels=self.decoder.conv_in.in_channels,
|
115 |
-
|
|
|
116 |
scaling_factor=1.0,
|
117 |
norm_layer=self.encoder.norm_layer,
|
118 |
patch_size=self.encoder.patch_size,
|
@@ -242,7 +243,7 @@ class Encoder(nn.Module):
|
|
242 |
dims: Union[int, Tuple[int, int]] = 3,
|
243 |
in_channels: int = 3,
|
244 |
out_channels: int = 3,
|
245 |
-
blocks: List[Tuple[str, int]] = [("res_x", 1)],
|
246 |
base_channels: int = 128,
|
247 |
norm_num_groups: int = 32,
|
248 |
patch_size: Union[int, Tuple[int]] = 1,
|
@@ -271,20 +272,22 @@ class Encoder(nn.Module):
|
|
271 |
|
272 |
self.down_blocks = nn.ModuleList([])
|
273 |
|
274 |
-
for block_name,
|
275 |
input_channel = output_channel
|
|
|
|
|
276 |
|
277 |
if block_name == "res_x":
|
278 |
block = UNetMidBlock3D(
|
279 |
dims=dims,
|
280 |
in_channels=input_channel,
|
281 |
-
num_layers=num_layers,
|
282 |
resnet_eps=1e-6,
|
283 |
resnet_groups=norm_num_groups,
|
284 |
norm_layer=norm_layer,
|
285 |
)
|
286 |
elif block_name == "res_x_y":
|
287 |
-
output_channel = 2 * output_channel
|
288 |
block = ResnetBlock3D(
|
289 |
dims=dims,
|
290 |
in_channels=input_channel,
|
@@ -320,6 +323,16 @@ class Encoder(nn.Module):
|
|
320 |
stride=(2, 2, 2),
|
321 |
causal=True,
|
322 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
323 |
else:
|
324 |
raise ValueError(f"unknown block: {block_name}")
|
325 |
|
@@ -421,7 +434,7 @@ class Decoder(nn.Module):
|
|
421 |
dims,
|
422 |
in_channels: int = 3,
|
423 |
out_channels: int = 3,
|
424 |
-
blocks: List[Tuple[str, int]] = [("res_x", 1)],
|
425 |
base_channels: int = 128,
|
426 |
layers_per_block: int = 2,
|
427 |
norm_num_groups: int = 32,
|
@@ -433,9 +446,15 @@ class Decoder(nn.Module):
|
|
433 |
self.patch_size = patch_size
|
434 |
self.layers_per_block = layers_per_block
|
435 |
out_channels = out_channels * patch_size**2
|
436 |
-
num_channel_doubles = len([x for x in blocks if x[0] == "res_x_y"])
|
437 |
-
output_channel = base_channels * 2**num_channel_doubles
|
438 |
self.causal = causal
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
439 |
|
440 |
self.conv_in = make_conv_nd(
|
441 |
dims,
|
@@ -449,20 +468,22 @@ class Decoder(nn.Module):
|
|
449 |
|
450 |
self.up_blocks = nn.ModuleList([])
|
451 |
|
452 |
-
for block_name,
|
453 |
input_channel = output_channel
|
|
|
|
|
454 |
|
455 |
if block_name == "res_x":
|
456 |
block = UNetMidBlock3D(
|
457 |
dims=dims,
|
458 |
in_channels=input_channel,
|
459 |
-
num_layers=num_layers,
|
460 |
resnet_eps=1e-6,
|
461 |
resnet_groups=norm_num_groups,
|
462 |
norm_layer=norm_layer,
|
463 |
)
|
464 |
elif block_name == "res_x_y":
|
465 |
-
output_channel = output_channel // 2
|
466 |
block = ResnetBlock3D(
|
467 |
dims=dims,
|
468 |
in_channels=input_channel,
|
@@ -481,7 +502,10 @@ class Decoder(nn.Module):
|
|
481 |
)
|
482 |
elif block_name == "compress_all":
|
483 |
block = DepthToSpaceUpsample(
|
484 |
-
dims=dims,
|
|
|
|
|
|
|
485 |
)
|
486 |
else:
|
487 |
raise ValueError(f"unknown layer: {block_name}")
|
@@ -590,7 +614,7 @@ class UNetMidBlock3D(nn.Module):
|
|
590 |
|
591 |
|
592 |
class DepthToSpaceUpsample(nn.Module):
|
593 |
-
def __init__(self, dims, in_channels, stride):
|
594 |
super().__init__()
|
595 |
self.stride = stride
|
596 |
self.out_channels = np.prod(stride) * in_channels
|
@@ -602,8 +626,21 @@ class DepthToSpaceUpsample(nn.Module):
|
|
602 |
stride=1,
|
603 |
causal=True,
|
604 |
)
|
|
|
605 |
|
606 |
def forward(self, x, causal: bool = True):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
607 |
x = self.conv(x, causal=causal)
|
608 |
x = rearrange(
|
609 |
x,
|
@@ -614,6 +651,8 @@ class DepthToSpaceUpsample(nn.Module):
|
|
614 |
)
|
615 |
if self.stride[0] == 2:
|
616 |
x = x[:, :, 1:, :, :]
|
|
|
|
|
617 |
return x
|
618 |
|
619 |
|
@@ -647,7 +686,6 @@ class ResnetBlock3D(nn.Module):
|
|
647 |
dims: Union[int, Tuple[int, int]],
|
648 |
in_channels: int,
|
649 |
out_channels: Optional[int] = None,
|
650 |
-
conv_shortcut: bool = False,
|
651 |
dropout: float = 0.0,
|
652 |
groups: int = 32,
|
653 |
eps: float = 1e-6,
|
@@ -657,7 +695,6 @@ class ResnetBlock3D(nn.Module):
|
|
657 |
self.in_channels = in_channels
|
658 |
out_channels = in_channels if out_channels is None else out_channels
|
659 |
self.out_channels = out_channels
|
660 |
-
self.use_conv_shortcut = conv_shortcut
|
661 |
|
662 |
if norm_layer == "group_norm":
|
663 |
self.norm1 = nn.GroupNorm(
|
|
|
78 |
dims=config["dims"],
|
79 |
in_channels=config.get("in_channels", 3),
|
80 |
out_channels=config["latent_channels"],
|
81 |
+
blocks=config.get("encoder_blocks", config.get("blocks")),
|
82 |
patch_size=config.get("patch_size", 1),
|
83 |
latent_log_var=latent_log_var,
|
84 |
norm_layer=config.get("norm_layer", "group_norm"),
|
|
|
88 |
dims=config["dims"],
|
89 |
in_channels=config["latent_channels"],
|
90 |
out_channels=config.get("out_channels", 3),
|
91 |
+
blocks=config.get("decoder_blocks", config.get("blocks")),
|
92 |
patch_size=config.get("patch_size", 1),
|
93 |
norm_layer=config.get("norm_layer", "group_norm"),
|
94 |
causal=config.get("causal_decoder", False),
|
|
|
112 |
out_channels=self.decoder.conv_out.out_channels
|
113 |
// self.decoder.patch_size**2,
|
114 |
latent_channels=self.decoder.conv_in.in_channels,
|
115 |
+
encoder_blocks=self.encoder.blocks_desc,
|
116 |
+
decoder_blocks=self.decoder.blocks_desc,
|
117 |
scaling_factor=1.0,
|
118 |
norm_layer=self.encoder.norm_layer,
|
119 |
patch_size=self.encoder.patch_size,
|
|
|
243 |
dims: Union[int, Tuple[int, int]] = 3,
|
244 |
in_channels: int = 3,
|
245 |
out_channels: int = 3,
|
246 |
+
blocks: List[Tuple[str, int | dict]] = [("res_x", 1)],
|
247 |
base_channels: int = 128,
|
248 |
norm_num_groups: int = 32,
|
249 |
patch_size: Union[int, Tuple[int]] = 1,
|
|
|
272 |
|
273 |
self.down_blocks = nn.ModuleList([])
|
274 |
|
275 |
+
for block_name, block_params in blocks:
|
276 |
input_channel = output_channel
|
277 |
+
if isinstance(block_params, int):
|
278 |
+
block_params = {"num_layers": block_params}
|
279 |
|
280 |
if block_name == "res_x":
|
281 |
block = UNetMidBlock3D(
|
282 |
dims=dims,
|
283 |
in_channels=input_channel,
|
284 |
+
num_layers=block_params["num_layers"],
|
285 |
resnet_eps=1e-6,
|
286 |
resnet_groups=norm_num_groups,
|
287 |
norm_layer=norm_layer,
|
288 |
)
|
289 |
elif block_name == "res_x_y":
|
290 |
+
output_channel = block_params.get("multiplier", 2) * output_channel
|
291 |
block = ResnetBlock3D(
|
292 |
dims=dims,
|
293 |
in_channels=input_channel,
|
|
|
323 |
stride=(2, 2, 2),
|
324 |
causal=True,
|
325 |
)
|
326 |
+
elif block_name == "compress_all_x_y":
|
327 |
+
output_channel = block_params.get("multiplier", 2) * output_channel
|
328 |
+
block = make_conv_nd(
|
329 |
+
dims=dims,
|
330 |
+
in_channels=input_channel,
|
331 |
+
out_channels=output_channel,
|
332 |
+
kernel_size=3,
|
333 |
+
stride=(2, 2, 2),
|
334 |
+
causal=True,
|
335 |
+
)
|
336 |
else:
|
337 |
raise ValueError(f"unknown block: {block_name}")
|
338 |
|
|
|
434 |
dims,
|
435 |
in_channels: int = 3,
|
436 |
out_channels: int = 3,
|
437 |
+
blocks: List[Tuple[str, int | dict]] = [("res_x", 1)],
|
438 |
base_channels: int = 128,
|
439 |
layers_per_block: int = 2,
|
440 |
norm_num_groups: int = 32,
|
|
|
446 |
self.patch_size = patch_size
|
447 |
self.layers_per_block = layers_per_block
|
448 |
out_channels = out_channels * patch_size**2
|
|
|
|
|
449 |
self.causal = causal
|
450 |
+
self.blocks_desc = blocks
|
451 |
+
|
452 |
+
# Compute output channel to be product of all channel-multiplier blocks
|
453 |
+
output_channel = base_channels
|
454 |
+
for block_name, block_params in list(reversed(blocks)):
|
455 |
+
block_params = block_params if isinstance(block_params, dict) else {}
|
456 |
+
if block_name == "res_x_y":
|
457 |
+
output_channel = output_channel * block_params.get("multiplier", 2)
|
458 |
|
459 |
self.conv_in = make_conv_nd(
|
460 |
dims,
|
|
|
468 |
|
469 |
self.up_blocks = nn.ModuleList([])
|
470 |
|
471 |
+
for block_name, block_params in list(reversed(blocks)):
|
472 |
input_channel = output_channel
|
473 |
+
if isinstance(block_params, int):
|
474 |
+
block_params = {"num_layers": block_params}
|
475 |
|
476 |
if block_name == "res_x":
|
477 |
block = UNetMidBlock3D(
|
478 |
dims=dims,
|
479 |
in_channels=input_channel,
|
480 |
+
num_layers=block_params["num_layers"],
|
481 |
resnet_eps=1e-6,
|
482 |
resnet_groups=norm_num_groups,
|
483 |
norm_layer=norm_layer,
|
484 |
)
|
485 |
elif block_name == "res_x_y":
|
486 |
+
output_channel = output_channel // block_params.get("multiplier", 2)
|
487 |
block = ResnetBlock3D(
|
488 |
dims=dims,
|
489 |
in_channels=input_channel,
|
|
|
502 |
)
|
503 |
elif block_name == "compress_all":
|
504 |
block = DepthToSpaceUpsample(
|
505 |
+
dims=dims,
|
506 |
+
in_channels=input_channel,
|
507 |
+
stride=(2, 2, 2),
|
508 |
+
residual=block_params.get("residual", False),
|
509 |
)
|
510 |
else:
|
511 |
raise ValueError(f"unknown layer: {block_name}")
|
|
|
614 |
|
615 |
|
616 |
class DepthToSpaceUpsample(nn.Module):
|
617 |
+
def __init__(self, dims, in_channels, stride, residual=False):
|
618 |
super().__init__()
|
619 |
self.stride = stride
|
620 |
self.out_channels = np.prod(stride) * in_channels
|
|
|
626 |
stride=1,
|
627 |
causal=True,
|
628 |
)
|
629 |
+
self.residual = residual
|
630 |
|
631 |
def forward(self, x, causal: bool = True):
|
632 |
+
if self.residual:
|
633 |
+
# Reshape and duplicate the input to match the output shape
|
634 |
+
x_in = rearrange(
|
635 |
+
x,
|
636 |
+
"b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
|
637 |
+
p1=self.stride[0],
|
638 |
+
p2=self.stride[1],
|
639 |
+
p3=self.stride[2],
|
640 |
+
)
|
641 |
+
x_in = x_in.repeat(1, np.prod(self.stride), 1, 1, 1)
|
642 |
+
if self.stride[0] == 2:
|
643 |
+
x_in = x_in[:, :, 1:, :, :]
|
644 |
x = self.conv(x, causal=causal)
|
645 |
x = rearrange(
|
646 |
x,
|
|
|
651 |
)
|
652 |
if self.stride[0] == 2:
|
653 |
x = x[:, :, 1:, :, :]
|
654 |
+
if self.residual:
|
655 |
+
x = x + x_in
|
656 |
return x
|
657 |
|
658 |
|
|
|
686 |
dims: Union[int, Tuple[int, int]],
|
687 |
in_channels: int,
|
688 |
out_channels: Optional[int] = None,
|
|
|
689 |
dropout: float = 0.0,
|
690 |
groups: int = 32,
|
691 |
eps: float = 1e-6,
|
|
|
695 |
self.in_channels = in_channels
|
696 |
out_channels = in_channels if out_channels is None else out_channels
|
697 |
self.out_channels = out_channels
|
|
|
698 |
|
699 |
if norm_layer == "group_norm":
|
700 |
self.norm1 = nn.GroupNorm(
|