origordon commited on
Commit
1f21780
·
1 Parent(s): 6a9d9a1

VAE Decoder: Inject noise between conv layers.

Browse files

1.Add inject_noise flag to res_x, rex_x_y blocks.
2.Init noise to zero in ResnetBlock3D constructor.
2.Add _feed_spatial_noise method to inject noise between conv layers.

xora/models/autoencoders/causal_video_autoencoder.py CHANGED
@@ -481,6 +481,7 @@ class Decoder(nn.Module):
481
  resnet_eps=1e-6,
482
  resnet_groups=norm_num_groups,
483
  norm_layer=norm_layer,
 
484
  )
485
  elif block_name == "res_x_y":
486
  output_channel = output_channel // block_params.get("multiplier", 2)
@@ -491,6 +492,7 @@ class Decoder(nn.Module):
491
  eps=1e-6,
492
  groups=norm_num_groups,
493
  norm_layer=norm_layer,
 
494
  )
495
  elif block_name == "compress_time":
496
  block = DepthToSpaceUpsample(
@@ -583,6 +585,7 @@ class UNetMidBlock3D(nn.Module):
583
  resnet_eps: float = 1e-6,
584
  resnet_groups: int = 32,
585
  norm_layer: str = "group_norm",
 
586
  ):
587
  super().__init__()
588
  resnet_groups = (
@@ -599,6 +602,7 @@ class UNetMidBlock3D(nn.Module):
599
  groups=resnet_groups,
600
  dropout=dropout,
601
  norm_layer=norm_layer,
 
602
  )
603
  for _ in range(num_layers)
604
  ]
@@ -690,11 +694,13 @@ class ResnetBlock3D(nn.Module):
690
  groups: int = 32,
691
  eps: float = 1e-6,
692
  norm_layer: str = "group_norm",
 
693
  ):
694
  super().__init__()
695
  self.in_channels = in_channels
696
  out_channels = in_channels if out_channels is None else out_channels
697
  self.out_channels = out_channels
 
698
 
699
  if norm_layer == "group_norm":
700
  self.norm1 = nn.GroupNorm(
@@ -717,6 +723,9 @@ class ResnetBlock3D(nn.Module):
717
  causal=True,
718
  )
719
 
 
 
 
720
  if norm_layer == "group_norm":
721
  self.norm2 = nn.GroupNorm(
722
  num_groups=groups, num_channels=out_channels, eps=eps, affine=True
@@ -738,6 +747,9 @@ class ResnetBlock3D(nn.Module):
738
  causal=True,
739
  )
740
 
 
 
 
741
  self.conv_shortcut = (
742
  make_linear_nd(
743
  dims=dims, in_channels=in_channels, out_channels=out_channels
@@ -752,6 +764,20 @@ class ResnetBlock3D(nn.Module):
752
  else nn.Identity()
753
  )
754
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
755
  def forward(
756
  self,
757
  input_tensor: torch.FloatTensor,
@@ -765,6 +791,11 @@ class ResnetBlock3D(nn.Module):
765
 
766
  hidden_states = self.conv1(hidden_states, causal=causal)
767
 
 
 
 
 
 
768
  hidden_states = self.norm2(hidden_states)
769
 
770
  hidden_states = self.non_linearity(hidden_states)
@@ -773,6 +804,11 @@ class ResnetBlock3D(nn.Module):
773
 
774
  hidden_states = self.conv2(hidden_states, causal=causal)
775
 
 
 
 
 
 
776
  input_tensor = self.norm3(input_tensor)
777
 
778
  input_tensor = self.conv_shortcut(input_tensor)
 
481
  resnet_eps=1e-6,
482
  resnet_groups=norm_num_groups,
483
  norm_layer=norm_layer,
484
+ inject_noise=block_params.get("inject_noise", False),
485
  )
486
  elif block_name == "res_x_y":
487
  output_channel = output_channel // block_params.get("multiplier", 2)
 
492
  eps=1e-6,
493
  groups=norm_num_groups,
494
  norm_layer=norm_layer,
495
+ inject_noise=block_params.get("inject_noise", False),
496
  )
497
  elif block_name == "compress_time":
498
  block = DepthToSpaceUpsample(
 
585
  resnet_eps: float = 1e-6,
586
  resnet_groups: int = 32,
587
  norm_layer: str = "group_norm",
588
+ inject_noise: bool = False,
589
  ):
590
  super().__init__()
591
  resnet_groups = (
 
602
  groups=resnet_groups,
603
  dropout=dropout,
604
  norm_layer=norm_layer,
605
+ inject_noise=inject_noise,
606
  )
607
  for _ in range(num_layers)
608
  ]
 
694
  groups: int = 32,
695
  eps: float = 1e-6,
696
  norm_layer: str = "group_norm",
697
+ inject_noise: bool = False,
698
  ):
699
  super().__init__()
700
  self.in_channels = in_channels
701
  out_channels = in_channels if out_channels is None else out_channels
702
  self.out_channels = out_channels
703
+ self.inject_noise = inject_noise
704
 
705
  if norm_layer == "group_norm":
706
  self.norm1 = nn.GroupNorm(
 
723
  causal=True,
724
  )
725
 
726
+ if inject_noise:
727
+ self.per_channel_scale1 = nn.Parameter(torch.zeros((in_channels, 1, 1)))
728
+
729
  if norm_layer == "group_norm":
730
  self.norm2 = nn.GroupNorm(
731
  num_groups=groups, num_channels=out_channels, eps=eps, affine=True
 
747
  causal=True,
748
  )
749
 
750
+ if inject_noise:
751
+ self.per_channel_scale2 = nn.Parameter(torch.zeros((in_channels, 1, 1)))
752
+
753
  self.conv_shortcut = (
754
  make_linear_nd(
755
  dims=dims, in_channels=in_channels, out_channels=out_channels
 
764
  else nn.Identity()
765
  )
766
 
767
+ def _feed_spatial_noise(
768
+ self, hidden_states: torch.FloatTensor, per_channel_scale: torch.FloatTensor
769
+ ) -> torch.FloatTensor:
770
+ spatial_shape = hidden_states.shape[-2:]
771
+ device = hidden_states.device
772
+ dtype = hidden_states.dtype
773
+
774
+ # similar to the "explicit noise inputs" method in style-gan
775
+ spatial_noise = torch.randn(spatial_shape, device=device, dtype=dtype)[None]
776
+ scaled_noise = (spatial_noise * per_channel_scale)[None, :, None, ...]
777
+ hidden_states = hidden_states + scaled_noise
778
+
779
+ return hidden_states
780
+
781
  def forward(
782
  self,
783
  input_tensor: torch.FloatTensor,
 
791
 
792
  hidden_states = self.conv1(hidden_states, causal=causal)
793
 
794
+ if self.inject_noise:
795
+ hidden_states = self._feed_spatial_noise(
796
+ hidden_states, self.per_channel_scale1
797
+ )
798
+
799
  hidden_states = self.norm2(hidden_states)
800
 
801
  hidden_states = self.non_linearity(hidden_states)
 
804
 
805
  hidden_states = self.conv2(hidden_states, causal=causal)
806
 
807
+ if self.inject_noise:
808
+ hidden_states = self._feed_spatial_noise(
809
+ hidden_states, self.per_channel_scale2
810
+ )
811
+
812
  input_tensor = self.norm3(input_tensor)
813
 
814
  input_tensor = self.conv_shortcut(input_tensor)