frankleeeee commited on
Commit
011cc63
1 Parent(s): f58f053

updated to opensora v1.1

Browse files
Files changed (35) hide show
  1. README copy.md +0 -13
  2. app.py +313 -61
  3. configs/dit/inference/16x256x256.py +2 -2
  4. configs/dit/inference/1x256x256-class.py +2 -2
  5. configs/dit/inference/1x256x256.py +2 -2
  6. configs/dit/train/16x256x256.py +9 -9
  7. configs/dit/train/1x256x256.py +9 -8
  8. configs/latte/inference/16x256x256-class.py +2 -2
  9. configs/latte/inference/16x256x256.py +2 -2
  10. configs/latte/train/16x256x256.py +8 -8
  11. configs/opensora-v1-1/inference/sample-ref.py +62 -0
  12. configs/opensora-v1-1/inference/sample.py +43 -0
  13. configs/opensora-v1-1/train/benchmark.py +101 -0
  14. configs/opensora-v1-1/train/image.py +65 -0
  15. configs/opensora-v1-1/train/stage1.py +77 -0
  16. configs/opensora-v1-1/train/stage2.py +79 -0
  17. configs/opensora-v1-1/train/stage3.py +79 -0
  18. configs/opensora-v1-1/train/video.py +67 -0
  19. configs/opensora/inference/16x256x256.py +11 -6
  20. configs/opensora/inference/16x512x512.py +6 -6
  21. configs/opensora/inference/64x512x512.py +7 -7
  22. configs/opensora/train/16x256x256-mask.py +60 -0
  23. configs/opensora/train/16x256x256-spee.py +60 -0
  24. configs/opensora/train/16x256x256.py +9 -9
  25. configs/opensora/train/16x512x512.py +10 -10
  26. configs/opensora/train/360x512x512.py +14 -8
  27. configs/opensora/train/64x512x512-sp.py +10 -10
  28. configs/opensora/train/64x512x512.py +9 -9
  29. configs/pixart/inference/16x256x256.py +3 -3
  30. configs/pixart/inference/1x1024MS.py +4 -4
  31. configs/pixart/inference/1x256x256.py +3 -3
  32. configs/pixart/inference/1x512x512.py +10 -4
  33. configs/pixart/train/16x256x256.py +10 -10
  34. configs/pixart/train/1x512x512.py +9 -9
  35. configs/pixart/train/64x512x512.py +10 -9
README copy.md DELETED
@@ -1,13 +0,0 @@
1
- ---
2
- title: Open Sora
3
- emoji: 📚
4
- colorFrom: yellow
5
- colorTo: indigo
6
- sdk: gradio
7
- sdk_version: 4.21.0
8
- app_file: app.py
9
- pinned: false
10
- license: apache-2.0
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -11,24 +11,146 @@ import importlib
11
  import os
12
  import subprocess
13
  import sys
 
 
 
 
14
  import spaces
15
- import gradio as gr
16
  import torch
17
 
 
18
 
19
 
20
- MODEL_TYPES = ["v1-16x256x256", "v1-HQ-16x256x256", "v1-HQ-16x512x512"]
21
  CONFIG_MAP = {
22
- "v1-16x256x256": "configs/opensora/inference/16x256x256.py",
23
- "v1-HQ-16x256x256": "configs/opensora/inference/16x256x256.py",
24
- "v1-HQ-16x512x512": "configs/opensora/inference/16x512x512.py",
25
  }
26
  HF_STDIT_MAP = {
27
- "v1-16x256x256": "hpcai-tech/OpenSora-STDiT-v1-16x256x256",
28
- "v1-HQ-16x256x256": "hpcai-tech/OpenSora-STDiT-v1-HQ-16x256x256",
29
- "v1-HQ-16x512x512": "hpcai-tech/OpenSora-STDiT-v1-HQ-16x512x512",
30
  }
 
 
 
 
 
 
 
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  def install_dependencies(enable_optimization=False):
33
  """
34
  Install the required dependencies for the demo if they are not already installed.
@@ -70,14 +192,20 @@ def install_dependencies(enable_optimization=False):
70
  shell=True,
71
  )
72
 
 
 
 
 
73
  def read_config(config_path):
74
  """
75
  Read the configuration file.
76
  """
77
  from mmengine.config import Config
 
78
  return Config.fromfile(config_path)
79
 
80
- def build_models(model_type, config):
 
81
  """
82
  Build the models for the given model type and configuration.
83
  """
@@ -87,7 +215,7 @@ def build_models(model_type, config):
87
  vae = build_module(config.vae, MODELS).cuda()
88
 
89
  # build text encoder
90
- text_encoder = build_module(config.text_encoder, MODELS) # T5 must be fp32
91
  text_encoder.t5.model = text_encoder.t5.model.cuda()
92
 
93
  # build stdit
@@ -96,9 +224,8 @@ def build_models(model_type, config):
96
  from transformers import AutoModel
97
 
98
  stdit = AutoModel.from_pretrained(
99
- HF_STDIT_MAP[model_type],
100
- enable_flash_attn=False,
101
- enable_layernorm_kernel=False,
102
  trust_remote_code=True,
103
  ).cuda()
104
 
@@ -111,23 +238,20 @@ def build_models(model_type, config):
111
  text_encoder.y_embedder = stdit.y_embedder
112
 
113
  # move modelst to device
114
- vae = vae.to(torch.float16).eval()
115
  text_encoder.t5.model = text_encoder.t5.model.eval() # t5 must be in fp32
116
- stdit = stdit.to(torch.float16).eval()
117
- return vae, text_encoder, stdit, scheduler
118
-
119
 
120
- def get_latent_size(config, vae):
121
- input_size = (config.num_frames, *config.image_size)
122
- latent_size = vae.get_latent_size(input_size)
123
- return latent_size
124
 
125
 
126
  def parse_args():
127
  parser = argparse.ArgumentParser()
128
  parser.add_argument(
129
  "--model-type",
130
- default="v1-HQ-16x256x256",
131
  choices=MODEL_TYPES,
132
  help=f"The type of model to run for the Gradio App, can only be {MODEL_TYPES}",
133
  )
@@ -135,7 +259,11 @@ def parse_args():
135
  parser.add_argument("--port", default=None, type=int, help="The port to run the Gradio App on.")
136
  parser.add_argument("--host", default=None, type=str, help="The host to run the Gradio App on.")
137
  parser.add_argument("--share", action="store_true", help="Whether to share this gradio demo.")
138
- parser.add_argument("--enable-optimization", action="store_true", help="Whether to enable optimization such as flash attention and fused layernorm")
 
 
 
 
139
  return parser.parse_args()
140
 
141
 
@@ -160,26 +288,130 @@ torch.jit._state.disable()
160
  # set up
161
  install_dependencies(enable_optimization=args.enable_optimization)
162
 
 
 
 
 
 
 
 
 
163
  # build model
164
- vae, text_encoder, stdit, scheduler = build_models(args.model_type, config)
 
165
 
166
  @spaces.GPU(duration=200)
167
- def run_inference(prompt_text):
168
- from opensora.datasets import save_sample
169
-
170
- latent_size = get_latent_size(config, vae)
171
- samples = scheduler.sample(
172
- stdit,
173
- text_encoder,
174
- z_size=(vae.out_channels, *latent_size),
175
- prompts=[prompt_text],
176
- device="cuda",
177
- )
178
-
179
- samples = vae.decode(samples.to(torch.float16))
180
- filename = f"{args.output}/sample"
181
- saved_path = save_sample(samples[0], fps=config.fps, save_path=filename)
182
- return saved_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
 
184
  def main():
185
  # create demo
@@ -208,32 +440,52 @@ def main():
208
 
209
  with gr.Row():
210
  with gr.Column():
211
- prompt_text = gr.Textbox(show_label=False, placeholder="Describe your video here", lines=4)
212
- submit_button = gr.Button("Generate video")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
 
 
 
 
 
214
  with gr.Column():
215
- output_video = gr.Video()
216
-
217
- submit_button.click(fn=run_inference, inputs=[prompt_text], outputs=output_video)
218
-
219
- gr.Examples(
220
- examples=[
221
- [
222
- "The video captures the majestic beauty of a waterfall cascading down a cliff into a serene lake. The waterfall, with its powerful flow, is the central focus of the video. The surrounding landscape is lush and green, with trees and foliage adding to the natural beauty of the scene. The camera angle provides a bird's eye view of the waterfall, allowing viewers to appreciate the full height and grandeur of the waterfall. The video is a stunning representation of nature's power and beauty.",
223
- ],
224
- ],
225
- fn=run_inference,
226
- inputs=[
227
- prompt_text,
228
- ],
229
- outputs=[output_video],
230
- cache_examples=True,
231
- )
232
 
233
  # launch
234
  demo.launch(server_port=args.port, server_name=args.host, share=args.share)
235
 
236
 
237
- if __name__ == '__main__':
238
  main()
239
-
 
11
  import os
12
  import subprocess
13
  import sys
14
+ import re
15
+ import json
16
+ import math
17
+
18
  import spaces
 
19
  import torch
20
 
21
+ import gradio as gr
22
 
23
 
24
+ MODEL_TYPES = ["v1.1"]
25
  CONFIG_MAP = {
26
+ "v1.1-stage2": "configs/opensora-v1-1/inference/sample-ref.py",
27
+ "v1.1-stage3": "configs/opensora-v1-1/inference/sample-ref.py",
 
28
  }
29
  HF_STDIT_MAP = {
30
+ "v1.1-stage2": "hpcai-tech/OpenSora-STDiT-v2-stage2",
31
+ "v1.1-stage3": "hpcai-tech/OpenSora-STDiT-v2-stage3",
 
32
  }
33
+ RESOLUTION_MAP = {
34
+ "360p": (360, 480),
35
+ "480p": (480, 858),
36
+ "720p": (720, 1280),
37
+ "1080p": (1080, 1920)
38
+ }
39
+
40
 
41
+ # ============================
42
+ # Utils
43
+ # ============================
44
+ def collect_references_batch(reference_paths, vae, image_size):
45
+ from opensora.datasets.utils import read_from_path
46
+
47
+ refs_x = []
48
+ for reference_path in reference_paths:
49
+ if reference_path is None:
50
+ refs_x.append([])
51
+ continue
52
+ ref_path = reference_path.split(";")
53
+ ref = []
54
+ for r_path in ref_path:
55
+ r = read_from_path(r_path, image_size, transform_name="resize_crop")
56
+ r_x = vae.encode(r.unsqueeze(0).to(vae.device, vae.dtype))
57
+ r_x = r_x.squeeze(0)
58
+ ref.append(r_x)
59
+ refs_x.append(ref)
60
+ # refs_x: [batch, ref_num, C, T, H, W]
61
+ return refs_x
62
+
63
+
64
+ def process_mask_strategy(mask_strategy):
65
+ mask_batch = []
66
+ mask_strategy = mask_strategy.split(";")
67
+ for mask in mask_strategy:
68
+ mask_group = mask.split(",")
69
+ assert len(mask_group) >= 1 and len(mask_group) <= 6, f"Invalid mask strategy: {mask}"
70
+ if len(mask_group) == 1:
71
+ mask_group.extend(["0", "0", "0", "1", "0"])
72
+ elif len(mask_group) == 2:
73
+ mask_group.extend(["0", "0", "1", "0"])
74
+ elif len(mask_group) == 3:
75
+ mask_group.extend(["0", "1", "0"])
76
+ elif len(mask_group) == 4:
77
+ mask_group.extend(["1", "0"])
78
+ elif len(mask_group) == 5:
79
+ mask_group.append("0")
80
+ mask_batch.append(mask_group)
81
+ return mask_batch
82
+
83
+
84
+ def apply_mask_strategy(z, refs_x, mask_strategys, loop_i):
85
+ masks = []
86
+ for i, mask_strategy in enumerate(mask_strategys):
87
+ mask = torch.ones(z.shape[2], dtype=torch.float, device=z.device)
88
+ if mask_strategy is None:
89
+ masks.append(mask)
90
+ continue
91
+ mask_strategy = process_mask_strategy(mask_strategy)
92
+ for mst in mask_strategy:
93
+ loop_id, m_id, m_ref_start, m_target_start, m_length, edit_ratio = mst
94
+ loop_id = int(loop_id)
95
+ if loop_id != loop_i:
96
+ continue
97
+ m_id = int(m_id)
98
+ m_ref_start = int(m_ref_start)
99
+ m_length = int(m_length)
100
+ m_target_start = int(m_target_start)
101
+ edit_ratio = float(edit_ratio)
102
+ ref = refs_x[i][m_id] # [C, T, H, W]
103
+ if m_ref_start < 0:
104
+ m_ref_start = ref.shape[1] + m_ref_start
105
+ if m_target_start < 0:
106
+ # z: [B, C, T, H, W]
107
+ m_target_start = z.shape[2] + m_target_start
108
+ z[i, :, m_target_start : m_target_start + m_length] = ref[:, m_ref_start : m_ref_start + m_length]
109
+ mask[m_target_start : m_target_start + m_length] = edit_ratio
110
+ masks.append(mask)
111
+ masks = torch.stack(masks)
112
+ return masks
113
+
114
+
115
+ def process_prompts(prompts, num_loop):
116
+ from opensora.models.text_encoder.t5 import text_preprocessing
117
+
118
+ ret_prompts = []
119
+ for prompt in prompts:
120
+ if prompt.startswith("|0|"):
121
+ prompt_list = prompt.split("|")[1:]
122
+ text_list = []
123
+ for i in range(0, len(prompt_list), 2):
124
+ start_loop = int(prompt_list[i])
125
+ text = prompt_list[i + 1]
126
+ text = text_preprocessing(text)
127
+ end_loop = int(prompt_list[i + 2]) if i + 2 < len(prompt_list) else num_loop
128
+ text_list.extend([text] * (end_loop - start_loop))
129
+ assert len(text_list) == num_loop, f"Prompt loop mismatch: {len(text_list)} != {num_loop}"
130
+ ret_prompts.append(text_list)
131
+ else:
132
+ prompt = text_preprocessing(prompt)
133
+ ret_prompts.append([prompt] * num_loop)
134
+ return ret_prompts
135
+
136
+
137
+ def extract_json_from_prompts(prompts):
138
+ additional_infos = []
139
+ ret_prompts = []
140
+ for prompt in prompts:
141
+ parts = re.split(r"(?=[{\[])", prompt)
142
+ assert len(parts) <= 2, f"Invalid prompt: {prompt}"
143
+ ret_prompts.append(parts[0])
144
+ if len(parts) == 1:
145
+ additional_infos.append({})
146
+ else:
147
+ additional_infos.append(json.loads(parts[1]))
148
+ return ret_prompts, additional_infos
149
+
150
+
151
+ # ============================
152
+ # Runtime Environment
153
+ # ============================
154
  def install_dependencies(enable_optimization=False):
155
  """
