Sapir Weissbuch commited on
Commit
645fba0
·
unverified ·
2 Parent(s): 73a9e96 d504563

Merge pull request #18 from LightricksResearch/delete-pixart

Browse files
xora/examples/image_to_video.py CHANGED
@@ -3,7 +3,7 @@ from xora.models.autoencoders.causal_video_autoencoder import CausalVideoAutoenc
3
  from xora.models.transformers.transformer3d import Transformer3DModel
4
  from xora.models.transformers.symmetric_patchifier import SymmetricPatchifier
5
  from xora.schedulers.rf import RectifiedFlowScheduler
6
- from xora.pipelines.pipeline_video_pixart_alpha import VideoPixArtAlphaPipeline
7
  from pathlib import Path
8
  from transformers import T5EncoderModel, T5Tokenizer
9
  import safetensors.torch
@@ -180,7 +180,7 @@ def main():
180
  "vae": vae,
181
  }
182
 
183
- pipeline = VideoPixArtAlphaPipeline(**submodel_dict).to("cuda")
184
 
185
  # Load media (video or image)
186
  if args.video_path:
 
3
  from xora.models.transformers.transformer3d import Transformer3DModel
4
  from xora.models.transformers.symmetric_patchifier import SymmetricPatchifier
5
  from xora.schedulers.rf import RectifiedFlowScheduler
6
+ from xora.pipelines.pipeline_xora_video import XoraVideoPipeline
7
  from pathlib import Path
8
  from transformers import T5EncoderModel, T5Tokenizer
9
  import safetensors.torch
 
180
  "vae": vae,
181
  }
182
 
183
+ pipeline = XoraVideoPipeline(**submodel_dict).to("cuda")
184
 
185
  # Load media (video or image)
186
  if args.video_path:
xora/examples/text_to_video.py CHANGED
@@ -3,7 +3,7 @@ from xora.models.autoencoders.causal_video_autoencoder import CausalVideoAutoenc
3
  from xora.models.transformers.transformer3d import Transformer3DModel
4
  from xora.models.transformers.symmetric_patchifier import SymmetricPatchifier
5
  from xora.schedulers.rf import RectifiedFlowScheduler
6
- from xora.pipelines.pipeline_video_pixart_alpha import VideoPixArtAlphaPipeline
7
  from pathlib import Path
8
  from transformers import T5EncoderModel, T5Tokenizer
9
  import safetensors.torch
@@ -82,7 +82,7 @@ def main():
82
  "vae": vae,
83
  }
84
 
85
- pipeline = VideoPixArtAlphaPipeline(**submodel_dict).to("cuda")
86
 
87
  # Sample input
88
  num_inference_steps = 20
 
3
  from xora.models.transformers.transformer3d import Transformer3DModel
4
  from xora.models.transformers.symmetric_patchifier import SymmetricPatchifier
5
  from xora.schedulers.rf import RectifiedFlowScheduler
6
+ from xora.pipelines.pipeline_xora_video import XoraVideoPipeline
7
  from pathlib import Path
8
  from transformers import T5EncoderModel, T5Tokenizer
9
  import safetensors.torch
 
82
  "vae": vae,
83
  }
84
 
85
+ pipeline = XoraVideoPipeline(**submodel_dict).to("cuda")
86
 
87
  # Sample input
88
  num_inference_steps = 20
xora/models/transformers/symmetric_patchifier.py CHANGED
@@ -60,26 +60,19 @@ class Patchifier(ConfigMixin, ABC):
60
  return grid
61
 
62
 
63
- def pixart_alpha_patchify(
64
- latents: Tensor,
65
- patch_size: int,
66
- ) -> Tuple[Tensor, Tensor]:
67
- latents = rearrange(
68
- latents,
69
- "b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)",
70
- p1=patch_size[0],
71
- p2=patch_size[1],
72
- p3=patch_size[2],
73
- )
74
- return latents
75
-
76
-
77
  class SymmetricPatchifier(Patchifier):
78
  def patchify(
79
  self,
80
  latents: Tensor,
81
  ) -> Tuple[Tensor, Tensor]:
82
- return pixart_alpha_patchify(latents, self._patch_size)
 
 
 
 
 
 
 
83
 
