guysrn commited on
Commit
bd9f5fd
1 Parent(s): aacf3ac

vae: fix attention blocks and timestep conditioning

Browse files
xora/models/autoencoders/causal_video_autoencoder.py CHANGED
@@ -220,7 +220,7 @@ class CausalVideoAutoencoder(AutoencoderKLWrapper):
220
 
221
  def set_use_tpu_flash_attention(self):
222
  for block in self.decoder.up_blocks:
223
- if isinstance(block, AttentionResBlocks):
224
  for attention_block in block.attention_blocks:
225
  attention_block.set_use_tpu_flash_attention()
226
 
@@ -497,17 +497,18 @@ class Decoder(nn.Module):
497
  resnet_groups=norm_num_groups,
498
  norm_layer=norm_layer,
499
  inject_noise=block_params.get("inject_noise", False),
 
500
  )
501
  elif block_name == "attn_res_x":
502
- block = AttentionResBlocks(
503
  dims=dims,
504
  in_channels=input_channel,
505
  num_layers=block_params["num_layers"],
506
  resnet_groups=norm_num_groups,
507
  norm_layer=norm_layer,
508
- attention_head_dim=block_params["attention_head_dim"],
509
  inject_noise=block_params.get("inject_noise", False),
510
  timestep_conditioning=timestep_conditioning,
 
511
  )
512
  elif block_name == "res_x_y":
513
  output_channel = output_channel // block_params.get("multiplier", 2)
@@ -642,129 +643,6 @@ class Decoder(nn.Module):
642
  return sample
643
 
644
 
645
- class AttentionResBlocks(nn.Module):
646
- """
647
- A 3D convolution residual block followed by self attention residual block
648
-
649
- Args:
650
- dims (`int` or `Tuple[int, int]`): The number of dimensions to use in convolutions.
651
- in_channels (`int`): The number of input channels.
652
- dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
653
- num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
654
- resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
655
- resnet_groups (`int`, *optional*, defaults to 32):
656
- The number of groups to use in the group normalization layers of the resnet blocks.
657
- norm_layer (`str`, *optional*, defaults to `group_norm`): The normalization layer to use.
658
- attention_head_dim (`int`, *optional*, defaults to 64): The dimension of the attention heads.
659
- inject_noise (`bool`, *optional*, defaults to `False`): Whether to inject noise or not between convolution layers.
660
-
661
- Returns:
662
- `torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
663
- in_channels, height, width)`.
664
-
665
- """
666
-
667
- def __init__(
668
- self,
669
- dims: Union[int, Tuple[int, int]],
670
- in_channels: int,
671
- dropout: float = 0.0,
672
- num_layers: int = 1,
673
- resnet_eps: float = 1e-6,
674
- resnet_groups: int = 32,
675
- norm_layer: str = "group_norm",
676
- attention_head_dim: int = 64,
677
- inject_noise: bool = False,
678
- ):
679
- super().__init__()
680
-
681
- if attention_head_dim > in_channels:
682
- raise ValueError(
683
- "attention_head_dim must be less than or equal to in_channels"
684
- )
685
-
686
- resnet_groups = (
687
- resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
688
- )
689
-
690
- self.res_blocks = []
691
- self.attention_blocks = []
692
- for i in range(num_layers):
693
- self.res_blocks.append(
694
- ResnetBlock3D(
695
- dims=dims,
696
- in_channels=in_channels,
697
- out_channels=in_channels,
698
- eps=resnet_eps,
699
- groups=resnet_groups,
700
- dropout=dropout,
701
- norm_layer=norm_layer,
702
- inject_noise=inject_noise,
703
- )
704
- )
705
- self.attention_blocks.append(
706
- Attention(
707
- query_dim=in_channels,
708
- heads=in_channels // attention_head_dim,
709
- dim_head=attention_head_dim,
710
- bias=True,
711
- out_bias=True,
712
- qk_norm="rms_norm",
713
- residual_connection=True,
714
- )
715
- )
716
-
717
- self.res_blocks = nn.ModuleList(self.res_blocks)
718
- self.attention_blocks = nn.ModuleList(self.attention_blocks)
719
-
720
- def forward(
721
- self, hidden_states: torch.FloatTensor, causal: bool = True
722
- ) -> torch.FloatTensor:
723
- for resnet, attention in zip(self.res_blocks, self.attention_blocks):
724
- hidden_states = resnet(hidden_states, causal=causal)
725
-
726
- # Reshape the hidden states to be (batch_size, frames * height * width, channel)
727
- batch_size, channel, frames, height, width = hidden_states.shape
728
- hidden_states = hidden_states.view(
729
- batch_size, channel, frames * height * width
730
- ).transpose(1, 2)
731
-
732
- if attention.use_tpu_flash_attention:
733
- # Pad the second dimension to be divisible by block_k_major (block in flash attention)
734
- seq_len = hidden_states.shape[1]
735
- block_k_major = 512
736
- pad_len = (block_k_major - seq_len % block_k_major) % block_k_major
737
- if pad_len > 0:
738
- hidden_states = F.pad(
739
- hidden_states, (0, 0, 0, pad_len), "constant", 0
740
- )
741
-
742
- # Create a mask with ones for the original sequence length and zeros for the padded indexes
743
- mask = torch.ones(
744
- (hidden_states.shape[0], seq_len),
745
- device=hidden_states.device,
746
- dtype=hidden_states.dtype,
747
- )
748
- if pad_len > 0:
749
- mask = F.pad(mask, (0, pad_len), "constant", 0)
750
-
751
- hidden_states = attention(
752
- hidden_states,
753
- attention_mask=None if not attention.use_tpu_flash_attention else mask,
754
- )
755
-
756
- if attention.use_tpu_flash_attention:
757
- # Remove the padding
758
- if pad_len > 0:
759
- hidden_states = hidden_states[:, :-pad_len, :]
760
-
761
- # Reshape the hidden states back to (batch_size, channel, frames, height, width, channel)
762
- hidden_states = hidden_states.transpose(-1, -2).reshape(
763
- batch_size, channel, frames, height, width
764
- )
765
- return hidden_states
766
-
767
-
768
  class UNetMidBlock3D(nn.Module):
