Sapir commited on
Commit
b6c994f
1 Parent(s): de2eaeb

README: added inference + installation guidelines, inference clearer.

Browse files
README.md CHANGED
@@ -1 +1,70 @@
1
- # xora-core
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+
3
+ # Xora️
4
+ </div>
5
+
6
+ This is the official repository for Xora.
7
+
8
+ ## Table of Contents
9
+
10
+ * [Introduction](#introduction)
11
+ * [Installation](#installation)
12
+ * [Inference](#inference)
13
+ * [Inference Code](#inference-code)
14
+ * [Acknowledgement](#acknowledgement)
15
+
16
+ ## Introduction
17
+
18
+ The performance of Diffusion Transformers is heavily influenced by the number of generated latent pixels (or tokens). In video generation, the token count becomes substantial as the number of frames increases. To address this, we designed a carefully optimized VAE that compresses videos into a smaller number of tokens while utilizing a deeper latent space. This approach enables our model to generate high-quality 768x512 videos at 24 FPS, achieving near real-time speeds.
19
+
20
+ ## Installation
21
+
22
+ # Setup
23
+ The codebase currently uses Python 3.10.5, CUDA version 12.2, and supports PyTorch >= 2.1.2.
24
+
25
+
26
+ ```bash
27
+ git clone https://github.com/LightricksResearch/xora-core.git
28
+ cd xora-core
29
+
30
+ # create env
31
+ python -m venv env
32
+ source env/bin/activate
33
+ python -m pip install -e .\[inference-script\]
34
+ ```
35
+
36
+ Then, download the model from [Hugging Face](https://huggingface.co/Lightricks/Xora)
37
+
38
+ ```python
39
+ from huggingface_hub import snapshot_download
40
+
41
+ model_path = 'PATH' # The local directory to save downloaded checkpoint
42
+ snapshot_download("Lightricks/Orah", local_dir=model_path, local_dir_use_symlinks=False, repo_type='model')
43
+ ```
44
+
45
+ ## Inference
46
+
47
+ ### Inference Code
48
+
49
+ To use our model, please follow the inference code in `inference.py` at [https://github.com/LightricksResearch/xora-core/blob/main/inference.py]():
50
+
51
+ For text-to-video generation:
52
+
53
+ ```bash
54
+ python inference.py --ckpt_dir 'PATH' --prompt "PROMPT" --height HEIGHT --width WIDTH
55
+ ```
56
+
57
+ For image-to-video generation:
58
+
59
+ ```python
60
+ python inference.py --ckpt_dir 'PATH' --prompt "PROMPT" --input_image_path IMAGE_PATH --height HEIGHT --width WIDTH
61
+
62
+ ```
63
+
64
+ ## Acknowledgement
65
+
66
+ We are grateful for the following awesome projects when implementing Xora:
67
+ * [DiT](https://github.com/facebookresearch/DiT) and [PixArt-alpha](https://github.com/PixArt-alpha/PixArt-alpha): vision transformers for image generation.
68
+
69
+
70
+ [//]: # (## Citation)
xora/examples/image_to_video.py → inference.py RENAMED
@@ -16,9 +16,39 @@ import cv2
16
  from PIL import Image
17
  import random
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  def load_vae(vae_dir):
21
- vae_ckpt_path = vae_dir / "diffusion_pytorch_model.safetensors"
22
  vae_config_path = vae_dir / "config.json"
23
  with open(vae_config_path, "r") as f:
24
  vae_config = json.load(f)
@@ -29,7 +59,7 @@ def load_vae(vae_dir):
29
 
30
 
31
  def load_unet(unet_dir):
32
- unet_ckpt_path = unet_dir / "diffusion_pytorch_model.safetensors"
33
  unet_config_path = unet_dir / "config.json"
34
  transformer_config = Transformer3DModel.load_config(unet_config_path)
35
  transformer = Transformer3DModel.from_config(transformer_config)
@@ -60,7 +90,7 @@ def center_crop_and_resize(frame, target_height, target_width):
60
  return frame_resized
61
 
62
 
63
- def load_video_to_tensor_with_resize(video_path, target_height=512, target_width=768):
64
  cap = cv2.VideoCapture(video_path)
65
  frames = []
66
  while True:
@@ -68,7 +98,12 @@ def load_video_to_tensor_with_resize(video_path, target_height=512, target_width
68
  if not ret:
69
  break
70
  frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
71
- frame_resized = center_crop_and_resize(frame_rgb, target_height, target_width)
 
 
 
 
 
72
  frames.append(frame_resized)
73
  cap.release()
74
  video_np = (np.array(frames) / 127.5) - 1.0
@@ -99,9 +134,19 @@ def main():
99
  help="Path to the directory containing unet, vae, and scheduler subdirectories",
100
  )
101
  parser.add_argument(
102
- "--video_path", type=str, help="Path to the input video file (first frame used)"
 
 
 
 
 
 
 
 
 
 
 
103
  )
104
- parser.add_argument("--image_path", type=str, help="Path to the input image file")
105
  parser.add_argument("--seed", type=int, default="171198")
106
 
107
  # Pipeline parameters
@@ -121,10 +166,16 @@ def main():
121
  help="Guidance scale for the pipeline",
122
  )
123
  parser.add_argument(
124
- "--height", type=int, default=512, help="Height of the output video frames"
 
 
 
125
  )
126
  parser.add_argument(
127
- "--width", type=int, default=768, help="Width of the output video frames"
 
 
 
128
  )
129
  parser.add_argument(
130
  "--num_frames",
@@ -136,12 +187,6 @@ def main():
136
  "--frame_rate", type=int, default=25, help="Frame rate for the output video"
137
  )
138
 
139
- parser.add_argument(
140
- "--mixed_precision",
141
- action="store_true",
142
- help="Mixed precision in float32 and bfloat16",
143
- )
144
-
145
  parser.add_argument(
146
  "--bfloat16",
147
  action="store_true",
@@ -152,7 +197,6 @@ def main():
152
  parser.add_argument(
153
  "--prompt",
154
  type=str,
155
- default='A man wearing a black leather jacket and blue jeans is riding a Harley Davidson motorcycle down a paved road. The man has short brown hair and is wearing a black helmet. The motorcycle is a dark red color with a large front fairing. The road is surrounded by green grass and trees. There is a gas station on the left side of the road with a red and white sign that says "Oil" and "Diner".',
156
  help="Text prompt to guide generation",
157
  )
158
  parser.add_argument(
@@ -161,9 +205,42 @@ def main():
161
  default="worst quality, inconsistent motion, blurry, jittery, distorted",
162
  help="Negative prompt for undesired features",
163
  )
 
 
 
 
 
 
164
 
165
  args = parser.parse_args()
166
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  # Paths for the separate mode directories
168
  ckpt_dir = Path(args.ckpt_dir)
169
  unet_dir = ckpt_dir / "unet"
@@ -197,18 +274,6 @@ def main():
197
 
198
  pipeline = XoraVideoPipeline(**submodel_dict).to("cuda")
199
 
200
- # Load media (video or image)
201
- if args.video_path:
202
- media_items = load_video_to_tensor_with_resize(
203
- args.video_path, args.height, args.width
204
- ).unsqueeze(0)
205
- elif args.image_path:
206
- media_items = load_image_to_tensor_with_resize(
207
- args.image_path, args.height, args.width
208
- )
209
- else:
210
- raise ValueError("Either --video_path or --image_path must be provided.")
211
-
212
  # Prepare input for the pipeline
213
  sample = {
214
  "prompt": args.prompt,
@@ -231,15 +296,19 @@ def main():
231
  generator=generator,
232
  output_type="pt",
233
  callback_on_step_end=None,
234
- height=args.height,
235
- width=args.width,
236
  num_frames=args.num_frames,
237
  frame_rate=args.frame_rate,
238
  **sample,
239
  is_video=True,
240
  vae_per_channel_normalize=True,
241
- conditioning_method=ConditioningMethod.FIRST_FRAME,
242
- mixed_precision=args.mixed_precision,
 
 
 
 
243
  ).images
244
 
245
  # Save output video
@@ -257,16 +326,29 @@ def main():
257
  video_np = (video_np * 255).astype(np.uint8)
258
  fps = args.frame_rate
259
  height, width = video_np.shape[1:3]
260
- output_filename = get_unique_filename(f"video_output_{i}", ".mp4", ".")
261
-
262
- out = cv2.VideoWriter(
263
- output_filename, cv2.VideoWriter_fourcc(*"mp4v"), fps, (width, height)
264
- )
265
-
266
- for frame in video_np[..., ::-1]:
267
- out.write(frame)
268
-
269
- out.release()
 
 
 
 
 
 
 
 
 
 
 
 
 
270
 
271
 
272
  if __name__ == "__main__":
 
16
  from PIL import Image
17
  import random
18
 
19
+ RECOMMENDED_RESOLUTIONS = [
20
+ (704, 1216, 41),
21
+ (704, 1088, 49),
22
+ (640, 1056, 57),
23
+ (608, 992, 65),
24
+ (608, 896, 73),
25
+ (544, 896, 81),
26
+ (544, 832, 89),
27
+ (512, 800, 97),
28
+ (512, 768, 97),
29
+ (480, 800, 105),
30
+ (480, 736, 113),
31
+ (480, 704, 121),
32
+ (448, 704, 129),
33
+ (448, 672, 137),
34
+ (416, 640, 153),
35
+ (384, 672, 161),
36
+ (384, 640, 169),
37
+ (384, 608, 177),
38
+ (384, 576, 185),
39
+ (352, 608, 193),
40
+ (352, 576, 201),
41
+ (352, 544, 209),
42
+ (352, 512, 225),
43
+ (352, 512, 233),
44
+ (320, 544, 241),
45
+ (320, 512, 249),
46
+ (320, 512, 257),
47
+ ]
48
+
49
 
50
  def load_vae(vae_dir):
51
+ vae_ckpt_path = vae_dir / "vae_diffusion_pytorch_model.safetensors"
52
  vae_config_path = vae_dir / "config.json"
53
  with open(vae_config_path, "r") as f:
54
  vae_config = json.load(f)
 
59
 
60
 
61
  def load_unet(unet_dir):
62
+ unet_ckpt_path = unet_dir / "unet_diffusion_pytorch_model.safetensors"
63
  unet_config_path = unet_dir / "config.json"
64
  transformer_config = Transformer3DModel.load_config(unet_config_path)
65
  transformer = Transformer3DModel.from_config(transformer_config)
 
90
  return frame_resized
91
 
92
 
93
+ def load_video_to_tensor_with_resize(video_path, target_height, target_width):
94
  cap = cv2.VideoCapture(video_path)
95
  frames = []
96
  while True:
 
98
  if not ret:
99
  break
100
  frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
101
+ if target_height is not None:
102
+ frame_resized = center_crop_and_resize(
103
+ frame_rgb, target_height, target_width
104
+ )
105
+ else:
106
+ frame_resized = frame_rgb
107
  frames.append(frame_resized)
108
  cap.release()
109
  video_np = (np.array(frames) / 127.5) - 1.0
 
134
  help="Path to the directory containing unet, vae, and scheduler subdirectories",
135
  )
136
  parser.add_argument(
137
+ "--input_video_path",
138
+ type=str,
139
+ help="Path to the input video file (first frame used)",
140
+ )
141
+ parser.add_argument(
142
+ "--input_image_path", type=str, help="Path to the input image file"
143
+ )
144
+ parser.add_argument(
145
+ "--output_path",
146
+ type=str,
147
+ default=None,
148
+ help="Path to save output video, if None will save in working directory.",
149
  )
 
150
  parser.add_argument("--seed", type=int, default="171198")
151
 
152
  # Pipeline parameters
 
166
  help="Guidance scale for the pipeline",
167
  )
168
  parser.add_argument(
169
+ "--height",
170
+ type=int,
171
+ default=None,
172
+ help="Height of the output video frames. Optional if an input image provided.",
173
  )
174
  parser.add_argument(
175
+ "--width",
176
+ type=int,
177
+ default=None,
178
+ help="Width of the output video frames. If None will infer from input image.",
179
  )
180
  parser.add_argument(
181
  "--num_frames",
 
187
  "--frame_rate", type=int, default=25, help="Frame rate for the output video"
188
  )
189
 
 
 
 
 
 
 
190
  parser.add_argument(
191
  "--bfloat16",
192
  action="store_true",
 
197
  parser.add_argument(
198
  "--prompt",
199
  type=str,
 
200
  help="Text prompt to guide generation",
201
  )
202
  parser.add_argument(
 
205
  default="worst quality, inconsistent motion, blurry, jittery, distorted",
206
  help="Negative prompt for undesired features",
207
  )
208
+ parser.add_argument(
209
+ "--custom_resolution",
210
+ action="store_true",
211
+ default=False,
212
+ help="Enable custom resolution (not in recommneded resolutions) if specified (default: False)",
213
+ )
214
 
215
  args = parser.parse_args()
216
 
217
+ if args.input_image_path is None and args.input_video_path is None:
218
+ assert (
219
+ args.height is not None and args.width is not None
220
+ ), "Must enter height and width for text to image generation."
221
+
222
+ # Load media (video or image)
223
+ if args.input_video_path:
224
+ media_items = load_video_to_tensor_with_resize(
225
+ args.input_video_path, args.height, args.width
226
+ ).unsqueeze(0)
227
+ elif args.input_image_path:
228
+ media_items = load_image_to_tensor_with_resize(
229
+ args.input_image_path, args.height, args.width
230
+ )
231
+ else:
232
+ media_items = None
233
+
234
+ height = args.height if args.height else media_items.shape[-2]
235
+ width = args.width if args.width else media_items.shape[-1]
236
+ assert height % 32 == 0, f"Height ({height}) should be divisible by 32."
237
+ assert width % 32 == 0, f"Width ({width}) should be divisible by 32."
238
+ assert (
239
+ height,
240
+ width,
241
+ args.num_frames,
242
+ ) in RECOMMENDED_RESOLUTIONS or args.custom_resolution, f"The selected resolution + num frames combination is not supported, results would be suboptimal. Supported (h,w,f) are: {RECOMMENDED_RESOLUTIONS}. Use --custom_resolution to enable working with this resolution."
243
+
244
  # Paths for the separate mode directories
245
  ckpt_dir = Path(args.ckpt_dir)
246
  unet_dir = ckpt_dir / "unet"
 
274
 
275
  pipeline = XoraVideoPipeline(**submodel_dict).to("cuda")
276
 
 
 
 
 
 
 
 
 
 
 
 
 
277
  # Prepare input for the pipeline
278
  sample = {
279
  "prompt": args.prompt,
 
296
  generator=generator,
297
  output_type="pt",
298
  callback_on_step_end=None,
299
+ height=height,
300
+ width=width,
301
  num_frames=args.num_frames,
302
  frame_rate=args.frame_rate,
303
  **sample,
304
  is_video=True,
305
  vae_per_channel_normalize=True,
306
+ conditioning_method=(
307
+ ConditioningMethod.FIRST_FRAME
308
+ if media_items is not None
309
+ else ConditioningMethod.UNCONDITIONAL
310
+ ),
311
+ mixed_precision=not args.bfloat16,
312
  ).images
313
 
314
  # Save output video
 
326
  video_np = (video_np * 255).astype(np.uint8)
327
  fps = args.frame_rate
328
  height, width = video_np.shape[1:3]
329
+ if video_np.shape[0] == 1:
330
+ output_filename = (
331
+ args.output_path
332
+ if args.output_path is not None
333
+ else get_unique_filename(f"image_output_{i}", ".png", ".")
334
+ )
335
+ cv2.imwrite(
336
+ output_filename, video_np[0][..., ::-1]
337
+ ) # Save single frame as image
338
+ else:
339
+ output_filename = (
340
+ args.output_path
341
+ if args.output_path is not None
342
+ else get_unique_filename(f"video_output_{i}", ".mp4", ".")
343
+ )
344
+
345
+ out = cv2.VideoWriter(
346
+ output_filename, cv2.VideoWriter_fourcc(*"mp4v"), fps, (width, height)
347
+ )
348
+
349
+ for frame in video_np[..., ::-1]:
350
+ out.write(frame)
351
+ out.release()
352
 
353
 
354
  if __name__ == "__main__":
scripts/to_safetensors.py CHANGED
@@ -100,10 +100,10 @@ def main(
100
 
101
  # Save unet and vae safetensors with the name diffusion_pytorch_model.safetensors
102
  safetensors.torch.save_file(
103
- unet, unet_dir / "diffusion_pytorch_model.safetensors"
104
  )
105
  safetensors.torch.save_file(
106
- vae, vae_dir / "diffusion_pytorch_model.safetensors"
107
  )
108
 
109
  # Save config files for unet, vae, and scheduler
 
100
 
101
  # Save unet and vae safetensors with the name diffusion_pytorch_model.safetensors
102
  safetensors.torch.save_file(
103
+ unet, unet_dir / "unet_diffusion_pytorch_model.safetensors"
104
  )
105
  safetensors.torch.save_file(
106
+ vae, vae_dir / "vae_diffusion_pytorch_model.safetensors"
107
  )
108
 
109
  # Save config files for unet, vae, and scheduler
xora/examples/text_to_video.py DELETED
@@ -1,138 +0,0 @@
1
- import torch
2
- from xora.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
3
- from xora.models.transformers.transformer3d import Transformer3DModel
4
- from xora.models.transformers.symmetric_patchifier import SymmetricPatchifier
5
- from xora.schedulers.rf import RectifiedFlowScheduler
6
- from xora.pipelines.pipeline_xora_video import XoraVideoPipeline
7
- from pathlib import Path
8
- from transformers import T5EncoderModel, T5Tokenizer
9
- import safetensors.torch
10
- import json
11
- import argparse
12
-
13
-
14
- def load_vae(vae_dir):
15
- vae_ckpt_path = vae_dir / "diffusion_pytorch_model.safetensors"
16
- vae_config_path = vae_dir / "config.json"
17
- with open(vae_config_path, "r") as f:
18
- vae_config = json.load(f)
19
- vae = CausalVideoAutoencoder.from_config(vae_config)
20
- vae_state_dict = safetensors.torch.load_file(vae_ckpt_path)
21
- vae.load_state_dict(vae_state_dict)
22
- return vae.cuda().to(torch.bfloat16)
23
-
24
-
25
- def load_unet(unet_dir):
26
- unet_ckpt_path = unet_dir / "diffusion_pytorch_model.safetensors"
27
- unet_config_path = unet_dir / "config.json"
28
- transformer_config = Transformer3DModel.load_config(unet_config_path)
29
- transformer = Transformer3DModel.from_config(transformer_config)
30
- unet_state_dict = safetensors.torch.load_file(unet_ckpt_path)
31
- transformer.load_state_dict(unet_state_dict, strict=True)
32
- return transformer.cuda()
33
-
34
-
35
- def load_scheduler(scheduler_dir):
36
- scheduler_config_path = scheduler_dir / "scheduler_config.json"
37
- scheduler_config = RectifiedFlowScheduler.load_config(scheduler_config_path)
38
- return RectifiedFlowScheduler.from_config(scheduler_config)
39
-
40
-
41
- def main():
42
- # Parse command line arguments
43
- parser = argparse.ArgumentParser(
44
- description="Load models from separate directories"
45
- )
46
- parser.add_argument(
47
- "--separate_dir",
48
- type=str,
49
- required=True,
50
- help="Path to the directory containing unet, vae, and scheduler subdirectories",
51
- )
52
- parser.add_argument(
53
- "--mixed_precision",
54
- action="store_true",
55
- help="Mixed precision in float32 and bfloat16",
56
- )
57
- parser.add_argument(
58
- "--bfloat16",
59
- action="store_true",
60
- help="Denoise in bfloat16",
61
- )
62
- args = parser.parse_args()
63
-
64
- # Paths for the separate mode directories
65
- separate_dir = Path(args.separate_dir)
66
- unet_dir = separate_dir / "unet"
67
- vae_dir = separate_dir / "vae"
68
- scheduler_dir = separate_dir / "scheduler"
69
-
70
- # Load models
71
- vae = load_vae(vae_dir)
72
- unet = load_unet(unet_dir)
73
- scheduler = load_scheduler(scheduler_dir)
74
-
75
- # Patchifier (remains the same)
76
- patchifier = SymmetricPatchifier(patch_size=1)
77
-
78
- text_encoder = T5EncoderModel.from_pretrained(
79
- "PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="text_encoder"
80
- ).to("cuda")
81
- tokenizer = T5Tokenizer.from_pretrained(
82
- "PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="tokenizer"
83
- )
84
-
85
- if args.bfloat16 and unet.dtype != torch.bfloat16:
86
- unet = unet.to(torch.bfloat16)
87
-
88
- # Use submodels for the pipeline
89
- submodel_dict = {
90
- "transformer": unet, # using unet for transformer
91
- "patchifier": patchifier,
92
- "scheduler": scheduler,
93
- "text_encoder": text_encoder,
94
- "tokenizer": tokenizer,
95
- "vae": vae,
96
- }
97
-
98
- pipeline = XoraVideoPipeline(**submodel_dict).to("cuda")
99
-
100
- # Sample input
101
- num_inference_steps = 20
102
- num_images_per_prompt = 2
103
- guidance_scale = 3
104
- height = 512
105
- width = 768
106
- num_frames = 57
107
- frame_rate = 25
108
- sample = {
109
- "prompt": "A middle-aged man with glasses and a salt-and-pepper beard is driving a car and talking, gesturing with his right hand. "
110
- "The man is wearing a dark blue zip-up jacket and a light blue collared shirt. He is sitting in the driver's seat of a car with a black interior. The car is moving on a road with trees and bushes on either side. The man has a serious expression on his face and is looking straight ahead.",
111
- "prompt_attention_mask": None, # Adjust attention masks as needed
112
- "negative_prompt": "Ugly deformed",
113
- "negative_prompt_attention_mask": None,
114
- }
115
-
116
- # Generate images (video frames)
117
- _ = pipeline(
118
- num_inference_steps=num_inference_steps,
119
- num_images_per_prompt=num_images_per_prompt,
120
- guidance_scale=guidance_scale,
121
- generator=None,
122
- output_type="pt",
123
- callback_on_step_end=None,
124
- height=height,
125
- width=width,
126
- num_frames=num_frames,
127
- frame_rate=frame_rate,
128
- **sample,
129
- is_video=True,
130
- vae_per_channel_normalize=True,
131
- mixed_precision=args.mixed_precision,
132
- ).images
133
-
134
- print("Generated images (video frames).")
135
-
136
-
137
- if __name__ == "__main__":
138
- main()