VAE Decoder: Inject noise between conv layers.
Browse files1.Add inject_noise flag to res_x, rex_x_y blocks.
2.Init noise to zero in ResnetBlock3D constructor.
2.Add _feed_spatial_noise method to inject noise between conv layers.
xora/models/autoencoders/causal_video_autoencoder.py
CHANGED
@@ -481,6 +481,7 @@ class Decoder(nn.Module):
|
|
481 |
resnet_eps=1e-6,
|
482 |
resnet_groups=norm_num_groups,
|
483 |
norm_layer=norm_layer,
|
|
|
484 |
)
|
485 |
elif block_name == "res_x_y":
|
486 |
output_channel = output_channel // block_params.get("multiplier", 2)
|
@@ -491,6 +492,7 @@ class Decoder(nn.Module):
|
|
491 |
eps=1e-6,
|
492 |
groups=norm_num_groups,
|
493 |
norm_layer=norm_layer,
|
|
|
494 |
)
|
495 |
elif block_name == "compress_time":
|
496 |
block = DepthToSpaceUpsample(
|
@@ -583,6 +585,7 @@ class UNetMidBlock3D(nn.Module):
|
|
583 |
resnet_eps: float = 1e-6,
|
584 |
resnet_groups: int = 32,
|
585 |
norm_layer: str = "group_norm",
|
|
|
586 |
):
|
587 |
super().__init__()
|
588 |
resnet_groups = (
|
@@ -599,6 +602,7 @@ class UNetMidBlock3D(nn.Module):
|
|
599 |
groups=resnet_groups,
|
600 |
dropout=dropout,
|
601 |
norm_layer=norm_layer,
|
|
|
602 |
)
|
603 |
for _ in range(num_layers)
|
604 |
]
|
@@ -690,11 +694,13 @@ class ResnetBlock3D(nn.Module):
|
|
690 |
groups: int = 32,
|
691 |
eps: float = 1e-6,
|
692 |
norm_layer: str = "group_norm",
|
|
|
693 |
):
|
694 |
super().__init__()
|
695 |
self.in_channels = in_channels
|
696 |
out_channels = in_channels if out_channels is None else out_channels
|
697 |
self.out_channels = out_channels
|
|
|
698 |
|
699 |
if norm_layer == "group_norm":
|
700 |
self.norm1 = nn.GroupNorm(
|
@@ -717,6 +723,9 @@ class ResnetBlock3D(nn.Module):
|
|
717 |
causal=True,
|
718 |
)
|
719 |
|
|
|
|
|
|
|
720 |
if norm_layer == "group_norm":
|
721 |
self.norm2 = nn.GroupNorm(
|
722 |
num_groups=groups, num_channels=out_channels, eps=eps, affine=True
|
@@ -738,6 +747,9 @@ class ResnetBlock3D(nn.Module):
|
|
738 |
causal=True,
|
739 |
)
|
740 |
|
|
|
|
|
|
|
741 |
self.conv_shortcut = (
|
742 |
make_linear_nd(
|
743 |
dims=dims, in_channels=in_channels, out_channels=out_channels
|
@@ -752,6 +764,20 @@ class ResnetBlock3D(nn.Module):
|
|
752 |
else nn.Identity()
|
753 |
)
|
754 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
755 |
def forward(
|
756 |
self,
|
757 |
input_tensor: torch.FloatTensor,
|
@@ -765,6 +791,11 @@ class ResnetBlock3D(nn.Module):
|
|
765 |
|
766 |
hidden_states = self.conv1(hidden_states, causal=causal)
|
767 |
|
|
|
|
|
|
|
|
|
|
|
768 |
hidden_states = self.norm2(hidden_states)
|
769 |
|
770 |
hidden_states = self.non_linearity(hidden_states)
|
@@ -773,6 +804,11 @@ class ResnetBlock3D(nn.Module):
|
|
773 |
|
774 |
hidden_states = self.conv2(hidden_states, causal=causal)
|
775 |
|
|
|
|
|
|
|
|
|
|
|
776 |
input_tensor = self.norm3(input_tensor)
|
777 |
|
778 |
input_tensor = self.conv_shortcut(input_tensor)
|
|
|
481 |
resnet_eps=1e-6,
|
482 |
resnet_groups=norm_num_groups,
|
483 |
norm_layer=norm_layer,
|
484 |
+
inject_noise=block_params.get("inject_noise", False),
|
485 |
)
|
486 |
elif block_name == "res_x_y":
|
487 |
output_channel = output_channel // block_params.get("multiplier", 2)
|
|
|
492 |
eps=1e-6,
|
493 |
groups=norm_num_groups,
|
494 |
norm_layer=norm_layer,
|
495 |
+
inject_noise=block_params.get("inject_noise", False),
|
496 |
)
|
497 |
elif block_name == "compress_time":
|
498 |
block = DepthToSpaceUpsample(
|
|
|
585 |
resnet_eps: float = 1e-6,
|
586 |
resnet_groups: int = 32,
|
587 |
norm_layer: str = "group_norm",
|
588 |
+
inject_noise: bool = False,
|
589 |
):
|
590 |
super().__init__()
|
591 |
resnet_groups = (
|
|
|
602 |
groups=resnet_groups,
|
603 |
dropout=dropout,
|
604 |
norm_layer=norm_layer,
|
605 |
+
inject_noise=inject_noise,
|
606 |
)
|
607 |
for _ in range(num_layers)
|
608 |
]
|
|
|
694 |
groups: int = 32,
|
695 |
eps: float = 1e-6,
|
696 |
norm_layer: str = "group_norm",
|
697 |
+
inject_noise: bool = False,
|
698 |
):
|
699 |
super().__init__()
|
700 |
self.in_channels = in_channels
|
701 |
out_channels = in_channels if out_channels is None else out_channels
|
702 |
self.out_channels = out_channels
|
703 |
+
self.inject_noise = inject_noise
|
704 |
|
705 |
if norm_layer == "group_norm":
|
706 |
self.norm1 = nn.GroupNorm(
|
|
|
723 |
causal=True,
|
724 |
)
|
725 |
|
726 |
+
if inject_noise:
|
727 |
+
self.per_channel_scale1 = nn.Parameter(torch.zeros((in_channels, 1, 1)))
|
728 |
+
|
729 |
if norm_layer == "group_norm":
|
730 |
self.norm2 = nn.GroupNorm(
|
731 |
num_groups=groups, num_channels=out_channels, eps=eps, affine=True
|
|
|
747 |
causal=True,
|
748 |
)
|
749 |
|
750 |
+
if inject_noise:
|
751 |
+
self.per_channel_scale2 = nn.Parameter(torch.zeros((in_channels, 1, 1)))
|
752 |
+
|
753 |
self.conv_shortcut = (
|
754 |
make_linear_nd(
|
755 |
dims=dims, in_channels=in_channels, out_channels=out_channels
|
|
|
764 |
else nn.Identity()
|
765 |
)
|
766 |
|
767 |
+
def _feed_spatial_noise(
|
768 |
+
self, hidden_states: torch.FloatTensor, per_channel_scale: torch.FloatTensor
|
769 |
+
) -> torch.FloatTensor:
|
770 |
+
spatial_shape = hidden_states.shape[-2:]
|
771 |
+
device = hidden_states.device
|
772 |
+
dtype = hidden_states.dtype
|
773 |
+
|
774 |
+
# similar to the "explicit noise inputs" method in style-gan
|
775 |
+
spatial_noise = torch.randn(spatial_shape, device=device, dtype=dtype)[None]
|
776 |
+
scaled_noise = (spatial_noise * per_channel_scale)[None, :, None, ...]
|
777 |
+
hidden_states = hidden_states + scaled_noise
|
778 |
+
|
779 |
+
return hidden_states
|
780 |
+
|
781 |
def forward(
|
782 |
self,
|
783 |
input_tensor: torch.FloatTensor,
|
|
|
791 |
|
792 |
hidden_states = self.conv1(hidden_states, causal=causal)
|
793 |
|
794 |
+
if self.inject_noise:
|
795 |
+
hidden_states = self._feed_spatial_noise(
|
796 |
+
hidden_states, self.per_channel_scale1
|
797 |
+
)
|
798 |
+
|
799 |
hidden_states = self.norm2(hidden_states)
|
800 |
|
801 |
hidden_states = self.non_linearity(hidden_states)
|
|
|
804 |
|
805 |
hidden_states = self.conv2(hidden_states, causal=causal)
|
806 |
|
807 |
+
if self.inject_noise:
|
808 |
+
hidden_states = self._feed_spatial_noise(
|
809 |
+
hidden_states, self.per_channel_scale2
|
810 |
+
)
|
811 |
+
|
812 |
input_tensor = self.norm3(input_tensor)
|
813 |
|
814 |
input_tensor = self.conv_shortcut(input_tensor)
|