769
  """
770
  A 3D UNet mid-block [`UNetMidBlock3D`] with multiple residual blocks.
@@ -776,6 +654,14 @@ class UNetMidBlock3D(nn.Module):
776
  resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
777
  resnet_groups (`int`, *optional*, defaults to 32):
778
  The number of groups to use in the group normalization layers of the resnet blocks.
 
 
 
 
 
 
 
 
779
 
780
  Returns:
781
  `torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
@@ -794,6 +680,7 @@ class UNetMidBlock3D(nn.Module):
794
  norm_layer: str = "group_norm",
795
  inject_noise: bool = False,
796
  timestep_conditioning: bool = False,
 
797
  ):
798
  super().__init__()
799
  resnet_groups = (
@@ -823,6 +710,29 @@ class UNetMidBlock3D(nn.Module):
823
  ]
824
  )
825
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
826
  def forward(
827
  self,
828
  hidden_states: torch.FloatTensor,
@@ -845,10 +755,60 @@ class UNetMidBlock3D(nn.Module):
845
  timestep_embed = timestep_embed.view(
846
  batch_size, timestep_embed.shape[-1], 1, 1, 1
847
  )
848
- for resnet in self.res_blocks:
849
- hidden_states = resnet(
850
- hidden_states, causal=causal, timesteps=timestep_embed
851
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
852
  return hidden_states
853
 
854
 
 
220
 
221
  def set_use_tpu_flash_attention(self):
222
  for block in self.decoder.up_blocks:
223
+ if isinstance(block, UNetMidBlock3D) and block.attention_blocks:
224
  for attention_block in block.attention_blocks:
225
  attention_block.set_use_tpu_flash_attention()
226
 
 
497
  resnet_groups=norm_num_groups,
498
  norm_layer=norm_layer,
499
  inject_noise=block_params.get("inject_noise", False),
500
+ timestep_conditioning=timestep_conditioning,
501
  )
502
  elif block_name == "attn_res_x":
503
+ block = UNetMidBlock3D(
504
  dims=dims,
505
  in_channels=input_channel,
506
  num_layers=block_params["num_layers"],
507
  resnet_groups=norm_num_groups,
508
  norm_layer=norm_layer,
 
509
  inject_noise=block_params.get("inject_noise", False),
510
  timestep_conditioning=timestep_conditioning,
511
+ attention_head_dim=block_params["attention_head_dim"],
512
  )
513
  elif block_name == "res_x_y":
514
  output_channel = output_channel // block_params.get("multiplier", 2)
 
643
  return sample
644
 
645
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
646
  class UNetMidBlock3D(nn.Module):
647
  """