156
  Install the required dependencies for the demo if they are not already installed.
 
192
  shell=True,
193
  )
194
 
195
+
196
+ # ============================
197
+ # Model-related
198
+ # ============================
199
  def read_config(config_path):
200
  """
201
  Read the configuration file.
202
  """
203
  from mmengine.config import Config
204
+
205
  return Config.fromfile(config_path)
206
 
207
+
208
+ def build_models(model_type, config, enable_optimization=False):
209
  """
210
  Build the models for the given model type and configuration.
211
  """
 
215
  vae = build_module(config.vae, MODELS).cuda()
216
 
217
  # build text encoder
218
+ text_encoder = build_module(config.text_encoder, MODELS) # T5 must be fp32
219
  text_encoder.t5.model = text_encoder.t5.model.cuda()
220
 
221
  # build stdit
 
224
  from transformers import AutoModel
225
 
226
  stdit = AutoModel.from_pretrained(
227
+ HF_STDIT_MAP[model_type],
228
+ enable_flash_attn=enable_optimization,
 
229
  trust_remote_code=True,
230
  ).cuda()
231
 
 
238
  text_encoder.y_embedder = stdit.y_embedder
239
 
240
  # move modelst to device
241
+ vae = vae.to(torch.bfloat16).eval()
242
  text_encoder.t5.model = text_encoder.t5.model.eval() # t5 must be in fp32
243
+ stdit = stdit.to(torch.bfloat16).eval()
 
 
244
 
245
+ # clear cuda
246
+ torch.cuda.empty_cache()
247
+ return vae, text_encoder, stdit, scheduler
 
248
 
249
 
250
  def parse_args():
251
  parser = argparse.ArgumentParser()
252
  parser.add_argument(
253
  "--model-type",
254
+ default="v1.1-stage3",
255
  choices=MODEL_TYPES,
256
  help=f"The type of model to run for the Gradio App, can only be {MODEL_TYPES}",
257
  )
 
259
  parser.add_argument("--port", default=None, type=int, help="The port to run the Gradio App on.")
260
  parser.add_argument("--host", default=None, type=str, help="The host to run the Gradio App on.")
261
  parser.add_argument("--share", action="store_true", help="Whether to share this gradio demo.")
262
+ parser.add_argument(
263
+ "--enable-optimization",
264
+ action="store_true",
265
+ help="Whether to enable optimization such as flash attention and fused layernorm",
266
+ )
267
  return parser.parse_args()
268
 
269
 
 
288
  # set up
289
  install_dependencies(enable_optimization=args.enable_optimization)
290
 
291
+ # import after installation
292
+ from opensora.datasets import IMG_FPS, save_sample
293
+ from opensora.utils.misc import to_torch_dtype
294
+
295
+ # some global variables
296
+ dtype = to_torch_dtype(config.dtype)
297
+ device = torch.device("cuda")
298
+
299
  # build model
300
+ vae, text_encoder, stdit, scheduler = build_models(args.model_type, config, enable_optimization=args.enable_optimization)
301
+
302
 
303
  @spaces.GPU(duration=200)
