Sapir commited on
Commit
86b1a7e
1 Parent(s): e7d5e3c
xora/examples/image_to_video.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from xora.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
3
+ from xora.models.transformers.transformer3d import Transformer3DModel
4
+ from xora.models.transformers.symmetric_patchifier import SymmetricPatchifier
5
+ from xora.schedulers.rf import RectifiedFlowScheduler
6
+ from xora.pipelines.pipeline_video_pixart_alpha import VideoPixArtAlphaPipeline
7
+ from pathlib import Path
8
+ from transformers import T5EncoderModel
9
+
10
+
11
+ model_name_or_path = "PixArt-alpha/PixArt-XL-2-1024-MS"
12
+ vae_local_path = Path("/opt/models/checkpoints/vae_training/causal_vvae_32x32x8_420m_cont_32/step_2296000")
13
+ dtype = torch.float32
14
+ vae = CausalVideoAutoencoder.from_pretrained(
15
+ pretrained_model_name_or_path=vae_local_path,
16
+ revision=False,
17
+ torch_dtype=torch.bfloat16,
18
+ load_in_8bit=False,
19
+ ).cuda()
20
+ transformer_config_path = Path("/opt/txt2img/txt2img/config/transformer3d/xora_v1.2-L.json")
21
+ transformer_config = Transformer3DModel.load_config(transformer_config_path)
22
+ transformer = Transformer3DModel.from_config(transformer_config)
23
+ transformer_local_path = Path("/opt/models/logs/v1.2-vae-mf-medHR-mr-cvae-first-frame-cond-4k-seq/ckpt/01822000/model.pt")
24
+ transformer_ckpt_state_dict = torch.load(transformer_local_path)
25
+ transformer.load_state_dict(transformer_ckpt_state_dict, True)
26
+ transformer = transformer.cuda()
27
+ unet = transformer
28
+ scheduler_config_path = Path("/opt/txt2img/txt2img/config/scheduler/RF_SD3_shifted.json")
29
+ scheduler_config = RectifiedFlowScheduler.load_config(scheduler_config_path)
30
+ scheduler = RectifiedFlowScheduler.from_config(scheduler_config)
31
+ patchifier = SymmetricPatchifier(patch_size=1)
32
+ # text_encoder = T5EncoderModel.from_pretrained("t5-v1_1-xxl")
33
+
34
+ submodel_dict = {
35
+ "unet": unet,
36
+ "transformer": transformer,
37
+ "patchifier": patchifier,
38
+ "text_encoder": None,
39
+ "scheduler": scheduler,
40
+ "vae": vae,
41
+
42
+ }
43
+
44
+ pipeline = VideoPixArtAlphaPipeline.from_pretrained(model_name_or_path,
45
+ safety_checker=None,
46
+ revision=None,
47
+ torch_dtype=dtype,
48
+ **submodel_dict,
49
+ )
50
+
51
+ num_inference_steps=20
52
+ num_images_per_prompt=2
53
+ guidance_scale=3
54
+ height=512
55
+ width=768
56
+ num_frames=57
57
+ frame_rate=25
58
+ # sample = {
59
+ # "prompt": "A cat", # (B, L, E)
60
+ # 'prompt_attention_mask': None, # (B , L)
61
+ # 'negative_prompt': "Ugly deformed",
62
+ # 'negative_prompt_attention_mask': None # (B , L)
63
+ # }
64
+
65
+ sample = torch.load("/opt/sample.pt")
66
+ for _, item in sample.items():
67
+ if item is not None:
68
+ item = item.cuda()
69
+ media_items = torch.load("/opt/sample_media.pt")
70
+
71
+ images = pipeline(
72
+ num_inference_steps=num_inference_steps,
73
+ num_images_per_prompt=num_images_per_prompt,
74
+ guidance_scale=guidance_scale,
75
+ generator=None,
76
+ output_type="pt",
77
+ callback_on_step_end=None,
78
+ height=height,
79
+ width=width,
80
+ num_frames=num_frames,
81
+ frame_rate=frame_rate,
82
+ **sample,
83
+ is_video=True,
84
+ vae_per_channel_normalize=True,
85
+ ).images
86
+
87
+ print()
xora/models/autoencoders/causal_video_autoencoder.py CHANGED
@@ -8,11 +8,13 @@ import torch
8
  import numpy as np
9
  from einops import rearrange
10
  from torch import nn
 
11
 
12
  from xora.models.autoencoders.conv_nd_factory import make_conv_nd, make_linear_nd
13
  from xora.models.autoencoders.pixel_norm import PixelNorm
14
  from xora.models.autoencoders.vae import AutoencoderKLWrapper
15
 
 
16
 
17
  class CausalVideoAutoencoder(AutoencoderKLWrapper):
18
  @classmethod
@@ -138,7 +140,7 @@ class CausalVideoAutoencoder(AutoencoderKLWrapper):
138
  key = key.replace(k, v)
139
 
140
  if "norm" in key and key not in model_keys:
141
- print(f"Removing key {key} from state_dict as it is not present in the model")
142
  continue
143
 
144
  converted_state_dict[key] = value
 
8
  import numpy as np
9
  from einops import rearrange
10
  from torch import nn
11
+ from diffusers.utils import logging
12
 
13
  from xora.models.autoencoders.conv_nd_factory import make_conv_nd, make_linear_nd
14
  from xora.models.autoencoders.pixel_norm import PixelNorm
15
  from xora.models.autoencoders.vae import AutoencoderKLWrapper
16
 
17
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
18
 
19
  class CausalVideoAutoencoder(AutoencoderKLWrapper):
20
  @classmethod
 
140
  key = key.replace(k, v)
141
 
142
  if "norm" in key and key not in model_keys:
143
+ logger.info(f"Removing key {key} from state_dict as it is not present in the model")
144
  continue
145
 
146
  converted_state_dict[key] = value
xora/models/autoencoders/vae_encode.py CHANGED
@@ -1,44 +1,12 @@
1
  import torch
2
- from torch import nn
3
  from diffusers import AutoencoderKL
4
  from einops import rearrange
5
  from torch import Tensor
6
- from torch.nn import functional
7
 
8
 
9
  from xora.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
10
-
11
- class Downsample3D(nn.Module):
12
- def __init__(self, dims, in_channels: int, out_channels: int, kernel_size: int = 3, padding: int = 1):
13
- super().__init__()
14
- stride: int = 2
15
- self.padding = padding
16
- self.in_channels = in_channels
17
- self.dims = dims
18
- self.conv = make_conv_nd(
19
- dims=dims,
20
- in_channels=in_channels,
21
- out_channels=out_channels,
22
- kernel_size=kernel_size,
23
- stride=stride,
24
- padding=padding,
25
- )
26
-
27
- def forward(self, x, downsample_in_time=True):
28
- conv = self.conv
29
- if self.padding == 0:
30
- if self.dims == 2:
31
- padding = (0, 1, 0, 1)
32
- else:
33
- padding = (0, 1, 0, 1, 0, 1 if downsample_in_time else 0)
34
-
35
- x = functional.pad(x, padding, mode="constant", value=0)
36
-
37
- if self.dims == (2, 1) and not downsample_in_time:
38
- return conv(x, skip_time_conv=True)
39
-
40
- return conv(x)
41
-
42
 
43
 
44
  def vae_encode(media_items: Tensor, vae: AutoencoderKL, split_size: int = 1, vae_per_channel_normalize=False) -> Tensor:
@@ -78,7 +46,7 @@ def vae_encode(media_items: Tensor, vae: AutoencoderKL, split_size: int = 1, vae
78
  if channels != 3:
79
  raise ValueError(f"Expects tensors with 3 channels, got {channels}.")
80
 
81
- if is_video_shaped and not isinstance(vae, (CausalVideoAutoencoder)):
82
  media_items = rearrange(media_items, "b c n h w -> (b n) c h w")
83
  if split_size > 1:
84
  if len(media_items) % split_size != 0:
@@ -86,14 +54,16 @@ def vae_encode(media_items: Tensor, vae: AutoencoderKL, split_size: int = 1, vae
86
  encode_bs = len(media_items) // split_size
87
  # latents = [vae.encode(image_batch).latent_dist.sample() for image_batch in media_items.split(encode_bs)]
88
  latents = []
 
89
  for image_batch in media_items.split(encode_bs):
90
  latents.append(vae.encode(image_batch).latent_dist.sample())
 
91
  latents = torch.cat(latents, dim=0)
92
  else:
93
  latents = vae.encode(media_items).latent_dist.sample()
94
 
95
  latents = normalize_latents(latents, vae, vae_per_channel_normalize)
96
- if is_video_shaped and not isinstance(vae, (CausalVideoAutoencoder)):
97
  latents = rearrange(latents, "(b n) c h w -> b c n h w", b=batch_size)
98
  return latents
99
 
@@ -104,7 +74,7 @@ def vae_decode(
104
  is_video_shaped = latents.dim() == 5
105
  batch_size = latents.shape[0]
106
 
107
- if is_video_shaped and not isinstance(vae, (CausalVideoAutoencoder)):
108
  latents = rearrange(latents, "b c n h w -> (b n) c h w")
109
  if split_size > 1:
110
  if len(latents) % split_size != 0:
@@ -118,13 +88,13 @@ def vae_decode(
118
  else:
119
  images = _run_decoder(latents, vae, is_video, vae_per_channel_normalize)
120
 
121
- if is_video_shaped and not isinstance(vae, (CausalVideoAutoencoder)):
122
  images = rearrange(images, "(b n) c h w -> b c n h w", b=batch_size)
123
  return images
124
 
125
 
126
  def _run_decoder(latents: Tensor, vae: AutoencoderKL, is_video: bool, vae_per_channel_normalize=False) -> Tensor:
127
- if isinstance(vae, (CausalVideoAutoencoder)):
128
  *_, fl, hl, wl = latents.shape
129
  temporal_scale, spatial_scale, _ = get_vae_size_scale_factor(vae)
130
  latents = latents.to(vae.dtype)
@@ -148,7 +118,7 @@ def get_vae_size_scale_factor(vae: AutoencoderKL) -> float:
148
  else:
149
  down_blocks = len([block for block in vae.encoder.down_blocks if isinstance(block.downsample, Downsample3D)])
150
  spatial = vae.config.patch_size * 2**down_blocks
151
- temporal = vae.config.patch_size_t * 2 ** down_blocks if isinstance(vae) else 1
152
 
153
  return (temporal, spatial, spatial)
154
 
@@ -168,4 +138,4 @@ def un_normalize_latents(latents: Tensor, vae: AutoencoderKL, vae_per_channel_no
168
  + vae.mean_of_means.to(latents.dtype).view(1, -1, 1, 1, 1)
169
  if vae_per_channel_normalize
170
  else latents / vae.config.scaling_factor
171
- )
 
1
  import torch
 
2
  from diffusers import AutoencoderKL
3
  from einops import rearrange
4
  from torch import Tensor
 
5
 
6
 
7
  from xora.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
8
+ from xora.models.autoencoders.video_autoencoder import Downsample3D, VideoAutoencoder
9
+ import xora.utils.dist_util
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
 
12
  def vae_encode(media_items: Tensor, vae: AutoencoderKL, split_size: int = 1, vae_per_channel_normalize=False) -> Tensor:
 
46
  if channels != 3:
47
  raise ValueError(f"Expects tensors with 3 channels, got {channels}.")
48
 
49
+ if is_video_shaped and not isinstance(vae, (VideoAutoencoder, CausalVideoAutoencoder)):
50
  media_items = rearrange(media_items, "b c n h w -> (b n) c h w")
51
  if split_size > 1:
52
  if len(media_items) % split_size != 0:
 
54
  encode_bs = len(media_items) // split_size
55
  # latents = [vae.encode(image_batch).latent_dist.sample() for image_batch in media_items.split(encode_bs)]
56
  latents = []
57
+ dist_util.execute_graph()
58
  for image_batch in media_items.split(encode_bs):
59
  latents.append(vae.encode(image_batch).latent_dist.sample())
60
+ dist_util.execute_graph()
61
  latents = torch.cat(latents, dim=0)
62
  else:
63
  latents = vae.encode(media_items).latent_dist.sample()
64
 
65
  latents = normalize_latents(latents, vae, vae_per_channel_normalize)
66
+ if is_video_shaped and not isinstance(vae, (VideoAutoencoder, CausalVideoAutoencoder)):
67
  latents = rearrange(latents, "(b n) c h w -> b c n h w", b=batch_size)
68
  return latents
69
 
 
74
  is_video_shaped = latents.dim() == 5
75
  batch_size = latents.shape[0]
76
 
77
+ if is_video_shaped and not isinstance(vae, (VideoAutoencoder, CausalVideoAutoencoder)):
78
  latents = rearrange(latents, "b c n h w -> (b n) c h w")
79
  if split_size > 1:
80
  if len(latents) % split_size != 0:
 
88
  else:
89
  images = _run_decoder(latents, vae, is_video, vae_per_channel_normalize)
90
 
91
+ if is_video_shaped and not isinstance(vae, (VideoAutoencoder, CausalVideoAutoencoder)):
92
  images = rearrange(images, "(b n) c h w -> b c n h w", b=batch_size)
93
  return images
94
 
95
 
96
  def _run_decoder(latents: Tensor, vae: AutoencoderKL, is_video: bool, vae_per_channel_normalize=False) -> Tensor:
97
+ if isinstance(vae, (VideoAutoencoder, CausalVideoAutoencoder)):
98
  *_, fl, hl, wl = latents.shape
99
  temporal_scale, spatial_scale, _ = get_vae_size_scale_factor(vae)
100
  latents = latents.to(vae.dtype)
 
118
  else:
119
  down_blocks = len([block for block in vae.encoder.down_blocks if isinstance(block.downsample, Downsample3D)])
120
  spatial = vae.config.patch_size * 2**down_blocks
121
+ temporal = vae.config.patch_size_t * 2 ** down_blocks if isinstance(vae, VideoAutoencoder) else 1
122
 
123
  return (temporal, spatial, spatial)
124
 
 
138
  + vae.mean_of_means.to(latents.dtype).view(1, -1, 1, 1, 1)
139
  if vae_per_channel_normalize
140
  else latents / vae.config.scaling_factor
141
+ )
xora/models/autoencoders/video_autoencoder.py ADDED
@@ -0,0 +1,912 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from functools import partial
4
+ from types import SimpleNamespace
5
+ from typing import Any, Mapping, Optional, Tuple, Union
6
+
7
+ import torch
8
+ from einops import rearrange
9
+ from torch import nn
10
+ from torch.nn import functional
11
+
12
+ from diffusers.utils import logging
13
+
14
+ from txt2img.models.layers.nn import Identity
15
+ from xora.models.autoencoders.conv_nd_factory import make_conv_nd, make_linear_nd
16
+ from xora.models.autoencoders.pixel_norm import PixelNorm
17
+ from xora.models.autoencoders.vae import AutoencoderKLWrapper
18
+
19
+ logger = logging.get_logger(__name__)
20
+
21
+
22
+ class VideoAutoencoder(AutoencoderKLWrapper):
23
+ @classmethod
24
+ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *args, **kwargs):
25
+ config_local_path = pretrained_model_name_or_path / "config.json"
26
+ config = cls.load_config(config_local_path, **kwargs)
27
+ video_vae = cls.from_config(config)
28
+ video_vae.to(kwargs["torch_dtype"])
29
+
30
+ model_local_path = pretrained_model_name_or_path / "autoencoder.pth"
31
+ ckpt_state_dict = torch.load(model_local_path)
32
+ video_vae.load_state_dict(ckpt_state_dict)
33
+
34
+ statistics_local_path = pretrained_model_name_or_path / "per_channel_statistics.json"
35
+ if statistics_local_path.exists():
36
+ with open(statistics_local_path, "r") as file:
37
+ data = json.load(file)
38
+ transposed_data = list(zip(*data["data"]))
39
+ data_dict = {col: torch.tensor(vals) for col, vals in zip(data["columns"], transposed_data)}
40
+ video_vae.register_buffer("std_of_means", data_dict["std-of-means"])
41
+ video_vae.register_buffer(
42
+ "mean_of_means", data_dict.get("mean-of-means", torch.zeros_like(data_dict["std-of-means"]))
43
+ )
44
+
45
+ return video_vae
46
+
47
+ @staticmethod
48
+ def from_config(config):
49
+ assert config["_class_name"] == "VideoAutoencoder", "config must have _class_name=VideoAutoencoder"
50
+ if isinstance(config["dims"], list):
51
+ config["dims"] = tuple(config["dims"])
52
+
53
+ assert config["dims"] in [2, 3, (2, 1)], "dims must be 2, 3 or (2, 1)"
54
+
55
+ double_z = config.get("double_z", True)
56
+ latent_log_var = config.get("latent_log_var", "per_channel" if double_z else "none")
57
+ use_quant_conv = config.get("use_quant_conv", True)
58
+
59
+ if use_quant_conv and latent_log_var == "uniform":
60
+ raise ValueError("uniform latent_log_var requires use_quant_conv=False")
61
+
62
+ encoder = Encoder(
63
+ dims=config["dims"],
64
+ in_channels=config.get("in_channels", 3),
65
+ out_channels=config["latent_channels"],
66
+ block_out_channels=config["block_out_channels"],
67
+ patch_size=config.get("patch_size", 1),
68
+ latent_log_var=latent_log_var,
69
+ norm_layer=config.get("norm_layer", "group_norm"),
70
+ patch_size_t=config.get("patch_size_t", config.get("patch_size", 1)),
71
+ add_channel_padding=config.get("add_channel_padding", False),
72
+ )
73
+
74
+ decoder = Decoder(
75
+ dims=config["dims"],
76
+ in_channels=config["latent_channels"],
77
+ out_channels=config.get("out_channels", 3),
78
+ block_out_channels=config["block_out_channels"],
79
+ patch_size=config.get("patch_size", 1),
80
+ norm_layer=config.get("norm_layer", "group_norm"),
81
+ patch_size_t=config.get("patch_size_t", config.get("patch_size", 1)),
82
+ add_channel_padding=config.get("add_channel_padding", False),
83
+ )
84
+
85
+ dims = config["dims"]
86
+ return VideoAutoencoder(
87
+ encoder=encoder,
88
+ decoder=decoder,
89
+ latent_channels=config["latent_channels"],
90
+ dims=dims,
91
+ use_quant_conv=use_quant_conv,
92
+ )
93
+
94
+ @property
95
+ def config(self):
96
+ return SimpleNamespace(
97
+ _class_name="VideoAutoencoder",
98
+ dims=self.dims,
99
+ in_channels=self.encoder.conv_in.in_channels // (self.encoder.patch_size_t * self.encoder.patch_size**2),
100
+ out_channels=self.decoder.conv_out.out_channels // (self.decoder.patch_size_t * self.decoder.patch_size**2),
101
+ latent_channels=self.decoder.conv_in.in_channels,
102
+ block_out_channels=[
103
+ self.encoder.down_blocks[i].res_blocks[-1].conv1.out_channels
104
+ for i in range(len(self.encoder.down_blocks))
105
+ ],
106
+ scaling_factor=1.0,
107
+ norm_layer=self.encoder.norm_layer,
108
+ patch_size=self.encoder.patch_size,
109
+ latent_log_var=self.encoder.latent_log_var,
110
+ use_quant_conv=self.use_quant_conv,
111
+ patch_size_t=self.encoder.patch_size_t,
112
+ add_channel_padding=self.encoder.add_channel_padding,
113
+ )
114
+
115
+ @property
116
+ def is_video_supported(self):
117
+ """
118
+ Check if the model supports video inputs of shape (B, C, F, H, W). Otherwise, the model only supports 2D images.
119
+ """
120
+ return self.dims != 2
121
+
122
+ @property
123
+ def downscale_factor(self):
124
+ return self.encoder.downsample_factor
125
+
126
+ def to_json_string(self) -> str:
127
+ import json
128
+
129
+ return json.dumps(self.config.__dict__)
130
+
131
+ def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
132
+ model_keys = set(name for name, _ in self.named_parameters())
133
+
134
+ key_mapping = {
135
+ ".resnets.": ".res_blocks.",
136
+ "downsamplers.0": "downsample",
137
+ "upsamplers.0": "upsample",
138
+ }
139
+
140
+ converted_state_dict = {}
141
+ for key, value in state_dict.items():
142
+ for k, v in key_mapping.items():
143
+ key = key.replace(k, v)
144
+
145
+ if "norm" in key and key not in model_keys:
146
+ logger.info(f"Removing key {key} from state_dict as it is not present in the model")
147
+ continue
148
+
149
+ converted_state_dict[key] = value
150
+
151
+ super().load_state_dict(converted_state_dict, strict=strict)
152
+
153
+ def last_layer(self):
154
+ if hasattr(self.decoder, "conv_out"):
155
+ if isinstance(self.decoder.conv_out, nn.Sequential):
156
+ last_layer = self.decoder.conv_out[-1]
157
+ else:
158
+ last_layer = self.decoder.conv_out
159
+ else:
160
+ last_layer = self.decoder.layers[-1]
161
+ return last_layer
162
+
163
+
164
+ class Encoder(nn.Module):
165
+ r"""
166
+ The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation.
167
+
168
+ Args:
169
+ in_channels (`int`, *optional*, defaults to 3):
170
+ The number of input channels.
171
+ out_channels (`int`, *optional*, defaults to 3):
172
+ The number of output channels.
173
+ block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
174
+ The number of output channels for each block.
175
+ layers_per_block (`int`, *optional*, defaults to 2):
176
+ The number of layers per block.
177
+ norm_num_groups (`int`, *optional*, defaults to 32):
178
+ The number of groups for normalization.
179
+ patch_size (`int`, *optional*, defaults to 1):
180
+ The patch size to use. Should be a power of 2.
181
+ norm_layer (`str`, *optional*, defaults to `group_norm`):
182
+ The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
183
+ latent_log_var (`str`, *optional*, defaults to `per_channel`):
184
+ The number of channels for the log variance. Can be either `per_channel`, `uniform`, or `none`.
185
+ """
186
+
187
+ def __init__(
188
+ self,
189
+ dims: Union[int, Tuple[int, int]] = 3,
190
+ in_channels: int = 3,
191
+ out_channels: int = 3,
192
+ block_out_channels: Tuple[int, ...] = (64,),
193
+ layers_per_block: int = 2,
194
+ norm_num_groups: int = 32,
195
+ patch_size: Union[int, Tuple[int]] = 1,
196
+ norm_layer: str = "group_norm", # group_norm, pixel_norm
197
+ latent_log_var: str = "per_channel",
198
+ patch_size_t: Optional[int] = None,
199
+ add_channel_padding: Optional[bool] = False,
200
+ ):
201
+ super().__init__()
202
+ self.patch_size = patch_size
203
+ self.patch_size_t = patch_size_t if patch_size_t is not None else patch_size
204
+ self.add_channel_padding = add_channel_padding
205
+ self.layers_per_block = layers_per_block
206
+ self.norm_layer = norm_layer
207
+ self.latent_channels = out_channels
208
+ self.latent_log_var = latent_log_var
209
+ if add_channel_padding:
210
+ in_channels = in_channels * self.patch_size**3
211
+ else:
212
+ in_channels = in_channels * self.patch_size_t * self.patch_size**2
213
+ self.in_channels = in_channels
214
+ output_channel = block_out_channels[0]
215
+
216
+ self.conv_in = make_conv_nd(
217
+ dims=dims,
218
+ in_channels=in_channels,
219
+ out_channels=output_channel,
220
+ kernel_size=3,
221
+ stride=1,
222
+ padding=1,
223
+ )
224
+
225
+ self.down_blocks = nn.ModuleList([])
226
+
227
+ for i in range(len(block_out_channels)):
228
+ input_channel = output_channel
229
+ output_channel = block_out_channels[i]
230
+ is_final_block = i == len(block_out_channels) - 1
231
+
232
+ down_block = DownEncoderBlock3D(
233
+ dims=dims,
234
+ in_channels=input_channel,
235
+ out_channels=output_channel,
236
+ num_layers=self.layers_per_block,
237
+ add_downsample=not is_final_block and 2**i >= patch_size,
238
+ resnet_eps=1e-6,
239
+ downsample_padding=0,
240
+ resnet_groups=norm_num_groups,
241
+ norm_layer=norm_layer,
242
+ )
243
+ self.down_blocks.append(down_block)
244
+
245
+ self.mid_block = UNetMidBlock3D(
246
+ dims=dims,
247
+ in_channels=block_out_channels[-1],
248
+ num_layers=self.layers_per_block,
249
+ resnet_eps=1e-6,
250
+ resnet_groups=norm_num_groups,
251
+ norm_layer=norm_layer,
252
+ )
253
+
254
+ # out
255
+ if norm_layer == "group_norm":
256
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
257
+ elif norm_layer == "pixel_norm":
258
+ self.conv_norm_out = PixelNorm()
259
+ self.conv_act = nn.SiLU()
260
+
261
+ conv_out_channels = out_channels
262
+ if latent_log_var == "per_channel":
263
+ conv_out_channels *= 2
264
+ elif latent_log_var == "uniform":
265
+ conv_out_channels += 1
266
+ elif latent_log_var != "none":
267
+ raise ValueError(f"Invalid latent_log_var: {latent_log_var}")
268
+ self.conv_out = make_conv_nd(dims, block_out_channels[-1], conv_out_channels, 3, padding=1)
269
+
270
+ self.gradient_checkpointing = False
271
+
272
+ @property
273
+ def downscale_factor(self):
274
+ return (
275
+ 2 ** len([block for block in self.down_blocks if isinstance(block.downsample, Downsample3D)])
276
+ * self.patch_size
277
+ )
278
+
279
+ def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
280
+ r"""The forward method of the `Encoder` class."""
281
+
282
+ downsample_in_time = sample.shape[2] != 1
283
+
284
+ # patchify
285
+ patch_size_t = self.patch_size_t if downsample_in_time else 1
286
+ sample = patchify(
287
+ sample,
288
+ patch_size_hw=self.patch_size,
289
+ patch_size_t=patch_size_t,
290
+ add_channel_padding=self.add_channel_padding,
291
+ )
292
+
293
+ sample = self.conv_in(sample)
294
+
295
+ checkpoint_fn = (
296
+ partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
297
+ if self.gradient_checkpointing and self.training
298
+ else lambda x: x
299
+ )
300
+
301
+ for down_block in self.down_blocks:
302
+ sample = checkpoint_fn(down_block)(sample, downsample_in_time=downsample_in_time)
303
+
304
+ sample = checkpoint_fn(self.mid_block)(sample)
305
+
306
+ # post-process
307
+ sample = self.conv_norm_out(sample)
308
+ sample = self.conv_act(sample)
309
+ sample = self.conv_out(sample)
310
+
311
+ if self.latent_log_var == "uniform":
312
+ last_channel = sample[:, -1:, ...]
313
+ num_dims = sample.dim()
314
+
315
+ if num_dims == 4:
316
+ # For shape (B, C, H, W)
317
+ repeated_last_channel = last_channel.repeat(1, sample.shape[1] - 2, 1, 1)
318
+ sample = torch.cat([sample, repeated_last_channel], dim=1)
319
+ elif num_dims == 5:
320
+ # For shape (B, C, F, H, W)
321
+ repeated_last_channel = last_channel.repeat(1, sample.shape[1] - 2, 1, 1, 1)
322
+ sample = torch.cat([sample, repeated_last_channel], dim=1)
323
+ else:
324
+ raise ValueError(f"Invalid input shape: {sample.shape}")
325
+
326
+ return sample
327
+
328
+
329
+ class Decoder(nn.Module):
330
+ r"""
331
+ The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample.
332
+
333
+ Args:
334
+ in_channels (`int`, *optional*, defaults to 3):
335
+ The number of input channels.
336
+ out_channels (`int`, *optional*, defaults to 3):
337
+ The number of output channels.
338
+ block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
339
+ The number of output channels for each block.
340
+ layers_per_block (`int`, *optional*, defaults to 2):
341
+ The number of layers per block.
342
+ norm_num_groups (`int`, *optional*, defaults to 32):
343
+ The number of groups for normalization.
344
+ patch_size (`int`, *optional*, defaults to 1):
345
+ The patch size to use. Should be a power of 2.
346
+ norm_layer (`str`, *optional*, defaults to `group_norm`):
347
+ The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
348
+ """
349
+
350
+ def __init__(
351
+ self,
352
+ dims,
353
+ in_channels: int = 3,
354
+ out_channels: int = 3,
355
+ block_out_channels: Tuple[int, ...] = (64,),
356
+ layers_per_block: int = 2,
357
+ norm_num_groups: int = 32,
358
+ patch_size: int = 1,
359
+ norm_layer: str = "group_norm",
360
+ patch_size_t: Optional[int] = None,
361
+ add_channel_padding: Optional[bool] = False,
362
+ ):
363
+ super().__init__()
364
+ self.patch_size = patch_size
365
+ self.patch_size_t = patch_size_t if patch_size_t is not None else patch_size
366
+ self.add_channel_padding = add_channel_padding
367
+ self.layers_per_block = layers_per_block
368
+ if add_channel_padding:
369
+ out_channels = out_channels * self.patch_size**3
370
+ else:
371
+ out_channels = out_channels * self.patch_size_t * self.patch_size**2
372
+ self.out_channels = out_channels
373
+
374
+ self.conv_in = make_conv_nd(
375
+ dims,
376
+ in_channels,
377
+ block_out_channels[-1],
378
+ kernel_size=3,
379
+ stride=1,
380
+ padding=1,
381
+ )
382
+
383
+ self.mid_block = None
384
+ self.up_blocks = nn.ModuleList([])
385
+
386
+ self.mid_block = UNetMidBlock3D(
387
+ dims=dims,
388
+ in_channels=block_out_channels[-1],
389
+ num_layers=self.layers_per_block,
390
+ resnet_eps=1e-6,
391
+ resnet_groups=norm_num_groups,
392
+ norm_layer=norm_layer,
393
+ )
394
+
395
+ reversed_block_out_channels = list(reversed(block_out_channels))
396
+ output_channel = reversed_block_out_channels[0]
397
+ for i in range(len(reversed_block_out_channels)):
398
+ prev_output_channel = output_channel
399
+ output_channel = reversed_block_out_channels[i]
400
+
401
+ is_final_block = i == len(block_out_channels) - 1
402
+
403
+ up_block = UpDecoderBlock3D(
404
+ dims=dims,
405
+ num_layers=self.layers_per_block + 1,
406
+ in_channels=prev_output_channel,
407
+ out_channels=output_channel,
408
+ add_upsample=not is_final_block and 2 ** (len(block_out_channels) - i - 1) > patch_size,
409
+ resnet_eps=1e-6,
410
+ resnet_groups=norm_num_groups,
411
+ norm_layer=norm_layer,
412
+ )
413
+ self.up_blocks.append(up_block)
414
+
415
+ if norm_layer == "group_norm":
416
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
417
+ elif norm_layer == "pixel_norm":
418
+ self.conv_norm_out = PixelNorm()
419
+
420
+ self.conv_act = nn.SiLU()
421
+ self.conv_out = make_conv_nd(dims, block_out_channels[0], out_channels, 3, padding=1)
422
+
423
+ self.gradient_checkpointing = False
424
+
425
+ def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor:
426
+ r"""The forward method of the `Decoder` class."""
427
+ assert target_shape is not None, "target_shape must be provided"
428
+ upsample_in_time = sample.shape[2] < target_shape[2]
429
+
430
+ sample = self.conv_in(sample)
431
+
432
+ upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
433
+
434
+ checkpoint_fn = (
435
+ partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
436
+ if self.gradient_checkpointing and self.training
437
+ else lambda x: x
438
+ )
439
+
440
+ sample = checkpoint_fn(self.mid_block)(sample)
441
+ sample = sample.to(upscale_dtype)
442
+
443
+ for up_block in self.up_blocks:
444
+ sample = checkpoint_fn(up_block)(sample, upsample_in_time=upsample_in_time)
445
+
446
+ # post-process
447
+ sample = self.conv_norm_out(sample)
448
+ sample = self.conv_act(sample)
449
+ sample = self.conv_out(sample)
450
+
451
+ # un-patchify
452
+ patch_size_t = self.patch_size_t if upsample_in_time else 1
453
+ sample = unpatchify(
454
+ sample,
455
+ patch_size_hw=self.patch_size,
456
+ patch_size_t=patch_size_t,
457
+ add_channel_padding=self.add_channel_padding,
458
+ )
459
+
460
+ return sample
461
+
462
+
463
+ class DownEncoderBlock3D(nn.Module):
464
+ def __init__(
465
+ self,
466
+ dims: Union[int, Tuple[int, int]],
467
+ in_channels: int,
468
+ out_channels: int,
469
+ dropout: float = 0.0,
470
+ num_layers: int = 1,
471
+ resnet_eps: float = 1e-6,
472
+ resnet_groups: int = 32,
473
+ add_downsample: bool = True,
474
+ downsample_padding: int = 1,
475
+ norm_layer: str = "group_norm",
476
+ ):
477
+ super().__init__()
478
+ res_blocks = []
479
+
480
+ for i in range(num_layers):
481
+ in_channels = in_channels if i == 0 else out_channels
482
+ res_blocks.append(
483
+ ResnetBlock3D(
484
+ dims=dims,
485
+ in_channels=in_channels,
486
+ out_channels=out_channels,
487
+ eps=resnet_eps,
488
+ groups=resnet_groups,
489
+ dropout=dropout,
490
+ norm_layer=norm_layer,
491
+ )
492
+ )
493
+
494
+ self.res_blocks = nn.ModuleList(res_blocks)
495
+
496
+ if add_downsample:
497
+ self.downsample = Downsample3D(dims, out_channels, out_channels=out_channels, padding=downsample_padding)
498
+ else:
499
+ self.downsample = Identity()
500
+
501
+ def forward(self, hidden_states: torch.FloatTensor, downsample_in_time) -> torch.FloatTensor:
502
+ for resnet in self.res_blocks:
503
+ hidden_states = resnet(hidden_states)
504
+
505
+ hidden_states = self.downsample(hidden_states, downsample_in_time=downsample_in_time)
506
+
507
+ return hidden_states
508
+
509
+
510
+ class UNetMidBlock3D(nn.Module):
511
+ """
512
+ A 3D UNet mid-block [`UNetMidBlock3D`] with multiple residual blocks.
513
+
514
+ Args:
515
+ in_channels (`int`): The number of input channels.
516
+ dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
517
+ num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
518
+ resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
519
+ resnet_groups (`int`, *optional*, defaults to 32):
520
+ The number of groups to use in the group normalization layers of the resnet blocks.
521
+
522
+ Returns:
523
+ `torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
524
+ in_channels, height, width)`.
525
+
526
+ """
527
+
528
+ def __init__(
529
+ self,
530
+ dims: Union[int, Tuple[int, int]],
531
+ in_channels: int,
532
+ dropout: float = 0.0,
533
+ num_layers: int = 1,
534
+ resnet_eps: float = 1e-6,
535
+ resnet_groups: int = 32,
536
+ norm_layer: str = "group_norm",
537
+ ):
538
+ super().__init__()
539
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
540
+
541
+ self.res_blocks = nn.ModuleList(
542
+ [
543
+ ResnetBlock3D(
544
+ dims=dims,
545
+ in_channels=in_channels,
546
+ out_channels=in_channels,
547
+ eps=resnet_eps,
548
+ groups=resnet_groups,
549
+ dropout=dropout,
550
+ norm_layer=norm_layer,
551
+ )
552
+ for _ in range(num_layers)
553
+ ]
554
+ )
555
+
556
+ def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
557
+ for resnet in self.res_blocks:
558
+ hidden_states = resnet(hidden_states)
559
+
560
+ return hidden_states
561
+
562
+
563
+ class UpDecoderBlock3D(nn.Module):
564
+ def __init__(
565
+ self,
566
+ dims: Union[int, Tuple[int, int]],
567
+ in_channels: int,
568
+ out_channels: int,
569
+ resolution_idx: Optional[int] = None,
570
+ dropout: float = 0.0,
571
+ num_layers: int = 1,
572
+ resnet_eps: float = 1e-6,
573
+ resnet_groups: int = 32,
574
+ add_upsample: bool = True,
575
+ norm_layer: str = "group_norm",
576
+ ):
577
+ super().__init__()
578
+ res_blocks = []
579
+
580
+ for i in range(num_layers):
581
+ input_channels = in_channels if i == 0 else out_channels
582
+
583
+ res_blocks.append(
584
+ ResnetBlock3D(
585
+ dims=dims,
586
+ in_channels=input_channels,
587
+ out_channels=out_channels,
588
+ eps=resnet_eps,
589
+ groups=resnet_groups,
590
+ dropout=dropout,
591
+ norm_layer=norm_layer,
592
+ )
593
+ )
594
+
595
+ self.res_blocks = nn.ModuleList(res_blocks)
596
+
597
+ if add_upsample:
598
+ self.upsample = Upsample3D(dims=dims, channels=out_channels, out_channels=out_channels)
599
+ else:
600
+ self.upsample = Identity()
601
+
602
+ self.resolution_idx = resolution_idx
603
+
604
+ def forward(self, hidden_states: torch.FloatTensor, upsample_in_time=True) -> torch.FloatTensor:
605
+ for resnet in self.res_blocks:
606
+ hidden_states = resnet(hidden_states)
607
+
608
+ hidden_states = self.upsample(hidden_states, upsample_in_time=upsample_in_time)
609
+
610
+ return hidden_states
611
+
612
+
613
+ class ResnetBlock3D(nn.Module):
614
+ r"""
615
+ A Resnet block.
616
+
617
+ Parameters:
618
+ in_channels (`int`): The number of channels in the input.
619
+ out_channels (`int`, *optional*, default to be `None`):
620
+ The number of output channels for the first conv layer. If None, same as `in_channels`.
621
+ dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
622
+ groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
623
+ eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
624
+ """
625
+
626
+ def __init__(
627
+ self,
628
+ dims: Union[int, Tuple[int, int]],
629
+ in_channels: int,
630
+ out_channels: Optional[int] = None,
631
+ conv_shortcut: bool = False,
632
+ dropout: float = 0.0,
633
+ groups: int = 32,
634
+ eps: float = 1e-6,
635
+ norm_layer: str = "group_norm",
636
+ ):
637
+ super().__init__()
638
+ self.in_channels = in_channels
639
+ out_channels = in_channels if out_channels is None else out_channels
640
+ self.out_channels = out_channels
641
+ self.use_conv_shortcut = conv_shortcut
642
+
643
+ if norm_layer == "group_norm":
644
+ self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
645
+ elif norm_layer == "pixel_norm":
646
+ self.norm1 = PixelNorm()
647
+
648
+ self.non_linearity = nn.SiLU()
649
+
650
+ self.conv1 = make_conv_nd(dims, in_channels, out_channels, kernel_size=3, stride=1, padding=1)
651
+
652
+ if norm_layer == "group_norm":
653
+ self.norm2 = torch.nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True)
654
+ elif norm_layer == "pixel_norm":
655
+ self.norm2 = PixelNorm()
656
+
657
+ self.dropout = torch.nn.Dropout(dropout)
658
+
659
+ self.conv2 = make_conv_nd(dims, out_channels, out_channels, kernel_size=3, stride=1, padding=1)
660
+
661
+ self.conv_shortcut = (
662
+ make_linear_nd(dims=dims, in_channels=in_channels, out_channels=out_channels)
663
+ if in_channels != out_channels
664
+ else nn.Identity()
665
+ )
666
+
667
+ def forward(
668
+ self,
669
+ input_tensor: torch.FloatTensor,
670
+ ) -> torch.FloatTensor:
671
+ hidden_states = input_tensor
672
+
673
+ hidden_states = self.norm1(hidden_states)
674
+
675
+ hidden_states = self.non_linearity(hidden_states)
676
+
677
+ hidden_states = self.conv1(hidden_states)
678
+
679
+ hidden_states = self.norm2(hidden_states)
680
+
681
+ hidden_states = self.non_linearity(hidden_states)
682
+
683
+ hidden_states = self.dropout(hidden_states)
684
+
685
+ hidden_states = self.conv2(hidden_states)
686
+
687
+ input_tensor = self.conv_shortcut(input_tensor)
688
+
689
+ output_tensor = input_tensor + hidden_states
690
+
691
+ return output_tensor
692
+
693
+
694
+ class Downsample3D(nn.Module):
695
+ def __init__(self, dims, in_channels: int, out_channels: int, kernel_size: int = 3, padding: int = 1):
696
+ super().__init__()
697
+ stride: int = 2
698
+ self.padding = padding
699
+ self.in_channels = in_channels
700
+ self.dims = dims
701
+ self.conv = make_conv_nd(
702
+ dims=dims,
703
+ in_channels=in_channels,
704
+ out_channels=out_channels,
705
+ kernel_size=kernel_size,
706
+ stride=stride,
707
+ padding=padding,
708
+ )
709
+
710
+ def forward(self, x, downsample_in_time=True):
711
+ conv = self.conv
712
+ if self.padding == 0:
713
+ if self.dims == 2:
714
+ padding = (0, 1, 0, 1)
715
+ else:
716
+ padding = (0, 1, 0, 1, 0, 1 if downsample_in_time else 0)
717
+
718
+ x = functional.pad(x, padding, mode="constant", value=0)
719
+
720
+ if self.dims == (2, 1) and not downsample_in_time:
721
+ return conv(x, skip_time_conv=True)
722
+
723
+ return conv(x)
724
+
725
+
726
+ class Upsample3D(nn.Module):
727
+ """
728
+ An upsampling layer for 3D tensors of shape (B, C, D, H, W).
729
+
730
+ :param channels: channels in the inputs and outputs.
731
+ """
732
+
733
+ def __init__(self, dims, channels, out_channels=None):
734
+ super().__init__()
735
+ self.dims = dims
736
+ self.channels = channels
737
+ self.out_channels = out_channels or channels
738
+ self.conv = make_conv_nd(dims, channels, out_channels, kernel_size=3, padding=1, bias=True)
739
+
740
+ def forward(self, x, upsample_in_time):
741
+ if self.dims == 2:
742
+ x = functional.interpolate(x, (x.shape[2] * 2, x.shape[3] * 2), mode="nearest")
743
+ else:
744
+ time_scale_factor = 2 if upsample_in_time else 1
745
+ # print("before:", x.shape)
746
+ b, c, d, h, w = x.shape
747
+ x = rearrange(x, "b c d h w -> (b d) c h w")
748
+ # height and width interpolate
749
+ x = functional.interpolate(x, (x.shape[2] * 2, x.shape[3] * 2), mode="nearest")
750
+ _, _, h, w = x.shape
751
+
752
+ if not upsample_in_time and self.dims == (2, 1):
753
+ x = rearrange(x, "(b d) c h w -> b c d h w ", b=b, h=h, w=w)
754
+ return self.conv(x, skip_time_conv=True)
755
+
756
+ # Second ** upsampling ** which is essentially treated as a 1D convolution across the 'd' dimension
757
+ x = rearrange(x, "(b d) c h w -> (b h w) c 1 d", b=b)
758
+
759
+ # (b h w) c 1 d
760
+ new_d = x.shape[-1] * time_scale_factor
761
+ x = functional.interpolate(x, (1, new_d), mode="nearest")
762
+ # (b h w) c 1 new_d
763
+ x = rearrange(x, "(b h w) c 1 new_d -> b c new_d h w", b=b, h=h, w=w, new_d=new_d)
764
+ # b c d h w
765
+
766
+ # x = functional.interpolate(
767
+ # x, (x.shape[2] * time_scale_factor, x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
768
+ # )
769
+ # print("after:", x.shape)
770
+
771
+ return self.conv(x)
772
+
773
+
774
+ def patchify(x, patch_size_hw, patch_size_t=1, add_channel_padding=False):
775
+ if patch_size_hw == 1 and patch_size_t == 1:
776
+ return x
777
+ if x.dim() == 4:
778
+ x = rearrange(x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size_hw, r=patch_size_hw)
779
+ elif x.dim() == 5:
780
+ x = rearrange(x, "b c (f p) (h q) (w r) -> b (c p r q) f h w", p=patch_size_t, q=patch_size_hw, r=patch_size_hw)
781
+ else:
782
+ raise ValueError(f"Invalid input shape: {x.shape}")
783
+
784
+ if (x.dim() == 5) and (patch_size_hw > patch_size_t) and (patch_size_t > 1 or add_channel_padding):
785
+ channels_to_pad = x.shape[1] * (patch_size_hw // patch_size_t) - x.shape[1]
786
+ padding_zeros = torch.zeros(
787
+ x.shape[0],
788
+ channels_to_pad,
789
+ x.shape[2],
790
+ x.shape[3],
791
+ x.shape[4],
792
+ device=x.device,
793
+ dtype=x.dtype,
794
+ )
795
+ x = torch.cat([padding_zeros, x], dim=1)
796
+
797
+ return x
798
+
799
+
800
+ def unpatchify(x, patch_size_hw, patch_size_t=1, add_channel_padding=False):
801
+ if patch_size_hw == 1 and patch_size_t == 1:
802
+ return x
803
+
804
+ if (x.dim() == 5) and (patch_size_hw > patch_size_t) and (patch_size_t > 1 or add_channel_padding):
805
+ channels_to_keep = int(x.shape[1] * (patch_size_t / patch_size_hw))
806
+ x = x[:, :channels_to_keep, :, :, :]
807
+
808
+ if x.dim() == 4:
809
+ x = rearrange(x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size_hw, r=patch_size_hw)
810
+ elif x.dim() == 5:
811
+ x = rearrange(x, "b (c p r q) f h w -> b c (f p) (h q) (w r)", p=patch_size_t, q=patch_size_hw, r=patch_size_hw)
812
+
813
+ return x
814
+
815
+
816
+ def create_video_autoencoder_config(
817
+ latent_channels: int = 4,
818
+ ):
819
+ config = {
820
+ "_class_name": "VideoAutoencoder",
821
+ "dims": (2, 1), # 2 for Conv2, 3 for Conv3d, (2, 1) for Conv2d followed by Conv1d
822
+ "in_channels": 3, # Number of input color channels (e.g., RGB)
823
+ "out_channels": 3, # Number of output color channels
824
+ "latent_channels": latent_channels, # Number of channels in the latent space representation
825
+ "block_out_channels": [128, 256, 512, 512], # Number of output channels of each encoder / decoder inner block
826
+ "patch_size": 1,
827
+ }
828
+
829
+ return config
830
+
831
+
832
+ def create_video_autoencoder_pathify4x4x4_config(
833
+ latent_channels: int = 4,
834
+ ):
835
+ config = {
836
+ "_class_name": "VideoAutoencoder",
837
+ "dims": (2, 1), # 2 for Conv2, 3 for Conv3d, (2, 1) for Conv2d followed by Conv1d
838
+ "in_channels": 3, # Number of input color channels (e.g., RGB)
839
+ "out_channels": 3, # Number of output color channels
840
+ "latent_channels": latent_channels, # Number of channels in the latent space representation
841
+ "block_out_channels": [512] * 4, # Number of output channels of each encoder / decoder inner block
842
+ "patch_size": 4,
843
+ "latent_log_var": "uniform",
844
+ }
845
+
846
+ return config
847
+
848
+
849
+ def create_video_autoencoder_pathify4x4_config(
850
+ latent_channels: int = 4,
851
+ ):
852
+ config = {
853
+ "_class_name": "VideoAutoencoder",
854
+ "dims": 2, # 2 for Conv2, 3 for Conv3d, (2, 1) for Conv2d followed by Conv1d
855
+ "in_channels": 3, # Number of input color channels (e.g., RGB)
856
+ "out_channels": 3, # Number of output color channels
857
+ "latent_channels": latent_channels, # Number of channels in the latent space representation
858
+ "block_out_channels": [512] * 4, # Number of output channels of each encoder / decoder inner block
859
+ "patch_size": 4,
860
+ "norm_layer": "pixel_norm",
861
+ }
862
+
863
+ return config
864
+
865
+
866
+ def test_vae_patchify_unpatchify():
867
+ import torch
868
+
869
+ x = torch.randn(2, 3, 8, 64, 64)
870
+ x_patched = patchify(x, patch_size_hw=4, patch_size_t=4)
871
+ x_unpatched = unpatchify(x_patched, patch_size_hw=4, patch_size_t=4)
872
+ assert torch.allclose(x, x_unpatched)
873
+
874
+
875
+ def demo_video_autoencoder_forward_backward():
876
+ # Configuration for the VideoAutoencoder
877
+ config = create_video_autoencoder_pathify4x4x4_config()
878
+
879
+ # Instantiate the VideoAutoencoder with the specified configuration
880
+ video_autoencoder = VideoAutoencoder.from_config(config)
881
+
882
+ print(video_autoencoder)
883
+
884
+ # Print the total number of parameters in the video autoencoder
885
+ total_params = sum(p.numel() for p in video_autoencoder.parameters())
886
+ print(f"Total number of parameters in VideoAutoencoder: {total_params:,}")
887
+
888
+ # Create a mock input tensor simulating a batch of videos
889
+ # Shape: (batch_size, channels, depth, height, width)
890
+ # E.g., 4 videos, each with 3 color channels, 16 frames, and 64x64 pixels per frame
891
+ input_videos = torch.randn(2, 3, 8, 64, 64)
892
+
893
+ # Forward pass: encode and decode the input videos
894
+ latent = video_autoencoder.encode(input_videos).latent_dist.mode()
895
+ print(f"input shape={input_videos.shape}")
896
+ print(f"latent shape={latent.shape}")
897
+ reconstructed_videos = video_autoencoder.decode(latent, target_shape=input_videos.shape).sample
898
+
899
+ print(f"reconstructed shape={reconstructed_videos.shape}")
900
+
901
+ # Calculate the loss (e.g., mean squared error)
902
+ loss = torch.nn.functional.mse_loss(input_videos, reconstructed_videos)
903
+
904
+ # Perform backward pass
905
+ loss.backward()
906
+
907
+ print(f"Demo completed with loss: {loss.item()}")
908
+
909
+
910
+ # Ensure to call the demo function to execute the forward and backward pass
911
+ if __name__ == "__main__":
912
+ demo_video_autoencoder_forward_backward()
xora/models/transformers/embeddings.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/embeddings.py
2
+ import math
3
+
4
+ import numpy as np
5
+ import torch
6
+ from einops import rearrange
7
+ from torch import nn
8
+
9
+
10
+ def get_timestep_embedding(
11
+ timesteps: torch.Tensor,
12
+ embedding_dim: int,
13
+ flip_sin_to_cos: bool = False,
14
+ downscale_freq_shift: float = 1,
15
+ scale: float = 1,
16
+ max_period: int = 10000,
17
+ ):
18
+ """
19
+ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
20
+
21
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
22
+ These may be fractional.
23
+ :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
24
+ embeddings. :return: an [N x dim] Tensor of positional embeddings.
25
+ """
26
+ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
27
+
28
+ half_dim = embedding_dim // 2
29
+ exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device)
30
+ exponent = exponent / (half_dim - downscale_freq_shift)
31
+
32
+ emb = torch.exp(exponent)
33
+ emb = timesteps[:, None].float() * emb[None, :]
34
+
35
+ # scale embeddings
36
+ emb = scale * emb
37
+
38
+ # concat sine and cosine embeddings
39
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
40
+
41
+ # flip sine and cosine embeddings
42
+ if flip_sin_to_cos:
43
+ emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
44
+
45
+ # zero pad
46
+ if embedding_dim % 2 == 1:
47
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
48
+ return emb
49
+
50
+
51
+ def get_3d_sincos_pos_embed(embed_dim, grid, w, h, f):
52
+ """
53
+ grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
54
+ [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
55
+ """
56
+ grid = rearrange(grid, "c (f h w) -> c f h w", h=h, w=w)
57
+ grid = rearrange(grid, "c f h w -> c h w f", h=h, w=w)
58
+ grid = grid.reshape([3, 1, w, h, f])
59
+ pos_embed = get_3d_sincos_pos_embed_from_grid(embed_dim, grid)
60
+ pos_embed = pos_embed.transpose(1, 0, 2, 3)
61
+ return rearrange(pos_embed, "h w f c -> (f h w) c")
62
+
63
+
64
+ def get_3d_sincos_pos_embed_from_grid(embed_dim, grid):
65
+ if embed_dim % 3 != 0:
66
+ raise ValueError("embed_dim must be divisible by 3")
67
+
68
+ # use half of dimensions to encode grid_h
69
+ emb_f = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, grid[0]) # (H*W*T, D/3)
70
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, grid[1]) # (H*W*T, D/3)
71
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, grid[2]) # (H*W*T, D/3)
72
+
73
+ emb = np.concatenate([emb_h, emb_w, emb_f], axis=-1) # (H*W*T, D)
74
+ return emb
75
+
76
+
77
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
78
+ """
79
+ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
80
+ """
81
+ if embed_dim % 2 != 0:
82
+ raise ValueError("embed_dim must be divisible by 2")
83
+
84
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
85
+ omega /= embed_dim / 2.0
86
+ omega = 1.0 / 10000**omega # (D/2,)
87
+
88
+ pos_shape = pos.shape
89
+
90
+ pos = pos.reshape(-1)
91
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
92
+ out = out.reshape([*pos_shape, -1])[0]
93
+
94
+ emb_sin = np.sin(out) # (M, D/2)
95
+ emb_cos = np.cos(out) # (M, D/2)
96
+
97
+ emb = np.concatenate([emb_sin, emb_cos], axis=-1) # (M, D)
98
+ return emb
99
+
100
+
101
+ class SinusoidalPositionalEmbedding(nn.Module):
102
+ """Apply positional information to a sequence of embeddings.
103
+
104
+ Takes in a sequence of embeddings with shape (batch_size, seq_length, embed_dim) and adds positional embeddings to
105
+ them
106
+
107
+ Args:
108
+ embed_dim: (int): Dimension of the positional embedding.
109
+ max_seq_length: Maximum sequence length to apply positional embeddings
110
+
111
+ """
112
+
113
+ def __init__(self, embed_dim: int, max_seq_length: int = 32):
114
+ super().__init__()
115
+ position = torch.arange(max_seq_length).unsqueeze(1)
116
+ div_term = torch.exp(torch.arange(0, embed_dim, 2) * (-math.log(10000.0) / embed_dim))
117
+ pe = torch.zeros(1, max_seq_length, embed_dim)
118
+ pe[0, :, 0::2] = torch.sin(position * div_term)
119
+ pe[0, :, 1::2] = torch.cos(position * div_term)
120
+ self.register_buffer("pe", pe)
121
+
122
+ def forward(self, x):
123
+ _, seq_length, _ = x.shape
124
+ x = x + self.pe[:, :seq_length]
125
+ return x
xora/models/transformers/transformer3d.py CHANGED
@@ -1,7 +1,7 @@
1
  # Adapted from: https://github.com/huggingface/diffusers/blob/v0.26.3/src/diffusers/models/transformers/transformer_2d.py