648
  A 3D UNet mid-block [`UNetMidBlock3D`] with multiple residual blocks.
 
654
  resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
655
  resnet_groups (`int`, *optional*, defaults to 32):
656
  The number of groups to use in the group normalization layers of the resnet blocks.
657
+ norm_layer (`str`, *optional*, defaults to `group_norm`):
658
+ The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
659
+ inject_noise (`bool`, *optional*, defaults to `False`):
660
+ Whether to inject noise into the hidden states.
661
+ timestep_conditioning (`bool`, *optional*, defaults to `False`):
662
+ Whether to condition the hidden states on the timestep.
663
+ attention_head_dim (`int`, *optional*, defaults to -1):
664
+ The dimension of the attention head. If -1, no attention is used.
665
 
666
  Returns:
667
  `torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
 
680
  norm_layer: str = "group_norm",
681
  inject_noise: bool = False,
682
  timestep_conditioning: bool = False,
683
+ attention_head_dim: int = -1,
684
  ):
685
  super().__init__()
686
  resnet_groups = (
 
710
  ]
711
  )
712
 
713
+ self.attention_blocks = None
714
+
715
+ if attention_head_dim > 0:
716
+ if attention_head_dim > in_channels:
717
+ raise ValueError(
718
+ "attention_head_dim must be less than or equal to in_channels"
719
+ )
720
+
721
+ self.attention_blocks = nn.ModuleList(
722
+ [
723
+ Attention(
724
+ query_dim=in_channels,
725
+ heads=in_channels // attention_head_dim,
726
+ dim_head=attention_head_dim,
727
+ bias=True,
728
+ out_bias=True,
729
+ qk_norm="rms_norm",
730
+ residual_connection=True,
731
+ )
732
+ for _ in range(num_layers)
733
+ ]
734
+ )
735
+
736
  def forward(
737
  self,
738
  hidden_states: torch.FloatTensor,
 
755
  timestep_embed = timestep_embed.view(
756
  batch_size, timestep_embed.shape[-1], 1, 1, 1
757
  )
758
+
759
+ if self.attention_blocks:
760
+ for resnet, attention in zip(self.res_blocks, self.attention_blocks):
761
+ hidden_states = resnet(
762
+ hidden_states, causal=causal, timesteps=timestep_embed
763
+ )
764
+
765
+ # Reshape the hidden states to be (batch_size, frames * height * width, channel)
766
+ batch_size, channel, frames, height, width = hidden_states.shape
767
+ hidden_states = hidden_states.view(
768
+ batch_size, channel, frames * height * width
769
+ ).transpose(1, 2)
770
+
771
+ if attention.use_tpu_flash_attention:
772
+ # Pad the second dimension to be divisible by block_k_major (block in flash attention)
773
+ seq_len = hidden_states.shape[1]
774
+ block_k_major = 512
775
+ pad_len = (block_k_major - seq_len % block_k_major) % block_k_major
776
+ if pad_len > 0:
777
+ hidden_states = F.pad(
778
+ hidden_states, (0, 0, 0, pad_len), "constant", 0
779
+ )
780
+
781
+ # Create a mask with ones for the original sequence length and zeros for the padded indexes
782
+ mask = torch.ones(
783
+ (hidden_states.shape[0], seq_len),
784
+ device=hidden_states.device,
785
+ dtype=hidden_states.dtype,
786
+ )
787
+ if pad_len > 0:
788
+ mask = F.pad(mask, (0, pad_len), "constant", 0)
789
+
790
+ hidden_states = attention(
791
+ hidden_states,
792
+ attention_mask=(
793
+ None if not attention.use_tpu_flash_attention else mask
794
+ ),
795
+ )
796
+
797
+ if attention.use_tpu_flash_attention:
798
+ # Remove the padding
799
+ if pad_len > 0:
800
+ hidden_states = hidden_states[:, :-pad_len, :]
801
+
802
+ # Reshape the hidden states back to (batch_size, channel, frames, height, width, channel)
803
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
804
+ batch_size, channel, frames, height, width
805
+ )
806
+ else:
807
+ for resnet in self.res_blocks:
808
+ hidden_states = resnet(
809
+ hidden_states, causal=causal, timesteps=timestep_embed
810
+ )
811
+
812
  return hidden_states
813
 
814