304
+ def run_inference(mode, prompt_text, resolution, length, reference_image):
305
+ with torch.inference_mode():
306
+ # ======================
307
+ # 1. Preparation
308
+ # ======================
309
+ # parse the inputs
310
+ resolution = RESOLUTION_MAP[resolution]
311
+
312
+ # compute number of loops
313
+ num_seconds = int(length.rstrip('s'))
314
+ total_number_of_frames = num_seconds * config.fps / config.frame_interval
315
+ num_loop = math.ceil(total_number_of_frames / config.num_frames)
316
+
317
+ # prepare model args
318
+ model_args = dict()
319
+ height = torch.tensor([resolution[0]], device=device, dtype=dtype)
320
+ width = torch.tensor([resolution[1]], device=device, dtype=dtype)
321
+ num_frames = torch.tensor([config.num_frames], device=device, dtype=dtype)
322
+ ar = torch.tensor([resolution[0] / resolution[1]], device=device, dtype=dtype)
323
+ if config.num_frames == 1:
324
+ config.fps = IMG_FPS
325
+ fps = torch.tensor([config.fps], device=device, dtype=dtype)
326
+ model_args["height"] = height
327
+ model_args["width"] = width
328
+ model_args["num_frames"] = num_frames
329
+ model_args["ar"] = ar
330
+ model_args["fps"] = fps
331
+
332
+ # compute latent size
333
+ input_size = (config.num_frames, *resolution)
334
+ latent_size = vae.get_latent_size(input_size)
335
+
336
+ # process prompt
337
+ prompt_raw = [prompt_text]
338
+ prompt_raw, _ = extract_json_from_prompts(prompt_raw)
339
+ prompt_loops = process_prompts(prompt_raw, num_loop)
340
+ video_clips = []
341
+
342
+ # prepare mask strategy
343
+ if mode == "Text2Video":
344
+ mask_strategy = [None]
345
+ elif mode == "Image2Video":
346
+ mask_strategy = ['0']
347
+ else:
348
+ raise ValueError(f"Invalid mode: {mode}")
349
+
350
+ # =========================
351
+ # 2. Load reference images
352
+ # =========================
353
+ if mode == "Text2Video":
354
+ refs_x = collect_references_batch([None], vae, resolution)
355
+ elif mode == "Image2Video":
356
+ # save image to disk
357
+ from PIL import Image
358
+ im = Image.fromarray(reference_image)
359
+ im.save("test.jpg")
360
+ refs_x = collect_references_batch(["test.jpg"], vae, resolution)
361
+ else:
362
+ raise ValueError(f"Invalid mode: {mode}")
363
+
364
+ # 4.3. long video generation
365
+ for loop_i in range(num_loop):
366
+ # 4.4 sample in hidden space
367
+ batch_prompts = [prompt[loop_i] for prompt in prompt_loops]
368
+ z = torch.randn(len(batch_prompts), vae.out_channels, *latent_size, device=device, dtype=dtype)
369
+
370
+ # 4.5. apply mask strategy
371
+ masks = None
372
+
373
+ # if cfg.reference_path is not None:
374
+ if loop_i > 0:
375
+ ref_x = vae.encode(video_clips[-1])
376
+ for j, refs in enumerate(refs_x):
377
+ if refs is None:
378
+ refs_x[j] = [ref_x[j]]
379
+ else:
380
+ refs.append(ref_x[j])
381
+ if mask_strategy[j] is None:
382
+ mask_strategy[j] = ""
383
+ else:
384
+ mask_strategy[j] += ";"
385
+ mask_strategy[
386
+ j
387
+ ] += f"{loop_i},{len(refs)-1},-{config.condition_frame_length},0,{config.condition_frame_length}"
388
+
389
+ masks = apply_mask_strategy(z, refs_x, mask_strategy, loop_i)
390
+
391
+ # 4.6. diffusion sampling
392
+ samples = scheduler.sample(
393
+ stdit,
394
+ text_encoder,
395
+ z=z,
396
+ prompts=batch_prompts,
397
+ device=device,
398
+ additional_args=model_args,
399
+ mask=masks, # scheduler must support mask
400
+ )
401
+ samples = vae.decode(samples.to(dtype))
402
+ video_clips.append(samples)
403
+
404
+ # 4.7. save video
405
+ if loop_i == num_loop - 1:
406
+ video_clips_list = [
407
+ video_clips[0][0]] + [video_clips[i][0][:, config.condition_frame_length :]
408
+ for i in range(1, num_loop)
409
+ ]
410
+ video = torch.cat(video_clips_list, dim=1)
411
+ save_path = f"{args.output}/sample"
412
+ saved_path = save_sample(video, fps=config.fps // config.frame_interval, save_path=save_path, force_video=True)
413
+ return saved_path
414
+
415
 
416
  def main():
417
  # create demo
 
440
 
441
  with gr.Row():
442
  with gr.Column():
443
+ mode = gr.Radio(
444
+ choices=["Text2Video", "Image2Video"],
445
+ value="Text2Video",
446
+ label="Usage",
447
+ info="Choose your usage scenario",
448
+ )
449
+ prompt_text = gr.Textbox(
450
+ label="Prompt",
451
+ placeholder="Describe your video here",
452
+ lines=4,
453
+ )
454
+ resolution = gr.Radio(
455
+ choices=["360p", "480p", "720p", "1080p"],
456
+ value="360p",
457
+ label="Resolution",
458
+ )
459
+ length = gr.Radio(
460
+ choices=["2s", "4s", "8s"],
461
+ value="2s",
462
+ label="Video Length",
463
+ info="8s may fail as Hugging Face ZeroGPU has the limitation of max 200 seconds inference time."
464
+ )
465
 
466
+ reference_image = gr.Image(
467
+ label="Reference Image (only used for Image2Video)",
468
+ )
469
+
470
  with gr.Column():
471
+ output_video = gr.Video(
472
+ label="Output Video",
473
+ height="100%"
474
+ )
475
+
476
+ with gr.Row():
477
+ submit_button = gr.Button("Generate video")
478
+
479
+
480
+ submit_button.click(
481
+ fn=run_inference,
482
+ inputs=[mode, prompt_text, resolution, length, reference_image],
483
+ outputs=output_video
484
+ )
 
 
 
485
 
486
  # launch
487
  demo.launch(server_port=args.port, server_name=args.host, share=args.share)
488
 
489
 
490
+ if __name__ == "__main__":
491
  main()
 
configs/dit/inference/16x256x256.py CHANGED
@@ -22,10 +22,10 @@ scheduler = dict(
22
  num_sampling_steps=20,
23
  cfg_scale=4.0,
24
  )
25
- dtype = "fp16"
26
 
27
  # Others
28
  batch_size = 2
29
  seed = 42
30
  prompt_path = "./assets/texts/ucf101_labels.txt"
31
- save_dir = "./outputs/samples/"
 
22
  num_sampling_steps=20,
23
  cfg_scale=4.0,
24
  )
25
+ dtype = "bf16"
26
 
27
  # Others
28
  batch_size = 2
29
  seed = 42
30
  prompt_path = "./assets/texts/ucf101_labels.txt"
31
+ save_dir = "./samples/samples/"
configs/dit/inference/1x256x256-class.py CHANGED
@@ -22,10 +22,10 @@ scheduler = dict(
22
  num_sampling_steps=20,
23
  cfg_scale=4.0,
24
  )
25
- dtype = "fp16"
26
 
27
  # Others
28
  batch_size = 2
29
  seed = 42
30
  prompt_path = "./assets/texts/imagenet_id.txt"
31
- save_dir = "./outputs/samples/"
 
22
  num_sampling_steps=20,
23
  cfg_scale=4.0,
24
  )
25
+ dtype = "bf16"
26
 
27
  # Others
28
  batch_size = 2
29
  seed = 42
30
  prompt_path = "./assets/texts/imagenet_id.txt"
31
+ save_dir = "./samples/samples/"
configs/dit/inference/1x256x256.py CHANGED
@@ -23,10 +23,10 @@ scheduler = dict(
23
  num_sampling_steps=20,
24
  cfg_scale=4.0,
25
  )
26
- dtype = "fp16"
27
 
28
  # Others
29
  batch_size = 2
30
  seed = 42
31
  prompt_path = "./assets/texts/imagenet_labels.txt"
32
- save_dir = "./outputs/samples/"
 
23
  num_sampling_steps=20,
24
  cfg_scale=4.0,
25
  )
26
+ dtype = "bf16"
27
 
28
  # Others
29
  batch_size = 2
30
  seed = 42
31
  prompt_path = "./assets/texts/imagenet_labels.txt"
32
+ save_dir = "./samples/samples/"
configs/dit/train/16x256x256.py CHANGED
@@ -1,16 +1,16 @@
1
- num_frames = 16
2
- frame_interval = 3
3
- image_size = (256, 256)
4
-
5
  # Define dataset
6
- root = None
7
- data_path = "CSV_PATH"
8
- use_image_transform = False
9
- num_workers = 4
 
 
 
10
 
11
  # Define acceleration
 
12
  dtype = "bf16"
13
- grad_checkpoint = False
14
  plugin = "zero2"
15
  sp_size = 1
16
 
 
 
 
 
 
1
  # Define dataset
2
+ dataset = dict(
3
+ type="VideoTextDataset",
4
+ data_path=None,
5
+ num_frames=16,
6
+ frame_interval=3,
7
+ image_size=(256, 256),
8
+ )
9
 
10
  # Define acceleration
11
+ num_workers = 4
12
  dtype = "bf16"
13
+ grad_checkpoint = True
14
  plugin = "zero2"
15
  sp_size = 1
16
 
configs/dit/train/1x256x256.py CHANGED
@@ -1,14 +1,15 @@
1
- num_frames = 1
2
- frame_interval = 1
3
- image_size = (256, 256)
4
-
5
  # Define dataset
6
- root = None
7
- data_path = "CSV_PATH"
8
- use_image_transform = True
9
- num_workers = 4
 
 
 
 
10
 
11
  # Define acceleration
 
12
  dtype = "bf16"
13
  grad_checkpoint = False
14
  plugin = "zero2"
 
 
 
 
 
1
  # Define dataset
2
+ dataset = dict(
3
+ type="VideoTextDataset",
4
+ data_path=None,
5
+ num_frames=1,
6
+ frame_interval=1,
7
+ image_size=(256, 256),
8
+ transform_name="center",
9
+ )
10
 
11
  # Define acceleration
12
+ num_workers = 4
13
  dtype = "bf16"
14
  grad_checkpoint = False
15
  plugin = "zero2"
configs/latte/inference/16x256x256-class.py CHANGED
@@ -21,10 +21,10 @@ scheduler = dict(
21
  num_sampling_steps=20,
22
  cfg_scale=4.0,
23
  )
24
- dtype = "fp16"
25
 
26
  # Others
27
  batch_size = 2
28
  seed = 42
29
  prompt_path = "./assets/texts/ucf101_id.txt"
30
- save_dir = "./outputs/samples/"
 
21
  num_sampling_steps=20,
22
  cfg_scale=4.0,
23
  )
24
+ dtype = "bf16"
25
 
26
  # Others
27
  batch_size = 2
28
  seed = 42
29
  prompt_path = "./assets/texts/ucf101_id.txt"
30
+ save_dir = "./samples/samples/"
configs/latte/inference/16x256x256.py CHANGED
@@ -22,10 +22,10 @@ scheduler = dict(
22
  num_sampling_steps=20,
23
  cfg_scale=4.0,
24
  )
25
- dtype = "fp16"
26
 
27
  # Others
28
  batch_size = 2
29
  seed = 42
30
  prompt_path = "./assets/texts/ucf101_labels.txt"
31
- save_dir = "./outputs/samples/"
 
22
  num_sampling_steps=20,
23
  cfg_scale=4.0,
24
  )
25
+ dtype = "bf16"
26
 
27
  # Others
28
  batch_size = 2
29
  seed = 42
30
  prompt_path = "./assets/texts/ucf101_labels.txt"
31
+ save_dir = "./samples/samples/"
configs/latte/train/16x256x256.py CHANGED
@@ -1,14 +1,14 @@
1
- num_frames = 16
2
- frame_interval = 3
3
- image_size = (256, 256)
4
-
5
  # Define dataset
6
- root = None
7
- data_path = "CSV_PATH"
8
- use_image_transform = False
9
- num_workers = 4
 
 
 
10
 
11
  # Define acceleration
 
12
  dtype = "bf16"
13
  grad_checkpoint = True