2
  import math
3
  from dataclasses import dataclass
4
- from typing import Any, Dict, List, Optional
5
 
6
  import torch
7
  from diffusers.configuration_utils import ConfigMixin, register_to_config
@@ -9,10 +9,13 @@ from diffusers.models.embeddings import PixArtAlphaTextProjection
9
  from diffusers.models.modeling_utils import ModelMixin
10
  from diffusers.models.normalization import AdaLayerNormSingle
11
  from diffusers.utils import BaseOutput, is_torch_version
 
12
  from torch import nn
13
 
14
  from xora.models.transformers.attention import BasicTransformerBlock
 
15
 
 
16
 
17
  @dataclass
18
  class Transformer3DModelOutput(BaseOutput):
@@ -143,6 +146,61 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
143
 
144
  self.gradient_checkpointing = False
145
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  def _set_gradient_checkpointing(self, module, value=False):
147
  if hasattr(module, "gradient_checkpointing"):
148
  module.gradient_checkpointing = value
@@ -287,10 +345,14 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
287
  if self.timestep_scale_multiplier:
288
  timestep = self.timestep_scale_multiplier * timestep
289
 
290
- if self.positional_embedding_type == "rope":
 
 
 
 
 
 
291
  freqs_cis = self.precompute_freqs_cis(indices_grid)