84
  def unpatchify(
85
  self,
 
60
  return grid
61
 
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  class SymmetricPatchifier(Patchifier):
64
  def patchify(
65
  self,
66
  latents: Tensor,
67
  ) -> Tuple[Tensor, Tensor]:
68
+ latents = rearrange(
69
+ latents,
70
+ "b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)",
71
+ p1=self._patch_size[0],
72
+ p2=self._patch_size[1],
73
+ p3=self._patch_size[2],
74
+ )
75
+ return latents
76
 
77
  def unpatchify(
78
  self,
xora/models/transformers/transformer3d.py CHANGED
@@ -141,12 +141,10 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
141
  )
142
  self.proj_out = nn.Linear(inner_dim, self.out_channels)
143
 
144
- # 5. PixArt-Alpha blocks.
145
  self.adaln_single = AdaLayerNormSingle(
146
  inner_dim, use_additional_conditions=False
147
  )
148
  if adaptive_norm == "single_scale":
149
- # Use 4 channels instead of the 6 for the PixArt-Alpha scale + shift ada norm.
150
  self.adaln_single.linear = nn.Linear(inner_dim, 4 * inner_dim, bias=True)
151
 
152
  self.caption_projection = None
@@ -170,7 +168,7 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
170
  for block in self.transformer_blocks:
171
  block.set_use_tpu_flash_attention(self.device.type)
172
 
173
- def initialize(self, embedding_std: float, mode: Literal["xora", "pixart"]):
174
  def _basic_init(module):
175
  if isinstance(module, nn.Linear):
176
  torch.nn.init.xavier_uniform_(module.weight)
@@ -211,7 +209,6 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
211
  nn.init.normal_(self.caption_projection.linear_1.weight, std=embedding_std)
212
  nn.init.normal_(self.caption_projection.linear_1.weight, std=embedding_std)
213
 
214
- # Zero-out adaLN modulation layers in PixArt blocks:
215
  for block in self.transformer_blocks:
216
  if mode.lower() == "xora":
217
  nn.init.constant_(block.attn1.to_out[0].weight, 0)
 
141
  )
142
  self.proj_out = nn.Linear(inner_dim, self.out_channels)
143
 
 
144
  self.adaln_single = AdaLayerNormSingle(
145
  inner_dim, use_additional_conditions=False
146
  )
147
  if adaptive_norm == "single_scale":
 
148
  self.adaln_single.linear = nn.Linear(inner_dim, 4 * inner_dim, bias=True)
149
 
150
  self.caption_projection = None
 
168
  for block in self.transformer_blocks:
169
  block.set_use_tpu_flash_attention(self.device.type)
170
 
171
+ def initialize(self, embedding_std: float, mode: Literal["xora", "legacy"]):
172
  def _basic_init(module):
173
  if isinstance(module, nn.Linear):
174
  torch.nn.init.xavier_uniform_(module.weight)
 
209
  nn.init.normal_(self.caption_projection.linear_1.weight, std=embedding_std)
210
  nn.init.normal_(self.caption_projection.linear_1.weight, std=embedding_std)
211
 
 
212
  for block in self.transformer_blocks:
213
  if mode.lower() == "xora":
214
  nn.init.constant_(block.attn1.to_out[0].weight, 0)
xora/pipelines/{pipeline_video_pixart_alpha.py → pipeline_xora_video.py} RENAMED
@@ -1,4 +1,4 @@
1
- # # Adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py
2
  import html
3
  import inspect
4
  import math
@@ -19,7 +19,6 @@ from diffusers.utils import (
19
  is_bs4_available,
20
  is_ftfy_available,
21
  logging,
22
- replace_example_docstring,
23
  )
24
  from diffusers.utils.torch_utils import randn_tensor
25
  from einops import rearrange
@@ -44,22 +43,6 @@ if is_bs4_available():
44
  if is_ftfy_available():
45
  import ftfy
46
 