14
  plugin = "zero2"
 
 
 
 
 
1
  # Define dataset
2
+ dataset = dict(
3
+ type="VideoTextDataset",
4
+ data_path=None,
5
+ num_frames=16,
6
+ frame_interval=3,
7
+ image_size=(256, 256),
8
+ )
9
 
10
  # Define acceleration
11
+ num_workers = 4
12
  dtype = "bf16"
13
  grad_checkpoint = True
14
  plugin = "zero2"
configs/opensora-v1-1/inference/sample-ref.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ num_frames = 16
2
+ frame_interval = 3
3
+ fps = 24
4
+ image_size = (240, 426)
5
+ multi_resolution = "STDiT2"
6
+
7
+ # Condition
8
+ prompt_path = None
9
+ prompt = [
10
+ "A car driving on the ocean.",
11
+ 'Drone view of waves crashing against the rugged cliffs along Big Sur\'s garay point beach. The crashing blue waters create white-tipped waves, while the golden light of the setting sun illuminates the rocky shore. A small island with a lighthouse sits in the distance, and green shrubbery covers the cliff\'s edge. The steep drop from the road down to the beach is a dramatic feat, with the cliff\'s edges jutting out over the sea. This is a view that captures the raw beauty of the coast and the rugged landscape of the Pacific Coast Highway.{"reference_path": "assets/images/condition/cliff.png", "mask_strategy": "0"}',
12
+ "In an ornate, historical hall, a massive tidal wave peaks and begins to crash. Two surfers, seizing the moment, skillfully navigate the face of the wave.",
13
+ ]
14
+
15
+ loop = 2
16
+ condition_frame_length = 4
17
+ reference_path = [
18
+ "https://cdn.openai.com/tmp/s/interp/d0.mp4",
19
+ None,
20
+ "assets/images/condition/wave.png",
21
+ ]
22
+ # valid when reference_path is not None
23
+ # (loop id, ref id, ref start, length, target start)
24
+ mask_strategy = [
25
+ "0,0,0,0,8,0.3",
26
+ None,
27
+ "0",
28
+ ]
29
+
30
+ # Define model
31
+ model = dict(
32
+ type="STDiT2-XL/2",
33
+ from_pretrained=None,
34
+ input_sq_size=512,
35
+ qk_norm=True,
36
+ enable_flashattn=True,
37
+ enable_layernorm_kernel=True,
38
+ )
39
+ vae = dict(
40
+ type="VideoAutoencoderKL",
41
+ from_pretrained="stabilityai/sd-vae-ft-ema",
42
+ cache_dir=None, # "/mnt/hdd/cached_models",
43
+ micro_batch_size=4,
44
+ )
45
+ text_encoder = dict(
46
+ type="t5",
47
+ from_pretrained="DeepFloyd/t5-v1_1-xxl",
48
+ cache_dir=None, # "/mnt/hdd/cached_models",
49
+ model_max_length=200,
50
+ )
51
+ scheduler = dict(
52
+ type="iddpm",
53
+ num_sampling_steps=100,
54
+ cfg_scale=7.0,
55
+ cfg_channel=3, # or None
56
+ )
57
+ dtype = "bf16"
58
+
59
+ # Others
60
+ batch_size = 1
61
+ seed = 42
62
+ save_dir = "./samples/samples/"
configs/opensora-v1-1/inference/sample.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ num_frames = 16
2
+ frame_interval = 3
3
+ fps = 24
4
+ image_size = (240, 426)
5
+ multi_resolution = "STDiT2"
6
+
7
+ # Define model
8
+ model = dict(
9
+ type="STDiT2-XL/2",
10
+ from_pretrained=None,
11
+ input_sq_size=512,
12
+ qk_norm=True,
13
+ enable_flashattn=True,
14
+ enable_layernorm_kernel=True,
15
+ )
16
+ vae = dict(
17
+ type="VideoAutoencoderKL",
18
+ from_pretrained="stabilityai/sd-vae-ft-ema",
19
+ cache_dir=None, # "/mnt/hdd/cached_models",
20
+ micro_batch_size=4,
21
+ )
22
+ text_encoder = dict(
23
+ type="t5",
24
+ from_pretrained="DeepFloyd/t5-v1_1-xxl",
25
+ cache_dir=None, # "/mnt/hdd/cached_models",
26
+ model_max_length=200,
27
+ )
28
+ scheduler = dict(
29
+ type="iddpm",
30
+ num_sampling_steps=100,
31
+ cfg_scale=7.0,
32
+ cfg_channel=3, # or None
33
+ )
34
+ dtype = "bf16"
35
+
36
+ # Condition
37
+ prompt_path = "./assets/texts/t2v_samples.txt"
38
+ prompt = None # prompt has higher priority than prompt_path
39
+
40
+ # Others
41
+ batch_size = 1
42
+ seed = 42
43
+ save_dir = "./samples/samples/"
configs/opensora-v1-1/train/benchmark.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # this file is only for batch size search and is not used for training
2
+
3
+ # Define dataset
4
+ dataset = dict(
5
+ type="VariableVideoTextDataset",
6
+ data_path=None,
7
+ num_frames=None,
8
+ frame_interval=3,
9
+ image_size=(None, None),
10
+ transform_name="resize_crop",
11
+ )
12
+
13
+ # bucket config format:
14
+ # 1. { resolution: {num_frames: (prob, batch_size)} }, in this case batch_size is ignored when searching
15
+ # 2. { resolution: {num_frames: (prob, (max_batch_size, ))} }, batch_size is searched in the range [batch_size_start, max_batch_size), batch_size_start is configured via CLI
16
+ # 3. { resolution: {num_frames: (prob, (min_batch_size, max_batch_size))} }, batch_size is searched in the range [min_batch_size, max_batch_size)
17
+ # 4. { resolution: {num_frames: (prob, (min_batch_size, max_batch_size, step_size))} }, batch_size is searched in the range [min_batch_size, max_batch_size) with step_size (grid search)
18
+ # 5. { resolution: {num_frames: (0.0, None)} }, this bucket will not be used
19
+
20
+ bucket_config = {
21
+ # == manual search ==
22
+ # "240p": {128: (1.0, 2)}, # 4.28s/it
23
+ # "240p": {64: (1.0, 4)},
24
+ # "240p": {32: (1.0, 8)}, # 4.6s/it
25
+ # "240p": {16: (1.0, 16)}, # 4.6s/it
26
+ # "480p": {16: (1.0, 4)}, # 4.6s/it
27
+ # "720p": {16: (1.0, 2)}, # 5.89s/it
28
+ # "256": {1: (1.0, 256)}, # 4.5s/it
29
+ # "512": {1: (1.0, 96)}, # 4.7s/it
30
+ # "512": {1: (1.0, 128)}, # 6.3s/it
31
+ # "480p": {1: (1.0, 50)}, # 4.0s/it
32
+ # "1024": {1: (1.0, 32)}, # 6.8s/it
33
+ # "1024": {1: (1.0, 20)}, # 4.3s/it
34
+ # "1080p": {1: (1.0, 16)}, # 8.6s/it
35
+ # "1080p": {1: (1.0, 8)}, # 4.4s/it
36
+ # == stage 2 ==
37
+ # "240p": {
38
+ # 16: (1.0, (2, 32)),
39
+ # 32: (1.0, (2, 16)),
40
+ # 64: (1.0, (2, 8)),
41
+ # 128: (1.0, (2, 6)),
42
+ # },
43
+ # "256": {1: (1.0, (128, 300))},
44
+ # "512": {1: (0.5, (64, 128))},
45
+ # "480p": {1: (0.4, (32, 128)), 16: (0.4, (2, 32)), 32: (0.0, None)},
46
+ # "720p": {16: (0.1, (2, 16)), 32: (0.0, None)}, # No examples now
47
+ # "1024": {1: (0.3, (8, 64))},
48
+ # "1080p": {1: (0.3, (2, 32))},
49
+ # == stage 3 ==
50
+ "720p": {1: (20, 40), 32: (0.5, (2, 4)), 64: (0.5, (1, 1))},
51
+ }
52
+
53
+
54
+ # Define acceleration
55
+ num_workers = 4
56
+ num_bucket_build_workers = 16
57
+ dtype = "bf16"
58
+ grad_checkpoint = True
59
+ plugin = "zero2"
60
+ sp_size = 1
61
+
62
+ # Define model
63
+ model = dict(
64
+ type="STDiT2-XL/2",
65
+ from_pretrained=None,
66
+ input_sq_size=512, # pretrained model is trained on 512x512
67
+ qk_norm=True,
68
+ enable_flashattn=True,
69
+ enable_layernorm_kernel=True,
70
+ )
71
+ vae = dict(
72
+ type="VideoAutoencoderKL",
73
+ from_pretrained="stabilityai/sd-vae-ft-ema",
74
+ micro_batch_size=4,
75
+ local_files_only=True,
76
+ )
77
+ text_encoder = dict(
78
+ type="t5",
79
+ from_pretrained="DeepFloyd/t5-v1_1-xxl",
80
+ model_max_length=200,
81
+ shardformer=True,
82
+ local_files_only=True,
83
+ )
84
+ scheduler = dict(
85
+ type="iddpm",
86
+ timestep_respacing="",
87
+ )
88
+
89
+ # Others
90
+ seed = 42
91
+ outputs = "outputs"
92
+ wandb = False
93
+
94
+ epochs = 1000
95
+ log_every = 10
96
+ ckpt_every = 1000
97
+ load = None
98
+
99
+ batch_size = None
100
+ lr = 2e-5
101
+ grad_clip = 1.0
configs/opensora-v1-1/train/image.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Define dataset
2
+ dataset = dict(
3
+ type="VariableVideoTextDataset",
4
+ data_path=None,
5
+ num_frames=None,
6
+ frame_interval=3,
7
+ image_size=(None, None),
8
+ transform_name="resize_crop",
9
+ )
10
+ bucket_config = { # 6s/it
11
+ "256": {1: (1.0, 256)},
12
+ "512": {1: (1.0, 80)},
13
+ "480p": {1: (1.0, 52)},
14
+ "1024": {1: (1.0, 20)},
15
+ "1080p": {1: (1.0, 8)},
16
+ }
17
+
18
+ # Define acceleration
19
+ num_workers = 4
20
+ num_bucket_build_workers = 16
21
+ dtype = "bf16"
22
+ grad_checkpoint = True
23
+ plugin = "zero2"
24
+ sp_size = 1
25
+
26
+ # Define model
27
+ model = dict(
28
+ type="STDiT2-XL/2",
29
+ from_pretrained=None,
30
+ input_sq_size=512, # pretrained model is trained on 512x512
31
+ qk_norm=True,
32
+ enable_flashattn=True,
33
+ enable_layernorm_kernel=True,
34
+ )
35
+ vae = dict(
36
+ type="VideoAutoencoderKL",
37
+ from_pretrained="stabilityai/sd-vae-ft-ema",
38
+ micro_batch_size=4,
39
+ local_files_only=True,
40
+ )
41
+ text_encoder = dict(
42
+ type="t5",
43
+ from_pretrained="DeepFloyd/t5-v1_1-xxl",
44
+ model_max_length=200,
45
+ shardformer=True,
46
+ local_files_only=True,
47
+ )
48
+ scheduler = dict(
49
+ type="iddpm",
50
+ timestep_respacing="",
51
+ )
52
+
53
+ # Others
54
+ seed = 42
55
+ outputs = "outputs"
56
+ wandb = False
57
+
58
+ epochs = 1000
59
+ log_every = 10
60
+ ckpt_every = 500
61
+ load = None
62
+
63
+ batch_size = 10 # only for logging
64
+ lr = 2e-5
65
+ grad_clip = 1.0
configs/opensora-v1-1/train/stage1.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Define dataset
2
+ dataset = dict(
3
+ type="VariableVideoTextDataset",
4
+ data_path=None,
5
+ num_frames=None,
6
+ frame_interval=3,
7
+ image_size=(None, None),
8
+ transform_name="resize_crop",
9
+ )
10
+ # IMG: 1024 (20%) 512 (30%) 256 (50%) drop (50%)
11
+ bucket_config = { # 1s/it
12
+ "144p": {1: (0.5, 48), 16: (1.0, 6), 32: (1.0, 3), 96: (1.0, 1)},
13
+ "256": {1: (0.5, 24), 16: (0.5, 3), 48: (0.5, 1), 64: (0.0, None)},
14
+ "240p": {16: (0.3, 2), 32: (0.3, 1), 64: (0.0, None)},
15
+ "512": {1: (0.4, 12)},
16
+ "1024": {1: (0.3, 3)},
17
+ }
18
+ mask_ratios = {
19
+ "mask_no": 0.75,
20
+ "mask_quarter_random": 0.025,
21
+ "mask_quarter_head": 0.025,
22
+ "mask_quarter_tail": 0.025,
23
+ "mask_quarter_head_tail": 0.05,
24
+ "mask_image_random": 0.025,
25
+ "mask_image_head": 0.025,
26
+ "mask_image_tail": 0.025,
27
+ "mask_image_head_tail": 0.05,
28
+ }
29
+
30
+ # Define acceleration
31
+ num_workers = 8
32
+ num_bucket_build_workers = 16
33
+ dtype = "bf16"
34
+ grad_checkpoint = False
35
+ plugin = "zero2"
36
+ sp_size = 1
37
+
38
+ # Define model
39
+ model = dict(
40
+ type="STDiT2-XL/2",
41
+ from_pretrained=None,
42
+ input_sq_size=512, # pretrained model is trained on 512x512
43
+ qk_norm=True,
44
+ enable_flashattn=True,
45
+ enable_layernorm_kernel=True,
46
+ )
47
+ vae = dict(
48
+ type="VideoAutoencoderKL",
49
+ from_pretrained="stabilityai/sd-vae-ft-ema",
50
+ micro_batch_size=4,
51
+ local_files_only=True,
52
+ )
53
+ text_encoder = dict(
54
+ type="t5",
55
+ from_pretrained="DeepFloyd/t5-v1_1-xxl",
56
+ model_max_length=200,
57
+ shardformer=True,
58
+ local_files_only=True,
59
+ )
60
+ scheduler = dict(
61
+ type="iddpm",
62
+ timestep_respacing="",
63
+ )
64
+
65
+ # Others
66
+ seed = 42
67
+ outputs = "outputs"
68
+ wandb = False
69
+
70
+ epochs = 1000
71
+ log_every = 10
72
+ ckpt_every = 500
73
+ load = None
74
+
75
+ batch_size = None
76
+ lr = 2e-5
77
+ grad_clip = 1.0
configs/opensora-v1-1/train/stage2.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Define dataset
2
+ dataset = dict(
3
+ type="VariableVideoTextDataset",
4
+ data_path=None,
5
+ num_frames=None,
6
+ frame_interval=3,
7
+ image_size=(None, None),
8
+ transform_name="resize_crop",
9
+ )
10
+ bucket_config = { # 7s/it
11
+ "144p": {1: (1.0, 48), 16: (1.0, 17), 32: (1.0, 9), 64: (1.0, 4), 128: (1.0, 1)},
12
+ "256": {1: (0.8, 254), 16: (0.5, 17), 32: (0.5, 9), 64: (0.5, 4), 128: (0.5, 1)},
13
+ "240p": {1: (0.1, 20), 16: (0.9, 17), 32: (0.8, 9), 64: (0.8, 4), 128: (0.8, 2)},
14
+ "512": {1: (0.5, 86), 16: (0.2, 4), 32: (0.2, 2), 64: (0.2, 1), 128: (0.0, None)},
15
+ "480p": {1: (0.4, 54), 16: (0.4, 4), 32: (0.0, None)},
16
+ "720p": {1: (0.1, 20), 16: (0.1, 2), 32: (0.0, None)},
17
+ "1024": {1: (0.3, 20)},
18
+ "1080p": {1: (0.4, 8)},
19
+ }
20
+ mask_ratios = {
21
+ "mask_no": 0.75,
22
+ "mask_quarter_random": 0.025,
23
+ "mask_quarter_head": 0.025,
24
+ "mask_quarter_tail": 0.025,
25
+ "mask_quarter_head_tail": 0.05,
26
+ "mask_image_random": 0.025,
27
+ "mask_image_head": 0.025,
28
+ "mask_image_tail": 0.025,
29
+ "mask_image_head_tail": 0.05,
30
+ }
31
+
32
+ # Define acceleration
33
+ num_workers = 8
34
+ num_bucket_build_workers = 16
35
+ dtype = "bf16"
36
+ grad_checkpoint = True
37
+ plugin = "zero2"
38
+ sp_size = 1
39
+
40
+ # Define model
41
+ model = dict(
42
+ type="STDiT2-XL/2",
43
+ from_pretrained=None,
44
+ input_sq_size=512, # pretrained model is trained on 512x512
45
+ qk_norm=True,
46
+ enable_flashattn=True,
47
+ enable_layernorm_kernel=True,
48
+ )
49
+ vae = dict(
50
+ type="VideoAutoencoderKL",
51
+ from_pretrained="stabilityai/sd-vae-ft-ema",
52
+ micro_batch_size=4,
53
+ local_files_only=True,
54
+ )
55
+ text_encoder = dict(
56
+ type="t5",
57
+ from_pretrained="DeepFloyd/t5-v1_1-xxl",
58
+ model_max_length=200,
59
+ shardformer=True,
60
+ local_files_only=True,
61
+ )
62
+ scheduler = dict(
63
+ type="iddpm",
64
+ timestep_respacing="",
65
+ )
66
+
67
+ # Others
68
+ seed = 42
69
+ outputs = "outputs"
70
+ wandb = False
71
+
72
+ epochs = 1000
73
+ log_every = 10
74
+ ckpt_every = 500
75
+ load = None
76
+
77
+ batch_size = None
78
+ lr = 2e-5
79
+ grad_clip = 1.0
configs/opensora-v1-1/train/stage3.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Define dataset
2
+ dataset = dict(
3
+ type="VariableVideoTextDataset",
4
+ data_path=None,
5
+ num_frames=None,
6
+ frame_interval=3,
7
+ image_size=(None, None),
8
+ transform_name="resize_crop",
9
+ )
10
+ bucket_config = { # 13s/it
11
+ "144p": {1: (1.0, 200), 16: (1.0, 36), 32: (1.0, 18), 64: (1.0, 9), 128: (1.0, 4)},
12
+ "256": {1: (0.8, 200), 16: (0.5, 22), 32: (0.5, 11), 64: (0.5, 6), 128: (0.8, 4)},
13
+ "240p": {1: (0.8, 200), 16: (0.5, 22), 32: (0.5, 10), 64: (0.5, 6), 128: (0.5, 3)},
14
+ "360p": {1: (0.5, 120), 16: (0.5, 9), 32: (0.5, 4), 64: (0.5, 2), 128: (0.5, 1)},
15
+ "512": {1: (0.5, 120), 16: (0.5, 9), 32: (0.5, 4), 64: (0.5, 2), 128: (0.8, 1)},
16
+ "480p": {1: (0.4, 80), 16: (0.6, 6), 32: (0.6, 3), 64: (0.6, 1), 128: (0.0, None)},
17
+ "720p": {1: (0.4, 40), 16: (0.6, 3), 32: (0.6, 1), 96: (0.0, None)},
18
+ "1024": {1: (0.3, 40)},
19
+ }
20
+ mask_ratios = {
21
+ "mask_no": 0.75,
22
+ "mask_quarter_random": 0.025,
23
+ "mask_quarter_head": 0.025,
24
+ "mask_quarter_tail": 0.025,
25
+ "mask_quarter_head_tail": 0.05,
26
+ "mask_image_random": 0.025,
27
+ "mask_image_head": 0.025,
28
+ "mask_image_tail": 0.025,
29
+ "mask_image_head_tail": 0.05,
30
+ }
31
+
32
+ # Define acceleration
33
+ num_workers = 8
34
+ num_bucket_build_workers = 16
35
+ dtype = "bf16"
36
+ grad_checkpoint = True
37
+ plugin = "zero2"
38
+ sp_size = 1
39
+
40
+ # Define model
41
+ model = dict(
42
+ type="STDiT2-XL/2",
43
+ from_pretrained=None,
44
+ input_sq_size=512, # pretrained model is trained on 512x512
45
+ qk_norm=True,
46
+ enable_flashattn=True,
47
+ enable_layernorm_kernel=True,
48
+ )
49
+ vae = dict(
50
+ type="VideoAutoencoderKL",
51
+ from_pretrained="stabilityai/sd-vae-ft-ema",
52
+ micro_batch_size=4,
53
+ local_files_only=True,
54
+ )
55
+ text_encoder = dict(
56
+ type="t5",
57
+ from_pretrained="DeepFloyd/t5-v1_1-xxl",
58
+ model_max_length=200,
59
+ shardformer=True,
60
+ local_files_only=True,
61
+ )
62
+ scheduler = dict(
63
+ type="iddpm",
64
+ timestep_respacing="",
65
+ )
66
+
67
+ # Others
68
+ seed = 42
69
+ outputs = "outputs"
70
+ wandb = False
71
+
72
+ epochs = 1000
73
+ log_every = 10
74
+ ckpt_every = 500
75
+ load = None
76
+
77
+ batch_size = None
78
+ lr = 2e-5
79
+ grad_clip = 1.0
configs/opensora-v1-1/train/video.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Define dataset
2
+ dataset = dict(
3
+ type="VariableVideoTextDataset",
4
+ data_path=None,
5
+ num_frames=None,
6
+ frame_interval=3,
7
+ image_size=(None, None),
8
+ transform_name="resize_crop",
9
+ )
10
+ bucket_config = { # 6s/it
11
+ "240p": {16: (1.0, 16), 32: (1.0, 8), 64: (1.0, 4), 128: (1.0, 2)},
12
+ "256": {1: (1.0, 256)},
13
+ "512": {1: (0.5, 80)},
14
+ "480p": {1: (0.4, 52), 16: (0.4, 4), 32: (0.0, None)},
15
+ "720p": {16: (0.1, 2), 32: (0.0, None)}, # No examples now
16
+ "1024": {1: (0.3, 20)},
17
+ "1080p": {1: (0.3, 8)},
18
+ }
19
+
20
+ # Define acceleration
21
+ num_workers = 4
22
+ num_bucket_build_workers = 16
23
+ dtype = "bf16"
24
+ grad_checkpoint = True
25
+ plugin = "zero2"
26
+ sp_size = 1
27
+
28
+ # Define model
29
+ model = dict(
30
+ type="STDiT2-XL/2",
31
+ from_pretrained=None,
32
+ input_sq_size=512, # pretrained model is trained on 512x512
33
+ qk_norm=True,
34
+ enable_flashattn=True,
35
+ enable_layernorm_kernel=True,
36
+ )
37
+ vae = dict(
38
+ type="VideoAutoencoderKL",
39
+ from_pretrained="stabilityai/sd-vae-ft-ema",
40
+ micro_batch_size=4,
41
+ local_files_only=True,
42
+ )
43
+ text_encoder = dict(
44
+ type="t5",
45
+ from_pretrained="DeepFloyd/t5-v1_1-xxl",
46
+ model_max_length=200,
47
+ shardformer=True,
48
+ local_files_only=True,
49
+ )
50
+ scheduler = dict(
51
+ type="iddpm",
52
+ timestep_respacing="",
53
+ )
54
+
55
+ # Others
56
+ seed = 42
57
+ outputs = "outputs"
58
+ wandb = False
59
+
60
+ epochs = 1000
61
+ log_every = 10
62
+ ckpt_every = 500
63
+ load = None
64
+
65
+ batch_size = 10 # only for logging
66
+ lr = 2e-5
67
+ grad_clip = 1.0
configs/opensora/inference/16x256x256.py CHANGED
@@ -7,13 +7,14 @@ model = dict(
7
  type="STDiT-XL/2",
8
  space_scale=0.5,
9
  time_scale=1.0,
10
- enable_flashattn=False,
11
- enable_layernorm_kernel=False,
12
  from_pretrained="PRETRAINED_MODEL",
13
  )