292
- else:
293
- raise NotImplementedError("Only rope pos embed supported.")
294
 
295
  batch_size = hidden_states.shape[0]
296
  timestep, embedded_timestep = self.adaln_single(
@@ -358,3 +420,14 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
358
 
359
  return Transformer3DModelOutput(sample=hidden_states)
360
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # Adapted from: https://github.com/huggingface/diffusers/blob/v0.26.3/src/diffusers/models/transformers/transformer_2d.py
2
  import math
3
  from dataclasses import dataclass
4
+ from typing import Any, Dict, List, Optional, Literal
5
 
6
  import torch
7
  from diffusers.configuration_utils import ConfigMixin, register_to_config
 
9
  from diffusers.models.modeling_utils import ModelMixin
10
  from diffusers.models.normalization import AdaLayerNormSingle
11
  from diffusers.utils import BaseOutput, is_torch_version
12
+ from diffusers.utils import logging
13
  from torch import nn
14
 
15
  from xora.models.transformers.attention import BasicTransformerBlock
16
+ from xora.models.transformers.embeddings import get_3d_sincos_pos_embed
17
 
18
+ logger = logging.get_logger(__name__)
19
 
20
  @dataclass
21
  class Transformer3DModelOutput(BaseOutput):
 
146
 
147
  self.gradient_checkpointing = False
148
 
149
+ def set_use_tpu_flash_attention(self):
150
+ r"""
151
+ Function sets the flag in this object and propagates down the children. The flag will enforce the usage of TPU
152
+ attention kernel.
153
+ """
154
+ logger.info(" ENABLE TPU FLASH ATTENTION -> TRUE")
155
+ # if using TPU -> configure components to use TPU flash attention
156
+ if dist_util.acceleration_type() == dist_util.AccelerationType.TPU:
157
+ self.use_tpu_flash_attention = True
158
+ # push config down to the attention modules
159
+ for block in self.transformer_blocks:
160
+ block.set_use_tpu_flash_attention()
161
+
162
+ def initialize(self, embedding_std: float, mode: Literal["xora", "pixart"]):
163
+ def _basic_init(module):
164
+ if isinstance(module, nn.Linear):
165
+ torch.nn.init.xavier_uniform_(module.weight)
166
+ if module.bias is not None:
167
+ nn.init.constant_(module.bias, 0)
168
+
169
+ self.apply(_basic_init)
170
+
171
+ # Initialize timestep embedding MLP:
172
+ nn.init.normal_(self.adaln_single.emb.timestep_embedder.linear_1.weight, std=embedding_std)
173
+ nn.init.normal_(self.adaln_single.emb.timestep_embedder.linear_2.weight, std=embedding_std)
174
+ nn.init.normal_(self.adaln_single.linear.weight, std=embedding_std)
175
+
176
+ if hasattr(self.adaln_single.emb, "resolution_embedder"):
177
+ nn.init.normal_(self.adaln_single.emb.resolution_embedder.linear_1.weight, std=embedding_std)
178
+ nn.init.normal_(self.adaln_single.emb.resolution_embedder.linear_2.weight, std=embedding_std)
179
+ if hasattr(self.adaln_single.emb, "aspect_ratio_embedder"):
180
+ nn.init.normal_(self.adaln_single.emb.aspect_ratio_embedder.linear_1.weight, std=embedding_std)
181
+ nn.init.normal_(self.adaln_single.emb.aspect_ratio_embedder.linear_2.weight, std=embedding_std)
182
+
183
+ # Initialize caption embedding MLP:
184
+ nn.init.normal_(self.caption_projection.linear_1.weight, std=embedding_std)
185
+ nn.init.normal_(self.caption_projection.linear_1.weight, std=embedding_std)
186
+
187
+ # Zero-out adaLN modulation layers in PixArt blocks:
188
+ for block in self.transformer_blocks:
189
+ if mode == "xora":
190
+ nn.init.constant_(block.attn1.to_out[0].weight, 0)
191
+ nn.init.constant_(block.attn1.to_out[0].bias, 0)
192
+
193
+ nn.init.constant_(block.attn2.to_out[0].weight, 0)
194
+ nn.init.constant_(block.attn2.to_out[0].bias, 0)
195
+
196
+ if mode == "xora":
197
+ nn.init.constant_(block.ff.net[2].weight, 0)
198
+ nn.init.constant_(block.ff.net[2].bias, 0)
199
+
200
+ # Zero-out output layers:
201
+ nn.init.constant_(self.proj_out.weight, 0)
202
+ nn.init.constant_(self.proj_out.bias, 0)
203
+
204
  def _set_gradient_checkpointing(self, module, value=False):
205
  if hasattr(module, "gradient_checkpointing"):
206
  module.gradient_checkpointing = value
 
345
  if self.timestep_scale_multiplier:
346
  timestep = self.timestep_scale_multiplier * timestep
347
 
348
+ if self.positional_embedding_type == "absolute":
349
+ pos_embed_3d = self.get_absolute_pos_embed(indices_grid).to(hidden_states.device)
350
+ if self.project_to_2d_pos:
351
+ pos_embed = self.to_2d_proj(pos_embed_3d)
352
+ hidden_states = (hidden_states + pos_embed).to(hidden_states.dtype)
353
+ freqs_cis = None
354
+ elif self.positional_embedding_type == "rope":
355
  freqs_cis = self.precompute_freqs_cis(indices_grid)
 
 
356
 
357
  batch_size = hidden_states.shape[0]
358
  timestep, embedded_timestep = self.adaln_single(
 
420
 
421
  return Transformer3DModelOutput(sample=hidden_states)
422
 
423
+ def get_absolute_pos_embed(self, grid):
424
+ grid_np = grid[0].cpu().numpy()
425
+ embed_dim_3d = math.ceil((self.inner_dim / 2) * 3) if self.project_to_2d_pos else self.inner_dim
426
+ pos_embed = get_3d_sincos_pos_embed( # (f h w)
427
+ embed_dim_3d,
428
+ grid_np,
429
+ h=int(max(grid_np[1]) + 1),
430
+ w=int(max(grid_np[2]) + 1),
431
+ f=int(max(grid_np[0] + 1)),
432
+ )
433
+ return torch.from_numpy(pos_embed).float().unsqueeze(0)
xora/pipelines/pipeline_video_pixart_alpha.py CHANGED
@@ -32,16 +32,106 @@ from xora.models.transformers.symmetric_patchifier import Patchifier
32
  from xora.models.autoencoders.vae_encode import get_vae_size_scale_factor, vae_decode, vae_encode
33
  from xora.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
34
  from xora.schedulers.rf import TimestepShifter
 
35
 
36
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
37
 
38
-
39
  if is_bs4_available():
40
  from bs4 import BeautifulSoup
41
 
42
  if is_ftfy_available():
43
  import ftfy
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  def retrieve_timesteps(
46
  scheduler,
47
  num_inference_steps: Optional[int] = None,
@@ -520,14 +610,7 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
520
 
521
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
522
  def prepare_latents(
523
- self,
524
- batch_size,
525
- num_latent_channels,
526
- num_patches,
527
- dtype,
528
- device,
529
- generator,
530
- latents=None,
531
  ):
532
  shape = (
533
  batch_size,
@@ -543,6 +626,9 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
543
 
544
  if latents is None:
545
  latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
 
 
 
546
  else:
547
  latents = latents.to(device)
548
 
@@ -582,8 +668,8 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
582
 
583
  return samples
584
 
585
-
586
  @torch.no_grad()
 
587
  def __call__(
588
  self,
589
  height: int,
@@ -607,6 +693,7 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
607
  return_dict: bool = True,
608
  callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
609
  clean_caption: bool = True,
 
610
  **kwargs,
611
  ) -> Union[ImagePipelineOutput, Tuple]:
612
  """
@@ -736,8 +823,15 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
736
  prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
737
  prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
738
 
739
- # 4. Prepare latents.
740
  self.video_scale_factor = self.video_scale_factor if is_video else 1
 
 
 
 
 
 
 
741
  latent_height = height // self.vae_scale_factor
742
  latent_width = width // self.vae_scale_factor
743
  latent_num_frames = num_frames // self.video_scale_factor
@@ -752,7 +846,12 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
752
  dtype=prompt_embeds.dtype,
753
  device=device,
754
  generator=generator,
 
 
755
  )
 
 
 
756
 
757
  # 5. Prepare timesteps
758
  retrieve_timesteps_kwargs = {}
@@ -790,7 +889,7 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
790
  elif len(current_timestep.shape) == 0:
791
  current_timestep = current_timestep[None].to(latent_model_input.device)
792
  # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
793
- current_timestep = current_timestep.expand(latent_model_input.shape[0])
794
  scale_grid = (
795
  (1 / latent_frame_rates, self.vae_scale_factor, self.vae_scale_factor)
796
  if self.transformer.use_rope
@@ -805,6 +904,9 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
805
  device=latents.device,
806
  )
807
 
 
 
 
808
  # predict noise model_output
809
  noise_pred = self.transformer(
810
  latent_model_input.to(self.transformer.dtype),
@@ -819,13 +921,20 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
819
  if do_classifier_free_guidance:
820
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
821
  noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
 
822
 
823
  # learned sigma
824
  if self.transformer.config.out_channels // 2 == self.transformer.config.in_channels:
825
  noise_pred = noise_pred.chunk(2, dim=1)[0]
826
 
827
  # compute previous image: x_t -> x_t-1
828
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
 
 
 
 
 
 
829
 
830
  # call the callback, if provided
831
  if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
@@ -857,3 +966,62 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
857
  return (image,)
858
 
859
  return ImagePipelineOutput(images=image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  from xora.models.autoencoders.vae_encode import get_vae_size_scale_factor, vae_decode, vae_encode
33
  from xora.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
34
  from xora.schedulers.rf import TimestepShifter
35
+ from xora.utils.conditioning_method import ConditioningMethod
36
 
37
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
38
 
 
39
  if is_bs4_available():
40
  from bs4 import BeautifulSoup
41
 
42
  if is_ftfy_available():
43
  import ftfy
44
 
45
+ EXAMPLE_DOC_STRING = """
46
+ Examples:
47
+ ```py
48
+ >>> import torch
49
+ >>> from diffusers import PixArtAlphaPipeline
50
+
51
+ >>> # You can replace the checkpoint id with "PixArt-alpha/PixArt-XL-2-512x512" too.
52
+ >>> pipe = PixArtAlphaPipeline.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", torch_dtype=torch.float16)
53
+ >>> # Enable memory optimizations.
54
+ >>> pipe.enable_model_cpu_offload()
55
+
56
+ >>> prompt = "A small cactus with a happy face in the Sahara desert."
57
+ >>> image = pipe(prompt).images[0]
58
+ ```
59
+ """
60
+
61
+ ASPECT_RATIO_1024_BIN = {
62
+ "0.25": [512.0, 2048.0],
63
+ "0.28": [512.0, 1856.0],
64
+ "0.32": [576.0, 1792.0],
65
+ "0.33": [576.0, 1728.0],
66
+ "0.35": [576.0, 1664.0],
67
+ "0.4": [640.0, 1600.0],
68
+ "0.42": [640.0, 1536.0],
69
+ "0.48": [704.0, 1472.0],
70
+ "0.5": [704.0, 1408.0],
71
+ "0.52": [704.0, 1344.0],
72
+ "0.57": [768.0, 1344.0],
73
+ "0.6": [768.0, 1280.0],
74
+ "0.68": [832.0, 1216.0],
75
+ "0.72": [832.0, 1152.0],
76
+ "0.78": [896.0, 1152.0],
77
+ "0.82": [896.0, 1088.0],
78
+ "0.88": [960.0, 1088.0],
79
+ "0.94": [960.0, 1024.0],
80
+ "1.0": [1024.0, 1024.0],
81
+ "1.07": [1024.0, 960.0],
82
+ "1.13": [1088.0, 960.0],
83
+ "1.21": [1088.0, 896.0],
84
+ "1.29": [1152.0, 896.0],
85
+ "1.38": [1152.0, 832.0],
86
+ "1.46": [1216.0, 832.0],
87
+ "1.67": [1280.0, 768.0],
88
+ "1.75": [1344.0, 768.0],
89
+ "2.0": [1408.0, 704.0],
90
+ "2.09": [1472.0, 704.0],
91
+ "2.4": [1536.0, 640.0],
92
+ "2.5": [1600.0, 640.0],
93
+ "3.0": [1728.0, 576.0],
94
+ "4.0": [2048.0, 512.0],
95
+ }
96
+
97
+ ASPECT_RATIO_512_BIN = {
98
+ "0.25": [256.0, 1024.0],
99
+ "0.28": [256.0, 928.0],
100
+ "0.32": [288.0, 896.0],
101
+ "0.33": [288.0, 864.0],
102
+ "0.35": [288.0, 832.0],
103
+ "0.4": [320.0, 800.0],
104
+ "0.42": [320.0, 768.0],
105
+ "0.48": [352.0, 736.0],
106
+ "0.5": [352.0, 704.0],
107
+ "0.52": [352.0, 672.0],
108
+ "0.57": [384.0, 672.0],
109
+ "0.6": [384.0, 640.0],
110
+ "0.68": [416.0, 608.0],
111
+ "0.72": [416.0, 576.0],
112
+ "0.78": [448.0, 576.0],
113
+ "0.82": [448.0, 544.0],
114
+ "0.88": [480.0, 544.0],
115
+ "0.94": [480.0, 512.0],
116
+ "1.0": [512.0, 512.0],
117
+ "1.07": [512.0, 480.0],
118
+ "1.13": [544.0, 480.0],
119
+ "1.21": [544.0, 448.0],
120
+ "1.29": [576.0, 448.0],
121
+ "1.38": [576.0, 416.0],
122
+ "1.46": [608.0, 416.0],
123
+ "1.67": [640.0, 384.0],
124
+ "1.75": [672.0, 384.0],
125
+ "2.0": [704.0, 352.0],
126
+ "2.09": [736.0, 352.0],
127
+ "2.4": [768.0, 320.0],
128
+ "2.5": [800.0, 320.0],
129
+ "3.0": [864.0, 288.0],
130
+ "4.0": [1024.0, 256.0],
131
+ }
132
+
133
+
134
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
135
  def retrieve_timesteps(
136
  scheduler,
137
  num_inference_steps: Optional[int] = None,
 
610
 
611
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
612
  def prepare_latents(
613
+ self, batch_size, num_latent_channels, num_patches, dtype, device, generator, latents=None, latents_mask=None
 
 
 
 
 
 
 
614
  ):
615
  shape = (
616
  batch_size,
 
626
 
627
  if latents is None:
628
  latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
629
+ elif latents_mask is not None:
630
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
631
+ latents = latents * latents_mask[..., None] + noise * (1 - latents_mask[..., None])
632
  else:
633
  latents = latents.to(device)
634
 
 
668
 
669
  return samples
670
 
 
671
  @torch.no_grad()
672
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
673
  def __call__(
674
  self,
675
  height: int,
 
693
  return_dict: bool = True,
694
  callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
695
  clean_caption: bool = True,
696
+ media_items: Optional[torch.FloatTensor] = None,
697
  **kwargs,
698
  ) -> Union[ImagePipelineOutput, Tuple]:
699
  """
 
823
  prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
824
  prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
825
 
826
+ # 3b. Encode and prepare conditioning data
827
  self.video_scale_factor = self.video_scale_factor if is_video else 1
828
+ conditioning_method = kwargs.get("conditioning_method", None)
829
+ vae_per_channel_normalize = kwargs.get("vae_per_channel_normalize", False)
830
+ init_latents, conditioning_mask = self.prepare_conditioning(
831
+ media_items, num_frames, height, width, conditioning_method, vae_per_channel_normalize
832
+ )
833
+
834
+ # 4. Prepare latents.
835
  latent_height = height // self.vae_scale_factor
836
  latent_width = width // self.vae_scale_factor
837
  latent_num_frames = num_frames // self.video_scale_factor
 
846
  dtype=prompt_embeds.dtype,
847
  device=device,
848
  generator=generator,
849
+ latents=init_latents,
850
+ latents_mask=conditioning_mask,
851
  )
852
+ if conditioning_mask is not None and is_video:
853
+ assert num_images_per_prompt == 1
854
+ conditioning_mask = torch.cat([conditioning_mask] * 2) if do_classifier_free_guidance else conditioning_mask
855
 
856
  # 5. Prepare timesteps
857
  retrieve_timesteps_kwargs = {}
 
889
  elif len(current_timestep.shape) == 0:
890
  current_timestep = current_timestep[None].to(latent_model_input.device)
891
  # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
892
+ current_timestep = current_timestep.expand(latent_model_input.shape[0]).unsqueeze(-1)
893
  scale_grid = (
894
  (1 / latent_frame_rates, self.vae_scale_factor, self.vae_scale_factor)
895
  if self.transformer.use_rope
 
904
  device=latents.device,
905
  )
906
 
907
+ if conditioning_mask is not None:
908
+ current_timestep = current_timestep * (1 - conditioning_mask)
909
+
910
  # predict noise model_output
911
  noise_pred = self.transformer(
912
  latent_model_input.to(self.transformer.dtype),
 
921
  if do_classifier_free_guidance:
922
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
923
  noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
924
+ current_timestep, _ = current_timestep.chunk(2)
925
 
926
  # learned sigma
927
  if self.transformer.config.out_channels // 2 == self.transformer.config.in_channels:
928
  noise_pred = noise_pred.chunk(2, dim=1)[0]
929
 
930
  # compute previous image: x_t -> x_t-1
931
+ latents = self.scheduler.step(
932
+ noise_pred,
933
+ t if current_timestep is None else current_timestep,
934
+ latents,
935
+ **extra_step_kwargs,
936
+ return_dict=False,
937
+ )[0]
938
 
939
  # call the callback, if provided
940
  if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
 
966
  return (image,)
967
 
968
  return ImagePipelineOutput(images=image)
969
+
970
+ def prepare_conditioning(
971
+ self,
972
+ media_items: torch.Tensor,
973
+ num_frames: int,
974
+ height: int,
975
+ width: int,
976
+ method: ConditioningMethod = ConditioningMethod.UNCONDITIONAL,
977
+ vae_per_channel_normalize: bool = False,
978
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
979
+ """
980
+ Prepare the conditioning data for the video generation. If an input media item is provided, encode it
981
+ and set the conditioning_mask to indicate which tokens to condition on. Input media item should have
982
+ the same height and width as the generated video.
983
+
984
+ Args:
985
+ media_items (torch.Tensor): media items to condition on (images or videos)
986
+ num_frames (int): number of frames to generate
987
+ height (int): height of the generated video
988
+ width (int): width of the generated video
989
+ method (ConditioningMethod, optional): conditioning method to use. Defaults to ConditioningMethod.UNCONDITIONAL.
990
+ vae_per_channel_normalize (bool, optional): whether to normalize the input to the VAE per channel. Defaults to False.
991
+
992
+ Returns:
993
+ Tuple[torch.Tensor, torch.Tensor]: the conditioning latents and the conditioning mask
994
+ """
995
+ if media_items is None or method == ConditioningMethod.UNCONDITIONAL:
996
+ return None, None
997
+
998
+ assert media_items.ndim == 5
999
+ assert height == media_items.shape[-2] and width == media_items.shape[-1]
1000
+
1001
+ # Encode the input video and repeat to the required number of frame-tokens
1002
+ init_latents = vae_encode(
1003
+ media_items.to(dtype=self.vae.dtype, device=self.vae.device),
1004
+ self.vae,
1005
+ vae_per_channel_normalize=vae_per_channel_normalize,
1006
+ ).float()
1007
+
1008
+ init_len, target_len = init_latents.shape[2], num_frames // self.video_scale_factor
1009
+ if isinstance(self.vae, CausalVideoAutoencoder):
1010
+ target_len += 1
1011
+ init_latents = init_latents[:, :, :target_len]
1012
+ if target_len > init_len:
1013
+ repeat_factor = (target_len + init_len - 1) // init_len # Ceiling division
1014
+ init_latents = init_latents.repeat(1, 1, repeat_factor, 1, 1)[:, :, :target_len]
1015
+
1016
+ # Prepare the conditioning mask (1.0 = condition on this token)
1017
+ b, n, f, h, w = init_latents.shape
1018
+ conditioning_mask = torch.zeros([b, 1, f, h, w], device=init_latents.device)
1019
+ if method in [ConditioningMethod.FIRST_FRAME, ConditioningMethod.FIRST_AND_LAST_FRAME]:
1020
+ conditioning_mask[:, :, 0] = 1.0
1021
+ if method in [ConditioningMethod.LAST_FRAME, ConditioningMethod.FIRST_AND_LAST_FRAME]:
1022
+ conditioning_mask[:, :, -1] = 1.0
1023
+
1024
+ # Patchify the init latents and the mask
1025
+ conditioning_mask = self.patchifier.patchify(conditioning_mask).squeeze(-1)
1026
+ init_latents = self.patchifier.patchify(latents=init_latents)
1027
+ return init_latents, conditioning_mask
xora/schedulers/rf.py CHANGED
@@ -9,7 +9,7 @@ from diffusers.schedulers.scheduling_utils import SchedulerMixin
9
  from diffusers.utils import BaseOutput
10
  from torch import Tensor
11
 
12
- from xora.utils.torch_utils import append_dims
13
 
14
 
15
  def simple_diffusion_resolution_dependent_timestep_shift(
@@ -199,8 +199,17 @@ class RectifiedFlowScheduler(SchedulerMixin, ConfigMixin, TimestepShifter):
199
  "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
200
  )
201
 
202
- current_index = (self.timesteps - timestep).abs().argmin()
203
- dt = self.delta_timesteps.gather(0, current_index.unsqueeze(0))
 
 
 
 
 
 
 
 
 
204
 
205
  prev_sample = sample - dt * model_output
206
 
@@ -219,4 +228,4 @@ class RectifiedFlowScheduler(SchedulerMixin, ConfigMixin, TimestepShifter):
219
  sigmas = append_dims(sigmas, original_samples.ndim)
220
  alphas = 1 - sigmas
221
  noisy_samples = alphas * original_samples + sigmas * noise
222
- return noisy_samples
 
9
  from diffusers.utils import BaseOutput
10
  from torch import Tensor
11
 
12
+ from txt2img.common.torch_utils import append_dims
13
 
14
 
15
  def simple_diffusion_resolution_dependent_timestep_shift(
 
199
  "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
200
  )
201
 
202
+ if timestep.ndim == 0:
203
+ # Global timestep
204
+ current_index = (self.timesteps - timestep).abs().argmin()
205
+ dt = self.delta_timesteps.gather(0, current_index.unsqueeze(0))
206
+ else:
207
+ # Timestep per token
208
+ assert timestep.ndim == 2
209
+ current_index = (self.timesteps[:, None, None] - timestep[None]).abs().argmin(dim=0)
210
+ dt = self.delta_timesteps[current_index]
211
+ # Special treatment for zero timestep tokens - set dt to 0 so prev_sample = sample
212
+ dt = torch.where(timestep == 0.0, torch.zeros_like(dt), dt)[..., None]
213
 
214
  prev_sample = sample - dt * model_output
215
 
 
228
  sigmas = append_dims(sigmas, original_samples.ndim)
229
  alphas = 1 - sigmas
230
  noisy_samples = alphas * original_samples + sigmas * noise
231
+ return noisy_samples
xora/utils/conditioning_method.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+
3
+ class ConditioningMethod(Enum):
4
+ UNCONDITIONAL = "unconditional"
5
+ FIRST_FRAME = "first_frame"
6
+ LAST_FRAME = "last_frame"
7
+ FIRST_AND_LAST_FRAME = "first_and_last_frame"
xora/utils/dist_util.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+
3
+ class AccelerationType(Enum):
4
+ CPU = "cpu"
5
+ GPU = "gpu"
6
+ TPU = "tpu"
7
+ MPS = "mps"
8
+
9
+ def execute_graph() -> None:
10
+ if _acceleration_type == AccelerationType.TPU:
11
+ xm.mark_step()