47
- EXAMPLE_DOC_STRING = """
48
- Examples:
49
- ```py
50
- >>> import torch
51
- >>> from diffusers import PixArtAlphaPipeline
52
-
53
- >>> # You can replace the checkpoint id with "PixArt-alpha/PixArt-XL-2-512x512" too.
54
- >>> pipe = PixArtAlphaPipeline.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", torch_dtype=torch.float16)
55
- >>> # Enable memory optimizations.
56
- >>> pipe.enable_model_cpu_offload()
57
-
58
- >>> prompt = "A small cactus with a happy face in the Sahara desert."
59
- >>> image = pipe(prompt).images[0]
60
- ```
61
- """
62
-
63
  ASPECT_RATIO_1024_BIN = {
64
  "0.25": [512.0, 2048.0],
65
  "0.28": [512.0, 1856.0],
@@ -180,9 +163,9 @@ def retrieve_timesteps(
180
  return timesteps, num_inference_steps
181
 
182
 
183
- class VideoPixArtAlphaPipeline(DiffusionPipeline):
184
  r"""
185
- Pipeline for text-to-image generation using PixArt-Alpha.
186
 
187
  This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
188
  library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
@@ -191,7 +174,7 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
191
  vae ([`AutoencoderKL`]):
192
  Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
193
  text_encoder ([`T5EncoderModel`]):
194
- Frozen text-encoder. PixArt-Alpha uses
195
  [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
196
  [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
197
  tokenizer (`T5Tokenizer`):
@@ -247,7 +230,6 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
247
  )
248
  self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
249
 
250
- # Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/utils.py
251
  def mask_text_embeddings(self, emb, mask):
252
  if emb.shape[0] == 1:
253
  keep_index = mask.sum().item()
@@ -280,7 +262,7 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
280
  negative_prompt (`str` or `List[str]`, *optional*):
281
  The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`
282
  instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For
283
- PixArt-Alpha, this should be "".
284
  do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
285
  whether to use classifier free guidance or not
286
  num_images_per_prompt (`int`, *optional*, defaults to 1):
@@ -291,8 +273,7 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
291
  Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
292
  provided, text embeddings will be generated from `prompt` input argument.
293
  negative_prompt_embeds (`torch.FloatTensor`, *optional*):
294
- Pre-generated negative text embeddings. For PixArt-Alpha, it's should be the embeddings of the ""
295
- string.
296
  clean_caption (bool, defaults to `False`):
297
  If `True`, the function will preprocess and clean the provided caption before encoding.
298
  """
@@ -753,7 +734,6 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
753
  return samples
754
 
755
  @torch.no_grad()
756
- @replace_example_docstring(EXAMPLE_DOC_STRING)
757
  def __call__(
758
  self,
759
  height: int,
@@ -824,7 +804,7 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
824
  provided, text embeddings will be generated from `prompt` input argument.
825
  prompt_attention_mask (`torch.FloatTensor`, *optional*): Pre-generated attention mask for text embeddings.
826
  negative_prompt_embeds (`torch.FloatTensor`, *optional*):
827
- Pre-generated negative text embeddings. For PixArt-Alpha this negative prompt should be "". If not
828
  provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
829
  negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):
830
  Pre-generated attention mask for negative text embeddings.
 
1
+ # Adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py
2
  import html
3
  import inspect
4
  import math
 
19
  is_bs4_available,
20
  is_ftfy_available,
21
  logging,
 
22
  )
23
  from diffusers.utils.torch_utils import randn_tensor
24
  from einops import rearrange
 
43
  if is_ftfy_available():
44
  import ftfy
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  ASPECT_RATIO_1024_BIN = {
47
  "0.25": [512.0, 2048.0],
48
  "0.28": [512.0, 1856.0],
 
163
  return timesteps, num_inference_steps
164
 
165
 
166
+ class XoraVideoPipeline(DiffusionPipeline):
167
  r"""
168
+ Pipeline for text-to-image generation using Xora.
169
 
170
  This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
171
  library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
 
174
  vae ([`AutoencoderKL`]):
175
  Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
176
  text_encoder ([`T5EncoderModel`]):
177
+ Frozen text-encoder. This uses
178
  [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
179
  [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
180
  tokenizer (`T5Tokenizer`):
 
230
  )
231
  self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
232
 
 
233
  def mask_text_embeddings(self, emb, mask):
234
  if emb.shape[0] == 1:
235
  keep_index = mask.sum().item()
 
262
  negative_prompt (`str` or `List[str]`, *optional*):
263
  The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`
264
  instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For
265
+ This should be "".
266
  do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
267
  whether to use classifier free guidance or not
268
  num_images_per_prompt (`int`, *optional*, defaults to 1):
 
273
  Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
274
  provided, text embeddings will be generated from `prompt` input argument.
275
  negative_prompt_embeds (`torch.FloatTensor`, *optional*):
276
+ Pre-generated negative text embeddings.
 
277
  clean_caption (bool, defaults to `False`):
278
  If `True`, the function will preprocess and clean the provided caption before encoding.
279
  """
 
734
  return samples
735
 
736
  @torch.no_grad()
 
737
  def __call__(
738
  self,
739
  height: int,
 
804
  provided, text embeddings will be generated from `prompt` input argument.
805
  prompt_attention_mask (`torch.FloatTensor`, *optional*): Pre-generated attention mask for text embeddings.
806
  negative_prompt_embeds (`torch.FloatTensor`, *optional*):
807
+ Pre-generated negative text embeddings. This negative prompt should be "". If not
808
  provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
809
  negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):
810
  Pre-generated attention mask for negative text embeddings.