14
  vae = dict(
15
  type="VideoAutoencoderKL",
16
  from_pretrained="stabilityai/sd-vae-ft-ema",
 
17
  )
18
  text_encoder = dict(
19
  type="t5",
@@ -24,11 +25,15 @@ scheduler = dict(
24
  type="iddpm",
25
  num_sampling_steps=100,
26
  cfg_scale=7.0,
 
27
  )
28
- dtype = "fp16"
 
 
 
 
29
 
30
  # Others
31
- batch_size = 2
32
  seed = 42
33
- prompt_path = "./assets/texts/t2v_samples.txt"
34
- save_dir = "./outputs/samples/"
 
7
  type="STDiT-XL/2",
8
  space_scale=0.5,
9
  time_scale=1.0,
10
+ enable_flashattn=True,
11
+ enable_layernorm_kernel=True,
12
  from_pretrained="PRETRAINED_MODEL",
13
  )
14
  vae = dict(
15
  type="VideoAutoencoderKL",
16
  from_pretrained="stabilityai/sd-vae-ft-ema",
17
+ micro_batch_size=4,
18
  )
19
  text_encoder = dict(
20
  type="t5",
 
25
  type="iddpm",
26
  num_sampling_steps=100,
27
  cfg_scale=7.0,
28
+ cfg_channel=3, # or None
29
  )
30
+ dtype = "bf16"
31
+
32
+ # Condition
33
+ prompt_path = "./assets/texts/t2v_samples.txt"
34
+ prompt = None # prompt has higher priority than prompt_path
35
 
36
  # Others
37
+ batch_size = 1
38
  seed = 42
39
+ save_dir = "./samples/samples/"
 
configs/opensora/inference/16x512x512.py CHANGED
@@ -7,14 +7,14 @@ model = dict(
7
  type="STDiT-XL/2",
8
  space_scale=1.0,
9
  time_scale=1.0,
10
- enable_flashattn=False,
11
- enable_layernorm_kernel=False,
12
- from_pretrained="PRETRAINED_MODEL"
13
  )
14
  vae = dict(
15
  type="VideoAutoencoderKL",
16
  from_pretrained="stabilityai/sd-vae-ft-ema",
17
- micro_batch_size=128,
18
  )
19
  text_encoder = dict(
20
  type="t5",
@@ -26,10 +26,10 @@ scheduler = dict(
26
  num_sampling_steps=100,
27
  cfg_scale=7.0,
28
  )
29
- dtype = "fp16"
30
 
31
  # Others
32
  batch_size = 2
33
  seed = 42
34
  prompt_path = "./assets/texts/t2v_samples.txt"
35
- save_dir = "./outputs/samples/"
 
7
  type="STDiT-XL/2",
8
  space_scale=1.0,
9
  time_scale=1.0,
10
+ enable_flashattn=True,
11
+ enable_layernorm_kernel=True,
12
+ from_pretrained="PRETRAINED_MODEL",
13
  )
14
  vae = dict(
15
  type="VideoAutoencoderKL",
16
  from_pretrained="stabilityai/sd-vae-ft-ema",
17
+ micro_batch_size=2,
18
  )
19
  text_encoder = dict(
20
  type="t5",
 
26
  num_sampling_steps=100,
27
  cfg_scale=7.0,
28
  )
29
+ dtype = "bf16"
30
 
31
  # Others
32
  batch_size = 2
33
  seed = 42
34
  prompt_path = "./assets/texts/t2v_samples.txt"
35
+ save_dir = "./samples/samples/"
configs/opensora/inference/64x512x512.py CHANGED
@@ -1,5 +1,5 @@
1
- num_frames = 16
2
- fps = 24 //4
3
  image_size = (512, 512)
4
 
5
  # Define model
@@ -7,8 +7,8 @@ model = dict(
7
  type="STDiT-XL/2",
8
  space_scale=1.0,
9
  time_scale=2 / 3,
10
- enable_flashattn=False,
11
- enable_layernorm_kernel=False,
12
  from_pretrained="PRETRAINED_MODEL",
13
  )
14
  vae = dict(
@@ -23,13 +23,13 @@ text_encoder = dict(
23
  )
24
  scheduler = dict(
25
  type="iddpm",
26
- num_sampling_steps=50,
27
  cfg_scale=7.0,
28
  )
29
- dtype = "fp16"
30
 
31
  # Others
32
  batch_size = 1
33
  seed = 42
34
  prompt_path = "./assets/texts/t2v_samples.txt"
35
- save_dir = "./outputs/samples/"
 
1
+ num_frames = 64
2
+ fps = 24 // 2
3
  image_size = (512, 512)
4
 
5
  # Define model
 
7
  type="STDiT-XL/2",
8
  space_scale=1.0,
9
  time_scale=2 / 3,
10
+ enable_flashattn=True,
11
+ enable_layernorm_kernel=True,
12
  from_pretrained="PRETRAINED_MODEL",
13
  )
14
  vae = dict(
 
23
  )
24
  scheduler = dict(
25
  type="iddpm",
26
+ num_sampling_steps=100,
27
  cfg_scale=7.0,
28
  )
29
+ dtype = "bf16"
30
 
31
  # Others
32
  batch_size = 1
33
  seed = 42
34
  prompt_path = "./assets/texts/t2v_samples.txt"
35
+ save_dir = "./samples/samples/"
configs/opensora/train/16x256x256-mask.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Define dataset
2
+ dataset = dict(
3
+ type="VideoTextDataset",
4
+ data_path=None,
5
+ num_frames=16,
6
+ frame_interval=3,
7
+ image_size=(256, 256),
8
+ )
9
+
10
+ # Define acceleration
11
+ num_workers = 4
12
+ dtype = "bf16"
13
+ grad_checkpoint = True
14
+ plugin = "zero2"
15
+ sp_size = 1
16
+
17
+ # Define model
18
+ model = dict(
19
+ type="STDiT-XL/2",
20
+ space_scale=0.5,
21
+ time_scale=1.0,
22
+ from_pretrained="PixArt-XL-2-512x512.pth",
23
+ enable_flashattn=True,
24
+ enable_layernorm_kernel=True,
25
+ )
26
+ mask_ratios = {
27
+ "mask_no": 0.7,
28
+ "mask_random": 0.15,
29
+ "mask_head": 0.05,
30
+ "mask_tail": 0.05,
31
+ "mask_head_tail": 0.05,
32
+ }
33
+ vae = dict(
34
+ type="VideoAutoencoderKL",
35
+ from_pretrained="stabilityai/sd-vae-ft-ema",
36
+ )
37
+ text_encoder = dict(
38
+ type="t5",
39
+ from_pretrained="DeepFloyd/t5-v1_1-xxl",
40
+ model_max_length=120,
41
+ shardformer=True,
42
+ )
43
+ scheduler = dict(
44
+ type="iddpm",
45
+ timestep_respacing="",
46
+ )
47
+
48
+ # Others
49
+ seed = 42
50
+ outputs = "outputs"
51
+ wandb = False
52
+
53
+ epochs = 1000
54
+ log_every = 10
55
+ ckpt_every = 1000
56
+ load = None
57
+
58
+ batch_size = 8
59
+ lr = 2e-5
60
+ grad_clip = 1.0
configs/opensora/train/16x256x256-spee.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Define dataset
2
+ dataset = dict(
3
+ type="VideoTextDataset",
4
+ data_path=None,
5
+ num_frames=16,
6
+ frame_interval=3,
7
+ image_size=(256, 256),
8
+ )
9
+
10
+ # Define acceleration
11
+ num_workers = 4
12
+ dtype = "bf16"
13
+ grad_checkpoint = True
14
+ plugin = "zero2"
15
+ sp_size = 1
16
+
17
+ # Define model
18
+ model = dict(
19
+ type="STDiT-XL/2",
20
+ space_scale=0.5,
21
+ time_scale=1.0,
22
+ from_pretrained="PixArt-XL-2-512x512.pth",
23
+ enable_flashattn=True,
24
+ enable_layernorm_kernel=True,
25
+ )
26
+ mask_ratios = {
27
+ "mask_no": 0.5,
28
+ "mask_random": 0.29,
29
+ "mask_head": 0.07,
30
+ "mask_tail": 0.07,
31
+ "mask_head_tail": 0.07,
32
+ }
33
+ vae = dict(
34
+ type="VideoAutoencoderKL",
35
+ from_pretrained="stabilityai/sd-vae-ft-ema",
36
+ )
37
+ text_encoder = dict(
38
+ type="t5",
39
+ from_pretrained="DeepFloyd/t5-v1_1-xxl",
40
+ model_max_length=120,
41
+ shardformer=True,
42
+ )
43
+ scheduler = dict(
44
+ type="iddpm-speed",
45
+ timestep_respacing="",
46
+ )
47
+
48
+ # Others
49
+ seed = 42
50
+ outputs = "outputs"
51
+ wandb = False
52
+
53
+ epochs = 1000
54
+ log_every = 10
55
+ ckpt_every = 1000
56
+ load = None
57
+
58
+ batch_size = 8
59
+ lr = 2e-5
60
+ grad_clip = 1.0
configs/opensora/train/16x256x256.py CHANGED
@@ -1,14 +1,14 @@
1
- num_frames = 16
2
- frame_interval = 3
3
- image_size = (256, 256)
4
-
5
  # Define dataset
6
- root = None
7
- data_path = "CSV_PATH"
8
- use_image_transform = False
9
- num_workers = 4
 
 
 
10
 
11
  # Define acceleration
 
12
  dtype = "bf16"
13
  grad_checkpoint = True
14
  plugin = "zero2"
@@ -29,7 +29,7 @@ vae = dict(
29
  )
30
  text_encoder = dict(
31
  type="t5",
32
- from_pretrained="./pretrained_models/t5_ckpts",
33
  model_max_length=120,
34
  shardformer=True,
35
  )
 
 
 
 
 
1
  # Define dataset
2
+ dataset = dict(
3
+ type="VideoTextDataset",
4
+ data_path=None,
5
+ num_frames=16,
6
+ frame_interval=3,
7
+ image_size=(256, 256),
8
+ )
9
 
10
  # Define acceleration
11
+ num_workers = 4
12
  dtype = "bf16"
13
  grad_checkpoint = True
14
  plugin = "zero2"
 
29
  )
30
  text_encoder = dict(
31
  type="t5",
32
+ from_pretrained="DeepFloyd/t5-v1_1-xxl",
33
  model_max_length=120,
34
  shardformer=True,
35
  )
configs/opensora/train/16x512x512.py CHANGED
@@ -1,16 +1,16 @@
1
- num_frames = 16
2
- frame_interval = 3
3
- image_size = (512, 512)
4
-
5
  # Define dataset
6
- root = None
7
- data_path = "CSV_PATH"
8
- use_image_transform = False
9
- num_workers = 4
 
 
 
10
 
11
  # Define acceleration
 
12
  dtype = "bf16"
13
- grad_checkpoint = False
14
  plugin = "zero2"
15
  sp_size = 1
16
 
@@ -30,7 +30,7 @@ vae = dict(
30
  )
31
  text_encoder = dict(
32
  type="t5",
33
- from_pretrained="./pretrained_models/t5_ckpts",
34
  model_max_length=120,
35
  shardformer=True,
36
  )
 
 
 
 
 
1
  # Define dataset
2
+ dataset = dict(
3
+ type="VideoTextDataset",
4
+ data_path=None,
5
+ num_frames=16,
6
+ frame_interval=3,
7
+ image_size=(512, 512),
8
+ )
9
 
10
  # Define acceleration
11
+ num_workers = 4
12
  dtype = "bf16"
13
+ grad_checkpoint = True
14
  plugin = "zero2"
15
  sp_size = 1
16
 
 
30
  )
31
  text_encoder = dict(
32
  type="t5",
33
+ from_pretrained="DeepFloyd/t5-v1_1-xxl",
34
  model_max_length=120,
35
  shardformer=True,
36
  )
configs/opensora/train/360x512x512.py CHANGED
@@ -1,12 +1,18 @@
1
- num_frames = 360
2
- frame_interval = 1
3
- image_size = (512, 512)
4
-
5
  # Define dataset
6
- root = None
7
- data_path = "CSV_PATH"
8
- use_image_transform = False
 
 
 
 
 
 
9
  num_workers = 4
 
 
 
 
10
 
11
  # Define acceleration
12
  dtype = "bf16"
@@ -31,7 +37,7 @@ vae = dict(
31
  )
32
  text_encoder = dict(
33
  type="t5",
34
- from_pretrained="./pretrained_models/t5_ckpts",
35
  model_max_length=120,
36
  shardformer=True,
37
  )
 
 
 
 
 
1
  # Define dataset
2
+ dataset = dict(
3
+ type="VideoTextDataset",
4
+ data_path=None,
5
+ num_frames=360,
6
+ frame_interval=3,
7
+ image_size=(512, 512),
8
+ )
9
+
10
+ # Define acceleration
11
  num_workers = 4
12
+ dtype = "bf16"
13
+ grad_checkpoint = True
14
+ plugin = "zero2"
15
+ sp_size = 1
16
 
17
  # Define acceleration
18
  dtype = "bf16"
 
37
  )
38
  text_encoder = dict(
39
  type="t5",
40
+ from_pretrained="DeepFloyd/t5-v1_1-xxl",
41
  model_max_length=120,
42
  shardformer=True,
43
  )
configs/opensora/train/64x512x512-sp.py CHANGED
@@ -1,17 +1,17 @@
1
- num_frames = 64
2
- frame_interval = 2
3
- image_size = (512, 512)
4
-
5
  # Define dataset
6
- root = None
7
- data_path = "CSV_PATH"
8
- use_image_transform = False
9
- num_workers = 4
 
 
 
10
 
11
  # Define acceleration
 
12
  dtype = "bf16"
13
  grad_checkpoint = True
14
- plugin = "zero2-seq"
15
  sp_size = 2
16
 
17
  # Define model
@@ -30,7 +30,7 @@ vae = dict(
30
  )
31
  text_encoder = dict(
32
  type="t5",
33
- from_pretrained="./pretrained_models/t5_ckpts",
34
  model_max_length=120,
35
  shardformer=True,
36
  )
 
 
 
 
 
1
  # Define dataset
2
+ dataset = dict(
3
+ type="VideoTextDataset",
4
+ data_path=None,
5
+ num_frames=16,
6
+ frame_interval=3,
7
+ image_size=(512, 512),
8
+ )
9
 
10
  # Define acceleration
11
+ num_workers = 4
12
  dtype = "bf16"
13
  grad_checkpoint = True
14
+ plugin = "zero2"
15
  sp_size = 2
16
 
17
  # Define model
 
30
  )
31
  text_encoder = dict(
32
  type="t5",
33
+ from_pretrained="DeepFloyd/t5-v1_1-xxl",
34
  model_max_length=120,
35
  shardformer=True,
36
  )
configs/opensora/train/64x512x512.py CHANGED
@@ -1,14 +1,14 @@
1
- num_frames = 64
2
- frame_interval = 2
3
- image_size = (512, 512)
4
-
5
  # Define dataset
6
- root = None
7
- data_path = "CSV_PATH"
8
- use_image_transform = False
9
- num_workers = 4
 
 
 
10
 
11
  # Define acceleration
 
12
  dtype = "bf16"
13
  grad_checkpoint = True
14
  plugin = "zero2"
@@ -30,7 +30,7 @@ vae = dict(
30
  )
31
  text_encoder = dict(
32
  type="t5",
33
- from_pretrained="./pretrained_models/t5_ckpts",
34
  model_max_length=120,
35
  shardformer=True,
36
  )
 
 
 
 
 
1
  # Define dataset
2
+ dataset = dict(
3
+ type="VideoTextDataset",
4
+ data_path=None,
5
+ num_frames=64,
6
+ frame_interval=3,
7
+ image_size=(512, 512),
8
+ )
9
 
10
  # Define acceleration
11
+ num_workers = 4
12
  dtype = "bf16"
13
  grad_checkpoint = True
14
  plugin = "zero2"
 
30
  )
31
  text_encoder = dict(
32
  type="t5",
33
+ from_pretrained="DeepFloyd/t5-v1_1-xxl",
34
  model_max_length=120,
35
  shardformer=True,
36
  )
configs/pixart/inference/16x256x256.py CHANGED
@@ -15,7 +15,7 @@ vae = dict(
15
  )
16
  text_encoder = dict(
17
  type="t5",
18
- from_pretrained="./pretrained_models/t5_ckpts",
19
  model_max_length=120,
20
  )
21
  scheduler = dict(
@@ -23,10 +23,10 @@ scheduler = dict(
23
  num_sampling_steps=20,
24
  cfg_scale=7.0,
25
  )
26
- dtype = "fp16"
27
 
28
  # Others
29
  batch_size = 2
30
  seed = 42
31
  prompt_path = "./assets/texts/t2v_samples.txt"
32
- save_dir = "./outputs/samples/"
 
15
  )
16
  text_encoder = dict(
17
  type="t5",
18
+ from_pretrained="DeepFloyd/t5-v1_1-xxl",
19
  model_max_length=120,
20
  )
21
  scheduler = dict(
 
23
  num_sampling_steps=20,
24
  cfg_scale=7.0,
25
  )
26
+ dtype = "bf16"
27
 
28
  # Others
29
  batch_size = 2
30
  seed = 42
31
  prompt_path = "./assets/texts/t2v_samples.txt"
32
+ save_dir = "./samples/samples/"
configs/pixart/inference/1x1024MS.py CHANGED
@@ -1,7 +1,7 @@
1
  num_frames = 1
2
  fps = 1
3
  image_size = (1920, 512)
4
- multi_resolution = True
5
 
6
  # Define model
7
  model = dict(
@@ -17,7 +17,7 @@ vae = dict(
17
  )
18
  text_encoder = dict(
19
  type="t5",
20
- from_pretrained="./pretrained_models/t5_ckpts",
21
  model_max_length=120,
22
  )
23
  scheduler = dict(
@@ -25,10 +25,10 @@ scheduler = dict(
25
  num_sampling_steps=20,
26
  cfg_scale=7.0,
27
  )
28
- dtype = "fp16"
29
 
30
  # Others
31
  batch_size = 2
32
  seed = 42
33
  prompt_path = "./assets/texts/t2i_samples.txt"
34
- save_dir = "./outputs/samples/"
 
1
  num_frames = 1
2
  fps = 1
3
  image_size = (1920, 512)
4
+ multi_resolution = "PixArtMS"
5
 
6
  # Define model
7
  model = dict(
 
17
  )
18
  text_encoder = dict(
19
  type="t5",
20
+ from_pretrained="DeepFloyd/t5-v1_1-xxl",
21
  model_max_length=120,
22
  )
23
  scheduler = dict(
 
25
  num_sampling_steps=20,
26
  cfg_scale=7.0,
27
  )
28
+ dtype = "bf16"
29
 
30
  # Others
31
  batch_size = 2
32
  seed = 42
33
  prompt_path = "./assets/texts/t2i_samples.txt"
34
+ save_dir = "./samples/samples/"
configs/pixart/inference/1x256x256.py CHANGED
@@ -16,7 +16,7 @@ vae = dict(
16
  )
17
  text_encoder = dict(
18
  type="t5",
19
- from_pretrained="./pretrained_models/t5_ckpts",
20
  model_max_length=120,
21
  )
22
  scheduler = dict(
@@ -24,10 +24,10 @@ scheduler = dict(
24
  num_sampling_steps=20,
25
  cfg_scale=7.0,
26
  )
27
- dtype = "fp16"
28
 
29
  # Others
30
  batch_size = 2
31
  seed = 42
32
  prompt_path = "./assets/texts/t2i_samples.txt"
33
- save_dir = "./outputs/samples/"
 
16
  )
17
  text_encoder = dict(
18
  type="t5",
19
+ from_pretrained="DeepFloyd/t5-v1_1-xxl",
20
  model_max_length=120,
21
  )
22
  scheduler = dict(
 
24
  num_sampling_steps=20,
25
  cfg_scale=7.0,
26
  )
27
+ dtype = "bf16"
28
 
29
  # Others
30
  batch_size = 2
31
  seed = 42
32
  prompt_path = "./assets/texts/t2i_samples.txt"
33
+ save_dir = "./samples/samples/"
configs/pixart/inference/1x512x512.py CHANGED
@@ -16,7 +16,7 @@ vae = dict(
16
  )
17
  text_encoder = dict(
18
  type="t5",
19
- from_pretrained="./pretrained_models/t5_ckpts",
20
  model_max_length=120,
21
  )
22
  scheduler = dict(
@@ -24,10 +24,16 @@ scheduler = dict(
24
  num_sampling_steps=20,
25
  cfg_scale=7.0,
26
  )
27
- dtype = "fp16"
 
 
 
 
 
 
 
28
 
29
  # Others
30
  batch_size = 2
31
  seed = 42
32
- prompt_path = "./assets/texts/t2i_samples.txt"
33
- save_dir = "./outputs/samples/"
 
16
  )
17
  text_encoder = dict(
18
  type="t5",
19
+ from_pretrained="DeepFloyd/t5-v1_1-xxl",
20
  model_max_length=120,
21
  )
22
  scheduler = dict(
 
24
  num_sampling_steps=20,
25
  cfg_scale=7.0,
26
  )
27
+ dtype = "bf16"
28
+
29
+ # prompt_path = "./assets/texts/t2i_samples.txt"
30
+ prompt = [
31
+ "Pirate ship trapped in a cosmic maelstrom nebula.",
32
+ "A small cactus with a happy face in the Sahara desert.",
33
+ "A small cactus with a sad face in the Sahara desert.",
34
+ ]
35
 
36
  # Others
37
  batch_size = 2
38
  seed = 42
39
+ save_dir = "./samples/samples/"
 
configs/pixart/train/16x256x256.py CHANGED
@@ -1,16 +1,16 @@
1
- num_frames = 16
2
- frame_interval = 3
3
- image_size = (256, 256)
4
-
5
  # Define dataset
6
- root = None
7
- data_path = "CSV_PATH"
8
- use_image_transform = False
9
- num_workers = 4
 
 
 
10
 
11
  # Define acceleration
 
12
  dtype = "bf16"
13
- grad_checkpoint = False
14
  plugin = "zero2"
15
  sp_size = 1
16
 
@@ -29,7 +29,7 @@ vae = dict(
29
  )
30
  text_encoder = dict(
31
  type="t5",
32
- from_pretrained="./pretrained_models/t5_ckpts",
33
  model_max_length=120,
34
  shardformer=True,
35
  )
 
 
 
 
 
1
  # Define dataset
2
+ dataset = dict(
3
+ type="VideoTextDataset",
4
+ data_path=None,
5
+ num_frames=16,
6
+ frame_interval=3,
7
+ image_size=(256, 256),
8
+ )
9
 
10
  # Define acceleration
11
+ num_workers = 4
12
  dtype = "bf16"
13
+ grad_checkpoint = True
14
  plugin = "zero2"
15
  sp_size = 1
16
 
 
29
  )
30
  text_encoder = dict(
31
  type="t5",
32
+ from_pretrained="DeepFloyd/t5-v1_1-xxl",
33
  model_max_length=120,
34
  shardformer=True,
35
  )
configs/pixart/train/1x512x512.py CHANGED
@@ -1,14 +1,14 @@
1
- num_frames = 1
2
- frame_interval = 1
3
- image_size = (512, 512)
4
-
5
  # Define dataset
6
- root = None
7
- data_path = "CSV_PATH"
8
- use_image_transform = True
9
- num_workers = 4
 
 
 
10
 
11
  # Define acceleration
 
12
  dtype = "bf16"
13
  grad_checkpoint = True
14
  plugin = "zero2"
@@ -30,7 +30,7 @@ vae = dict(
30
  )
31
  text_encoder = dict(
32
  type="t5",
33
- from_pretrained="./pretrained_models/t5_ckpts",
34
  model_max_length=120,
35
  shardformer=True,
36
  )
 
 
 
 
 
1
  # Define dataset
2
+ dataset = dict(
3
+ type="VideoTextDataset",
4
+ data_path=None,
5
+ num_frames=1,
6
+ frame_interval=3,
7
+ image_size=(512, 512),
8
+ )
9
 
10
  # Define acceleration
11
+ num_workers = 4
12
  dtype = "bf16"
13
  grad_checkpoint = True
14
  plugin = "zero2"
 
30
  )
31
  text_encoder = dict(
32
  type="t5",
33
+ from_pretrained="DeepFloyd/t5-v1_1-xxl",
34
  model_max_length=120,
35
  shardformer=True,
36
  )
configs/pixart/train/64x512x512.py CHANGED
@@ -1,19 +1,20 @@
1
- num_frames = 64
2
- frame_interval = 2
3
- image_size = (512, 512)
4
-
5
  # Define dataset
6
- root = None
7
- data_path = "CSV_PATH"
8
- use_image_transform = False
9
- num_workers = 4
 
 
 
10
 
11
  # Define acceleration
 
12
  dtype = "bf16"
13
  grad_checkpoint = True
14
  plugin = "zero2"
15
  sp_size = 1
16
 
 
17
  # Define model
18
  model = dict(
19
  type="PixArt-XL/2",
@@ -30,7 +31,7 @@ vae = dict(
30
  )
31
  text_encoder = dict(
32
  type="t5",
33
- from_pretrained="./pretrained_models/t5_ckpts",
34
  model_max_length=120,
35
  shardformer=True,
36
  )
 
 
 
 
 
1
  # Define dataset
2
+ dataset = dict(
3
+ type="VideoTextDataset",
4
+ data_path=None,
5
+ num_frames=64,
6
+ frame_interval=3,
7
+ image_size=(256, 256),
8
+ )
9
 
10
  # Define acceleration
11
+ num_workers = 4
12
  dtype = "bf16"
13
  grad_checkpoint = True
14
  plugin = "zero2"
15
  sp_size = 1
16
 
17
+
18
  # Define model
19
  model = dict(
20
  type="PixArt-XL/2",
 
31
  )
32
  text_encoder = dict(
33
  type="t5",
34
+ from_pretrained="DeepFloyd/t5-v1_1-xxl",
35
  model_max_length=120,
36
  shardformer=True,
37
  )