kunhaokhliu commited on
Commit
5d2a97a
·
1 Parent(s): 21a626f

Add application file

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ __pycache__
2
+ *.egg-info
3
+ .cache
4
+
5
+ wan_models
6
+ checkpoints
7
+ videos
8
+ logs
LICENSE ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Tencent is pleased to support the community by making RollingForcing available.
2
+
3
+ Copyright (C) 2025 Tencent. All rights reserved.
4
+
5
+ The open-source software and/or models included in this distribution may have been modified by Tencent (“Tencent Modifications”). All Tencent Modifications are Copyright (C) Tencent.
6
+
7
+ RollingForcing is licensed under the License Terms of RollingForcing, except for the third-party components listed below, which remain licensed under their respective original terms. RollingForcing does not impose any additional restrictions beyond those specified in the original licenses of these third-party components. Users are required to comply with all applicable terms and conditions of the original licenses and to ensure that the use of these third-party components conforms to all relevant laws and regulations.
8
+
9
+ For the avoidance of doubt, RollingForcing refers solely to training code, inference code, parameters, and weights made publicly available by Tencent in accordance with the License Terms of RollingForcing.
10
+
11
+ Terms of the License Terms of RollingForcing:
12
+ --------------------------------------------------------------------
13
+ Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, and /or sublicense copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
14
+
15
+ - You agree to use RollingForcing only for academic purposes, and refrain from using it for any commercial or production purposes under any circumstances.
16
+
17
+ - The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
18
+
19
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
20
+
21
+
22
+
23
+ Dependencies and Licenses:
24
+
25
+ This open-source project, RollingForcing, builds upon the following open-source models and/or software components, each of which remains licensed under its original license. Certain models or software may include modifications made by Tencent (“Tencent Modifications”), which are Copyright (C) Tencent.
26
+
27
+ In case you believe there have been errors in the attribution below, you may submit the concerns to us for review and correction.
28
+
29
+ Open Source Model Licensed under the Apache-2.0:
30
+ --------------------------------------------------------------------
31
+ 1. Wan-AI/Wan2.1-T2V-1.3B
32
+ Copyright (c) 2025 Wan Team
33
+
34
+ Terms of the Apache-2.0:
35
+ --------------------------------------------------------------------
36
+ Apache License
37
+ Version 2.0, January 2004
38
+ http://www.apache.org/licenses/
39
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
40
+
41
+ Definitions.
42
+
43
+ "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document.
44
+
45
+ "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License.
46
+
47
+ "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
48
+
49
+ "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License.
50
+
51
+ "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files.
52
+
53
+ "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types.
54
+
55
+ "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below).
56
+
57
+ "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof.
58
+
59
+ "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution."
60
+
61
+ "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work.
62
+
63
+ Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form.
64
+
65
+ Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed.
66
+
67
+ Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:
68
+
69
+ (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License.
70
+
71
+ Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions.
72
+
73
+ Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file.
74
+
75
+ Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License.
76
+
77
+ Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
78
+
79
+ Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
80
+
81
+ END OF TERMS AND CONDITIONS
README.md CHANGED
@@ -8,7 +8,7 @@ sdk_version: 5.49.1
8
  app_file: app.py
9
  pinned: false
10
  license: other
11
- short_description: 'Rolling Forcing: Autoregressive Long Video Diffusion in Real'
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
8
  app_file: app.py
9
  pinned: false
10
  license: other
11
+ short_description: 'Rolling Forcing: Autoregressive Long Video Diffusion in Real Time'
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import time
4
+ from typing import Optional
5
+
6
+ import torch
7
+ from torchvision.io import write_video
8
+ from omegaconf import OmegaConf
9
+ from einops import rearrange
10
+ import app as gr
11
+
12
+ from pipeline import CausalInferencePipeline
13
+ from huggingface_hub import snapshot_download, hf_hub_download
14
+
15
+
16
+ # -----------------------------
17
+ # Globals (loaded once per process)
18
+ # -----------------------------
19
+ _PIPELINE: Optional[torch.nn.Module] = None
20
+ _DEVICE: Optional[torch.device] = None
21
+
22
+
23
+ def _ensure_gpu():
24
+ if not torch.cuda.is_available():
25
+ raise gr.Error("CUDA GPU is required to run this demo. Please run on a machine with an NVIDIA GPU.")
26
+ # Bind to GPU:0 by default
27
+ torch.cuda.set_device(0)
28
+
29
+
30
+ def _load_pipeline(config_path: str, checkpoint_path: Optional[str], use_ema: bool) -> torch.nn.Module:
31
+ global _PIPELINE, _DEVICE
32
+ if _PIPELINE is not None:
33
+ return _PIPELINE
34
+
35
+ _ensure_gpu()
36
+ _DEVICE = torch.device("cuda:0")
37
+
38
+ # Load and merge configs
39
+ config = OmegaConf.load(config_path)
40
+ default_config = OmegaConf.load("configs/default_config.yaml")
41
+ config = OmegaConf.merge(default_config, config)
42
+
43
+ # Choose pipeline type based on config
44
+ pipeline = CausalInferencePipeline(config, device=_DEVICE)
45
+
46
+
47
+ # Load checkpoint if provided
48
+ if checkpoint_path and os.path.exists(checkpoint_path):
49
+ state_dict = torch.load(checkpoint_path, map_location="cpu")
50
+ if use_ema and 'generator_ema' in state_dict:
51
+ state_dict_to_load = state_dict['generator_ema']
52
+ # Remove possible FSDP prefix
53
+ from collections import OrderedDict
54
+ new_state_dict = OrderedDict()
55
+ for k, v in state_dict_to_load.items():
56
+ new_state_dict[k.replace("_fsdp_wrapped_module.", "")] = v
57
+ state_dict_to_load = new_state_dict
58
+ else:
59
+ state_dict_to_load = state_dict.get('generator', state_dict)
60
+ pipeline.generator.load_state_dict(state_dict_to_load, strict=False)
61
+
62
+ # The codebase assumes bfloat16 on GPU
63
+ pipeline = pipeline.to(device=_DEVICE, dtype=torch.bfloat16)
64
+ pipeline.eval()
65
+
66
+ # Quick sanity path check for Wan models to give friendly errors
67
+ wan_dir = os.path.join('wan_models', 'Wan2.1-T2V-1.3B')
68
+ if not os.path.isdir(wan_dir):
69
+ raise gr.Error(
70
+ "Wan2.1-T2V-1.3B not found at 'wan_models/Wan2.1-T2V-1.3B'.\n"
71
+ "Please download it first, e.g.:\n"
72
+ "huggingface-cli download Wan-AI/Wan2.1-T2V-1.3B --local-dir-use-symlinks False --local-dir wan_models/Wan2.1-T2V-1.3B"
73
+ )
74
+
75
+ _PIPELINE = pipeline
76
+ return _PIPELINE
77
+
78
+
79
+ def build_predict(config_path: str, checkpoint_path: Optional[str], output_dir: str, use_ema: bool):
80
+ os.makedirs(output_dir, exist_ok=True)
81
+
82
+ def predict(prompt: str, num_frames: int) -> str:
83
+ if not prompt or not prompt.strip():
84
+ raise gr.Error("Please enter a non-empty text prompt.")
85
+
86
+ num_frames = int(num_frames)
87
+ if num_frames % 3 != 0 or not (21 <= num_frames <= 252):
88
+ raise gr.Error("Number of frames must be a multiple of 3 between 21 and 252.")
89
+
90
+ pipeline = _load_pipeline(config_path, checkpoint_path, use_ema)
91
+
92
+ # Prepare inputs
93
+ prompts = [prompt.strip()]
94
+ noise = torch.randn([1, num_frames, 16, 60, 104], device=_DEVICE, dtype=torch.bfloat16)
95
+
96
+ torch.set_grad_enabled(False)
97
+ with torch.inference_mode(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
98
+ video = pipeline.inference_rolling_forcing(
99
+ noise=noise,
100
+ text_prompts=prompts,
101
+ return_latents=False,
102
+ initial_latent=None,
103
+ )
104
+
105
+ # video: [B=1, T, C, H, W] in [0,1]
106
+ video = rearrange(video, 'b t c h w -> b t h w c')[0]
107
+ video_uint8 = (video * 255.0).clamp(0, 255).to(torch.uint8).cpu()
108
+
109
+ # Save to a unique filepath
110
+ safe_stub = prompt[:60].replace(' ', '_').replace('/', '_')
111
+ ts = int(time.time())
112
+ filepath = os.path.join(output_dir, f"{safe_stub or 'video'}_{ts}.mp4")
113
+ write_video(filepath, video_uint8, fps=16)
114
+ print(f"Saved generated video to {filepath}")
115
+
116
+ return filepath
117
+
118
+ return predict
119
+
120
+
121
+ def main():
122
+ parser = argparse.ArgumentParser()
123
+ parser.add_argument('--config_path', type=str, default='configs/rolling_forcing_dmd.yaml',
124
+ help='Path to the model config')
125
+ parser.add_argument('--checkpoint_path', type=str, default='checkpoints/rolling_forcing_dmd.pt',
126
+ help='Path to rolling forcing checkpoint (.pt). If missing, will run with base weights only if available.')
127
+ parser.add_argument('--output_dir', type=str, default='videos/gradio', help='Where to save generated videos')
128
+ parser.add_argument('--no_ema', action='store_true', help='Disable EMA weights when loading checkpoint')
129
+ args = parser.parse_args()
130
+
131
+
132
+ # Download checkpoint from HuggingFace if not present
133
+ # 1️⃣ Equivalent to:
134
+ # huggingface-cli download Wan-AI/Wan2.1-T2V-1.3B --local-dir wan_models/Wan2.1-T2V-1.3B
135
+ wan_model_dir = snapshot_download(
136
+ repo_id="Wan-AI/Wan2.1-T2V-1.3B",
137
+ local_dir="wan_models/Wan2.1-T2V-1.3B",
138
+ local_dir_use_symlinks=False, # same as --local-dir-use-symlinks False
139
+ )
140
+ print("Wan model downloaded to:", wan_model_dir)
141
+
142
+ # 2️⃣ Equivalent to:
143
+ # huggingface-cli download TencentARC/RollingForcing checkpoints/rolling_forcing_dmd.pt --local-dir .
144
+ rolling_ckpt_path = hf_hub_download(
145
+ repo_id="TencentARC/RollingForcing",
146
+ filename="checkpoints/rolling_forcing_dmd.pt",
147
+ local_dir=".", # where to store it
148
+ local_dir_use_symlinks=False,
149
+ )
150
+ print("RollingForcing checkpoint downloaded to:", rolling_ckpt_path)
151
+
152
+ predict = build_predict(
153
+ config_path=args.config_path,
154
+ checkpoint_path=args.checkpoint_path,
155
+ output_dir=args.output_dir,
156
+ use_ema=not args.no_ema,
157
+ )
158
+
159
+ demo = gr.Interface(
160
+ fn=predict,
161
+ inputs=[
162
+ gr.Textbox(label="Text Prompt", lines=2, placeholder="A cinematic shot of a girl dancing in the sunset."),
163
+ gr.Slider(label="Number of Latent Frames", minimum=21, maximum=252, step=3, value=21),
164
+ ],
165
+ outputs=gr.Video(label="Generated Video", format="mp4"),
166
+ title="Rolling Forcing: Autoregressive Long Video Diffusion in Real Time",
167
+ description=(
168
+ "Enter a prompt and generate a video using the Rolling Forcing pipeline.\n"
169
+ "**Note:** although Rolling Forcing generates videos autoregressivelty, current Gradio demo does not support streaming outputs, so the entire video will be generated before it is displayed.\n"
170
+ "\n"
171
+ "If you find this demo useful, please consider giving it a ⭐ star on [GitHub](https://github.com/TencentARC/RollingForcing)--your support is crucial for sustaining this open-source project. "
172
+ "You can also dive deeper by reading the [paper](https://arxiv.org/abs/2509.25161) or exploring the [project page](https://kunhao-liu.github.io/Rolling_Forcing_Webpage) for more details."
173
+ ),
174
+ allow_flagging='never',
175
+ )
176
+
177
+ try:
178
+ # Gradio <= 3.x
179
+ demo.queue(concurrency_count=1, max_size=2)
180
+ except TypeError:
181
+ # Gradio >= 4.x
182
+ demo.queue(max_size=2)
183
+ demo.launch(show_error=True)
184
+
185
+
186
+ if __name__ == "__main__":
187
+ main()
configs/default_config.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ independent_first_frame: false
2
+ warp_denoising_step: false
3
+ weight_decay: 0.01
4
+ same_step_across_blocks: true
5
+ discriminator_lr_multiplier: 1.0
6
+ last_step_only: false
7
+ i2v: false
8
+ num_training_frames: 27
9
+ gc_interval: 100
10
+ context_noise: 0
11
+ causal: true
12
+
13
+ ckpt_step: 0
14
+ prompt_name: MovieGenVideoBench
15
+ prompt_path: prompts/MovieGenVideoBench.txt
16
+ eval_first_n: 64
17
+ num_samples: 1
18
+ height: 480
19
+ width: 832
20
+ num_frames: 81
configs/rolling_forcing_dmd.yaml ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ generator_ckpt: checkpoints/ode_init.pt
2
+ generator_fsdp_wrap_strategy: size
3
+ real_score_fsdp_wrap_strategy: size
4
+ fake_score_fsdp_wrap_strategy: size
5
+ real_name: Wan2.1-T2V-14B
6
+ text_encoder_fsdp_wrap_strategy: size
7
+ denoising_step_list:
8
+ - 1000
9
+ - 800
10
+ - 600
11
+ - 400
12
+ - 200
13
+ warp_denoising_step: true # need to remove - 0 in denoising_step_list if warp_denoising_step is true
14
+ ts_schedule: false
15
+ num_train_timestep: 1000
16
+ timestep_shift: 5.0
17
+ guidance_scale: 3.0
18
+ denoising_loss_type: flow
19
+ mixed_precision: true
20
+ seed: 0
21
+ sharding_strategy: hybrid_full
22
+ lr: 1.5e-06
23
+ lr_critic: 4.0e-07
24
+ beta1: 0.0
25
+ beta2: 0.999
26
+ beta1_critic: 0.0
27
+ beta2_critic: 0.999
28
+ data_path: prompts/vidprom_filtered_extended.txt
29
+ batch_size: 1
30
+ ema_weight: 0.99
31
+ ema_start_step: 200
32
+ total_batch_size: 64
33
+ log_iters: 100
34
+ negative_prompt: '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走'
35
+ dfake_gen_update_ratio: 5
36
+ image_or_video_shape:
37
+ - 1
38
+ - 21
39
+ - 16
40
+ - 60
41
+ - 104
42
+ distribution_loss: dmd
43
+ trainer: score_distillation
44
+ gradient_checkpointing: true
45
+ num_frame_per_block: 3
46
+ load_raw_video: false
47
+ model_kwargs:
48
+ timestep_shift: 5.0
inference.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ import os
4
+ from omegaconf import OmegaConf
5
+ from collections import OrderedDict
6
+ from tqdm import tqdm
7
+ from torchvision import transforms
8
+ from torchvision.io import write_video
9
+ from einops import rearrange
10
+ import torch.distributed as dist
11
+ import imageio
12
+ from torch.utils.data import DataLoader, SequentialSampler
13
+ from torch.utils.data.distributed import DistributedSampler
14
+
15
+ from pipeline import (
16
+ CausalDiffusionInferencePipeline,
17
+ CausalInferencePipeline
18
+ )
19
+ from utils.dataset import TextDataset, TextImagePairDataset
20
+ from utils.misc import set_seed
21
+
22
+ parser = argparse.ArgumentParser()
23
+ parser.add_argument("--config_path", type=str, help="Path to the config file")
24
+ parser.add_argument("--checkpoint_path", type=str, help="Path to the checkpoint folder")
25
+ parser.add_argument("--data_path", type=str, help="Path to the dataset")
26
+ parser.add_argument("--extended_prompt_path", type=str, help="Path to the extended prompt")
27
+ parser.add_argument("--output_folder", type=str, help="Output folder")
28
+ parser.add_argument("--num_output_frames", type=int, default=21,
29
+ help="Number of overlap frames between sliding windows")
30
+ parser.add_argument("--i2v", action="store_true", help="Whether to perform I2V (or T2V by default)")
31
+ parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA parameters")
32
+ parser.add_argument("--seed", type=int, default=0, help="Random seed")
33
+ parser.add_argument("--num_samples", type=int, default=1, help="Number of samples to generate per prompt")
34
+ parser.add_argument("--save_with_index", action="store_true",
35
+ help="Whether to save the video using the index or prompt as the filename")
36
+ args = parser.parse_args()
37
+
38
+ # Initialize distributed inference
39
+ if "LOCAL_RANK" in os.environ:
40
+ dist.init_process_group(backend='nccl')
41
+ local_rank = int(os.environ["LOCAL_RANK"])
42
+ torch.cuda.set_device(local_rank)
43
+ device = torch.device(f"cuda:{local_rank}")
44
+ world_size = dist.get_world_size()
45
+ set_seed(args.seed + local_rank)
46
+ else:
47
+ device = torch.device("cuda")
48
+ local_rank = 0
49
+ world_size = 1
50
+ set_seed(args.seed)
51
+
52
+ torch.set_grad_enabled(False)
53
+
54
+ config = OmegaConf.load(args.config_path)
55
+ default_config = OmegaConf.load("configs/default_config.yaml")
56
+ config = OmegaConf.merge(default_config, config)
57
+
58
+ # Initialize pipeline
59
+ if hasattr(config, 'denoising_step_list'):
60
+ # Few-step inference
61
+ pipeline = CausalInferencePipeline(config, device=device)
62
+ else:
63
+ # Multi-step diffusion inference
64
+ pipeline = CausalDiffusionInferencePipeline(config, device=device)
65
+
66
+ if args.checkpoint_path:
67
+ state_dict = torch.load(args.checkpoint_path, map_location="cpu")
68
+ if args.use_ema:
69
+ state_dict_to_load = state_dict['generator_ema']
70
+ def remove_fsdp_prefix(state_dict):
71
+ new_state_dict = OrderedDict()
72
+ for key, value in state_dict.items():
73
+ if "_fsdp_wrapped_module." in key:
74
+ new_key = key.replace("_fsdp_wrapped_module.", "")
75
+ new_state_dict[new_key] = value
76
+ else:
77
+ new_state_dict[key] = value
78
+ return new_state_dict
79
+ state_dict_to_load = remove_fsdp_prefix(state_dict_to_load)
80
+ else:
81
+ state_dict_to_load = state_dict['generator']
82
+ pipeline.generator.load_state_dict(state_dict_to_load)
83
+
84
+ pipeline = pipeline.to(device=device, dtype=torch.bfloat16)
85
+
86
+ # Create dataset
87
+ if args.i2v:
88
+ assert not dist.is_initialized(), "I2V does not support distributed inference yet"
89
+ transform = transforms.Compose([
90
+ transforms.Resize((480, 832)),
91
+ transforms.ToTensor(),
92
+ transforms.Normalize([0.5], [0.5])
93
+ ])
94
+ dataset = TextImagePairDataset(args.data_path, transform=transform)
95
+ else:
96
+ dataset = TextDataset(prompt_path=args.data_path, extended_prompt_path=args.extended_prompt_path)
97
+ num_prompts = len(dataset)
98
+ print(f"Number of prompts: {num_prompts}")
99
+
100
+ if dist.is_initialized():
101
+ sampler = DistributedSampler(dataset, shuffle=False, drop_last=True)
102
+ else:
103
+ sampler = SequentialSampler(dataset)
104
+ dataloader = DataLoader(dataset, batch_size=1, sampler=sampler, num_workers=0, drop_last=False)
105
+
106
+ # Create output directory (only on main process to avoid race conditions)
107
+ if local_rank == 0:
108
+ os.makedirs(args.output_folder, exist_ok=True)
109
+
110
+ if dist.is_initialized():
111
+ dist.barrier()
112
+
113
+
114
+ def encode(self, videos: torch.Tensor) -> torch.Tensor:
115
+ device, dtype = videos[0].device, videos[0].dtype
116
+ scale = [self.mean.to(device=device, dtype=dtype),
117
+ 1.0 / self.std.to(device=device, dtype=dtype)]
118
+ output = [
119
+ self.model.encode(u.unsqueeze(0), scale).float().squeeze(0)
120
+ for u in videos
121
+ ]
122
+
123
+ output = torch.stack(output, dim=0)
124
+ return output
125
+
126
+
127
+ for i, batch_data in tqdm(enumerate(dataloader), disable=(local_rank != 0)):
128
+ idx = batch_data['idx'].item()
129
+
130
+ # For DataLoader batch_size=1, the batch_data is already a single item, but in a batch container
131
+ # Unpack the batch data for convenience
132
+ if isinstance(batch_data, dict):
133
+ batch = batch_data
134
+ elif isinstance(batch_data, list):
135
+ batch = batch_data[0] # First (and only) item in the batch
136
+
137
+ all_video = []
138
+ num_generated_frames = 0 # Number of generated (latent) frames
139
+
140
+ if args.i2v:
141
+ # For image-to-video, batch contains image and caption
142
+ prompt = batch['prompts'][0] # Get caption from batch
143
+ prompts = [prompt] * args.num_samples
144
+
145
+ # Process the image
146
+ image = batch['image'].squeeze(0).unsqueeze(0).unsqueeze(2).to(device=device, dtype=torch.bfloat16)
147
+
148
+ # Encode the input image as the first latent
149
+ initial_latent = pipeline.vae.encode_to_latent(image).to(device=device, dtype=torch.bfloat16)
150
+ initial_latent = initial_latent.repeat(args.num_samples, 1, 1, 1, 1)
151
+
152
+ sampled_noise = torch.randn(
153
+ [args.num_samples, args.num_output_frames - 1, 16, 60, 104], device=device, dtype=torch.bfloat16
154
+ )
155
+ else:
156
+ # For text-to-video, batch is just the text prompt
157
+ prompt = batch['prompts'][0]
158
+ extended_prompt = batch['extended_prompts'][0] if 'extended_prompts' in batch else None
159
+ if extended_prompt is not None:
160
+ prompts = [extended_prompt] * args.num_samples
161
+ else:
162
+ prompts = [prompt] * args.num_samples
163
+ initial_latent = None
164
+
165
+ sampled_noise = torch.randn(
166
+ [args.num_samples, args.num_output_frames, 16, 60, 104], device=device, dtype=torch.bfloat16
167
+ )
168
+
169
+ # Generate 81 frames
170
+ video, latents = pipeline.inference_rolling_forcing(
171
+ noise=sampled_noise,
172
+ text_prompts=prompts,
173
+ return_latents=True,
174
+ initial_latent=initial_latent,
175
+ )
176
+ current_video = rearrange(video, 'b t c h w -> b t h w c').cpu()
177
+ all_video.append(current_video)
178
+ num_generated_frames += latents.shape[1]
179
+
180
+ # Final output video
181
+ video = 255.0 * torch.cat(all_video, dim=1)
182
+
183
+ # Clear VAE cache
184
+ pipeline.vae.model.clear_cache()
185
+
186
+ # Save the video if the current prompt is not a dummy prompt
187
+ if idx < num_prompts:
188
+ model = "regular" if not args.use_ema else "ema"
189
+ for seed_idx in range(args.num_samples):
190
+ # All processes save their videos
191
+ if args.save_with_index:
192
+ output_path = os.path.join(args.output_folder, f'{idx}-{seed_idx}_{model}.mp4')
193
+ else:
194
+ output_path = os.path.join(args.output_folder, f'{prompt[:100]}-{seed_idx}.mp4')
195
+ write_video(output_path, video[seed_idx], fps=16)
196
+ # imageio.mimwrite(output_path, video[seed_idx], fps=16, quality=8, output_params=["-loglevel", "error"])
197
+
model/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .diffusion import CausalDiffusion
2
+ from .causvid import CausVid
3
+ from .dmd import DMD
4
+ from .gan import GAN
5
+ from .sid import SiD
6
+ from .ode_regression import ODERegression
7
+ __all__ = [
8
+ "CausalDiffusion",
9
+ "CausVid",
10
+ "DMD",
11
+ "GAN",
12
+ "SiD",
13
+ "ODERegression"
14
+ ]
model/base.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+ from einops import rearrange
3
+ from torch import nn
4
+ import torch.distributed as dist
5
+ import torch
6
+
7
+ from pipeline import RollingForcingTrainingPipeline
8
+ from utils.loss import get_denoising_loss
9
+ from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder, WanVAEWrapper
10
+
11
+
12
+ class BaseModel(nn.Module):
13
+ def __init__(self, args, device):
14
+ super().__init__()
15
+ self._initialize_models(args, device)
16
+
17
+ self.device = device
18
+ self.args = args
19
+ self.dtype = torch.bfloat16 if args.mixed_precision else torch.float32
20
+ if hasattr(args, "denoising_step_list"):
21
+ self.denoising_step_list = torch.tensor(args.denoising_step_list, dtype=torch.long)
22
+ if args.warp_denoising_step:
23
+ timesteps = torch.cat((self.scheduler.timesteps.cpu(), torch.tensor([0], dtype=torch.float32)))
24
+ self.denoising_step_list = timesteps[1000 - self.denoising_step_list]
25
+
26
+ def _initialize_models(self, args, device):
27
+ self.real_model_name = getattr(args, "real_name", "Wan2.1-T2V-1.3B")
28
+ self.fake_model_name = getattr(args, "fake_name", "Wan2.1-T2V-1.3B")
29
+
30
+ self.generator = WanDiffusionWrapper(**getattr(args, "model_kwargs", {}), is_causal=True)
31
+ self.generator.model.requires_grad_(True)
32
+
33
+ self.real_score = WanDiffusionWrapper(model_name=self.real_model_name, is_causal=False)
34
+ self.real_score.model.requires_grad_(False)
35
+
36
+ self.fake_score = WanDiffusionWrapper(model_name=self.fake_model_name, is_causal=False)
37
+ self.fake_score.model.requires_grad_(True)
38
+
39
+ self.text_encoder = WanTextEncoder()
40
+ self.text_encoder.requires_grad_(False)
41
+
42
+ self.vae = WanVAEWrapper()
43
+ self.vae.requires_grad_(False)
44
+
45
+ self.scheduler = self.generator.get_scheduler()
46
+ self.scheduler.timesteps = self.scheduler.timesteps.to(device)
47
+
48
+ def _get_timestep(
49
+ self,
50
+ min_timestep: int,
51
+ max_timestep: int,
52
+ batch_size: int,
53
+ num_frame: int,
54
+ num_frame_per_block: int,
55
+ uniform_timestep: bool = False
56
+ ) -> torch.Tensor:
57
+ """
58
+ Randomly generate a timestep tensor based on the generator's task type. It uniformly samples a timestep
59
+ from the range [min_timestep, max_timestep], and returns a tensor of shape [batch_size, num_frame].
60
+ - If uniform_timestep, it will use the same timestep for all frames.
61
+ - If not uniform_timestep, it will use a different timestep for each block.
62
+ """
63
+ if uniform_timestep:
64
+ timestep = torch.randint(
65
+ min_timestep,
66
+ max_timestep,
67
+ [batch_size, 1],
68
+ device=self.device,
69
+ dtype=torch.long
70
+ ).repeat(1, num_frame)
71
+ return timestep
72
+ else:
73
+ timestep = torch.randint(
74
+ min_timestep,
75
+ max_timestep,
76
+ [batch_size, num_frame],
77
+ device=self.device,
78
+ dtype=torch.long
79
+ )
80
+ # make the noise level the same within every block
81
+ if self.independent_first_frame:
82
+ # the first frame is always kept the same
83
+ timestep_from_second = timestep[:, 1:]
84
+ timestep_from_second = timestep_from_second.reshape(
85
+ timestep_from_second.shape[0], -1, num_frame_per_block)
86
+ timestep_from_second[:, :, 1:] = timestep_from_second[:, :, 0:1]
87
+ timestep_from_second = timestep_from_second.reshape(
88
+ timestep_from_second.shape[0], -1)
89
+ timestep = torch.cat([timestep[:, 0:1], timestep_from_second], dim=1)
90
+ else:
91
+ timestep = timestep.reshape(
92
+ timestep.shape[0], -1, num_frame_per_block)
93
+ timestep[:, :, 1:] = timestep[:, :, 0:1]
94
+ timestep = timestep.reshape(timestep.shape[0], -1)
95
+ return timestep
96
+
97
+
98
+ class RollingForcingModel(BaseModel):
99
+ def __init__(self, args, device):
100
+ super().__init__(args, device)
101
+ self.denoising_loss_func = get_denoising_loss(args.denoising_loss_type)()
102
+
103
+ def _run_generator(
104
+ self,
105
+ image_or_video_shape,
106
+ conditional_dict: dict,
107
+ initial_latent: torch.tensor = None
108
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
109
+ """
110
+ Optionally simulate the generator's input from noise using backward simulation
111
+ and then run the generator for one-step.
112
+ Input:
113
+ - image_or_video_shape: a list containing the shape of the image or video [B, F, C, H, W].
114
+ - conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
115
+ - unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings).
116
+ - clean_latent: a tensor containing the clean latents [B, F, C, H, W]. Need to be passed when no backward simulation is used.
117
+ - initial_latent: a tensor containing the initial latents [B, F, C, H, W].
118
+ Output:
119
+ - pred_image: a tensor with shape [B, F, C, H, W].
120
+ - denoised_timestep: an integer
121
+ """
122
+ # Step 1: Sample noise and backward simulate the generator's input
123
+ assert getattr(self.args, "backward_simulation", True), "Backward simulation needs to be enabled"
124
+ if initial_latent is not None:
125
+ conditional_dict["initial_latent"] = initial_latent
126
+ if self.args.i2v:
127
+ noise_shape = [image_or_video_shape[0], image_or_video_shape[1] - 1, *image_or_video_shape[2:]]
128
+ else:
129
+ noise_shape = image_or_video_shape.copy()
130
+
131
+ # During training, the number of generated frames should be uniformly sampled from
132
+ # [21, self.num_training_frames], but still being a multiple of self.num_frame_per_block
133
+ min_num_frames = 20 if self.args.independent_first_frame else 21
134
+ max_num_frames = self.num_training_frames - 1 if self.args.independent_first_frame else self.num_training_frames
135
+ assert max_num_frames % self.num_frame_per_block == 0
136
+ assert min_num_frames % self.num_frame_per_block == 0
137
+ max_num_blocks = max_num_frames // self.num_frame_per_block
138
+ min_num_blocks = min_num_frames // self.num_frame_per_block
139
+ num_generated_blocks = torch.randint(min_num_blocks, max_num_blocks + 1, (1,), device=self.device)
140
+ dist.broadcast(num_generated_blocks, src=0)
141
+ num_generated_blocks = num_generated_blocks.item()
142
+ num_generated_frames = num_generated_blocks * self.num_frame_per_block
143
+ if self.args.independent_first_frame and initial_latent is None:
144
+ num_generated_frames += 1
145
+ min_num_frames += 1
146
+ # Sync num_generated_frames across all processes
147
+ noise_shape[1] = num_generated_frames
148
+
149
+ pred_image_or_video, denoised_timestep_from, denoised_timestep_to = self._consistency_backward_simulation(
150
+ noise=torch.randn(noise_shape,
151
+ device=self.device, dtype=self.dtype),
152
+ **conditional_dict,
153
+ )
154
+ # Slice last 21 frames
155
+ if pred_image_or_video.shape[1] > 21:
156
+ with torch.no_grad():
157
+ # Reencode to get image latent
158
+ latent_to_decode = pred_image_or_video[:, :-20, ...]
159
+ # Deccode to video
160
+ pixels = self.vae.decode_to_pixel(latent_to_decode)
161
+ frame = pixels[:, -1:, ...].to(self.dtype)
162
+ frame = rearrange(frame, "b t c h w -> b c t h w")
163
+ # Encode frame to get image latent
164
+ image_latent = self.vae.encode_to_latent(frame).to(self.dtype)
165
+ pred_image_or_video_last_21 = torch.cat([image_latent, pred_image_or_video[:, -20:, ...]], dim=1)
166
+ else:
167
+ pred_image_or_video_last_21 = pred_image_or_video
168
+
169
+ if num_generated_frames != min_num_frames:
170
+ # Currently, we do not use gradient for the first chunk, since it contains image latents
171
+ gradient_mask = torch.ones_like(pred_image_or_video_last_21, dtype=torch.bool)
172
+ if self.args.independent_first_frame:
173
+ gradient_mask[:, :1] = False
174
+ else:
175
+ gradient_mask[:, :self.num_frame_per_block] = False
176
+ else:
177
+ gradient_mask = None
178
+
179
+ pred_image_or_video_last_21 = pred_image_or_video_last_21.to(self.dtype)
180
+ return pred_image_or_video_last_21, gradient_mask, denoised_timestep_from, denoised_timestep_to
181
+
182
+ def _consistency_backward_simulation(
183
+ self,
184
+ noise: torch.Tensor,
185
+ **conditional_dict: dict
186
+ ) -> torch.Tensor:
187
+ """
188
+ Simulate the generator's input from noise to avoid training/inference mismatch.
189
+ See Sec 4.5 of the DMD2 paper (https://arxiv.org/abs/2405.14867) for details.
190
+ Here we use the consistency sampler (https://arxiv.org/abs/2303.01469)
191
+ Input:
192
+ - noise: a tensor sampled from N(0, 1) with shape [B, F, C, H, W] where the number of frame is 1 for images.
193
+ - conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
194
+ Output:
195
+ - output: a tensor with shape [B, T, F, C, H, W].
196
+ T is the total number of timesteps. output[0] is a pure noise and output[i] and i>0
197
+ represents the x0 prediction at each timestep.
198
+ """
199
+ if self.inference_pipeline is None:
200
+ self._initialize_inference_pipeline()
201
+
202
+ infer_w_rolling = torch.rand(1, device=self.device) > 0.5
203
+ dist.broadcast(infer_w_rolling, src=0)
204
+
205
+ if infer_w_rolling:
206
+ return self.inference_pipeline.inference_with_rolling_forcing(
207
+ noise=noise, **conditional_dict
208
+ )
209
+ else:
210
+ return self.inference_pipeline.inference_with_self_forcing(
211
+ noise=noise, **conditional_dict
212
+ )
213
+
214
+ def _initialize_inference_pipeline(self):
215
+ """
216
+ Lazy initialize the inference pipeline during the first backward simulation run.
217
+ Here we encapsulate the inference code with a model-dependent outside function.
218
+ We pass our FSDP-wrapped modules into the pipeline to save memory.
219
+ """
220
+ self.inference_pipeline = RollingForcingTrainingPipeline(
221
+ denoising_step_list=self.denoising_step_list,
222
+ scheduler=self.scheduler,
223
+ generator=self.generator,
224
+ num_frame_per_block=self.num_frame_per_block,
225
+ independent_first_frame=self.args.independent_first_frame,
226
+ same_step_across_blocks=self.args.same_step_across_blocks,
227
+ last_step_only=self.args.last_step_only,
228
+ num_max_frames=self.num_training_frames,
229
+ context_noise=self.args.context_noise
230
+ )
model/causvid.py ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn.functional as F
2
+ from typing import Tuple
3
+ import torch
4
+
5
+ from model.base import BaseModel
6
+
7
+
8
+ class CausVid(BaseModel):
9
+ def __init__(self, args, device):
10
+ """
11
+ Initialize the DMD (Distribution Matching Distillation) module.
12
+ This class is self-contained and compute generator and fake score losses
13
+ in the forward pass.
14
+ """
15
+ super().__init__(args, device)
16
+ self.num_frame_per_block = getattr(args, "num_frame_per_block", 1)
17
+ self.num_training_frames = getattr(args, "num_training_frames", 21)
18
+
19
+ if self.num_frame_per_block > 1:
20
+ self.generator.model.num_frame_per_block = self.num_frame_per_block
21
+
22
+ self.independent_first_frame = getattr(args, "independent_first_frame", False)
23
+ if self.independent_first_frame:
24
+ self.generator.model.independent_first_frame = True
25
+ if args.gradient_checkpointing:
26
+ self.generator.enable_gradient_checkpointing()
27
+ self.fake_score.enable_gradient_checkpointing()
28
+
29
+ # Step 2: Initialize all dmd hyperparameters
30
+ self.num_train_timestep = args.num_train_timestep
31
+ self.min_step = int(0.02 * self.num_train_timestep)
32
+ self.max_step = int(0.98 * self.num_train_timestep)
33
+ if hasattr(args, "real_guidance_scale"):
34
+ self.real_guidance_scale = args.real_guidance_scale
35
+ self.fake_guidance_scale = args.fake_guidance_scale
36
+ else:
37
+ self.real_guidance_scale = args.guidance_scale
38
+ self.fake_guidance_scale = 0.0
39
+ self.timestep_shift = getattr(args, "timestep_shift", 1.0)
40
+ self.teacher_forcing = getattr(args, "teacher_forcing", False)
41
+
42
+ if getattr(self.scheduler, "alphas_cumprod", None) is not None:
43
+ self.scheduler.alphas_cumprod = self.scheduler.alphas_cumprod.to(device)
44
+ else:
45
+ self.scheduler.alphas_cumprod = None
46
+
47
+ def _compute_kl_grad(
48
+ self, noisy_image_or_video: torch.Tensor,
49
+ estimated_clean_image_or_video: torch.Tensor,
50
+ timestep: torch.Tensor,
51
+ conditional_dict: dict, unconditional_dict: dict,
52
+ normalization: bool = True
53
+ ) -> Tuple[torch.Tensor, dict]:
54
+ """
55
+ Compute the KL grad (eq 7 in https://arxiv.org/abs/2311.18828).
56
+ Input:
57
+ - noisy_image_or_video: a tensor with shape [B, F, C, H, W] where the number of frame is 1 for images.
58
+ - estimated_clean_image_or_video: a tensor with shape [B, F, C, H, W] representing the estimated clean image or video.
59
+ - timestep: a tensor with shape [B, F] containing the randomly generated timestep.
60
+ - conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
61
+ - unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings).
62
+ - normalization: a boolean indicating whether to normalize the gradient.
63
+ Output:
64
+ - kl_grad: a tensor representing the KL grad.
65
+ - kl_log_dict: a dictionary containing the intermediate tensors for logging.
66
+ """
67
+ # Step 1: Compute the fake score
68
+ _, pred_fake_image_cond = self.fake_score(
69
+ noisy_image_or_video=noisy_image_or_video,
70
+ conditional_dict=conditional_dict,
71
+ timestep=timestep
72
+ )
73
+
74
+ if self.fake_guidance_scale != 0.0:
75
+ _, pred_fake_image_uncond = self.fake_score(
76
+ noisy_image_or_video=noisy_image_or_video,
77
+ conditional_dict=unconditional_dict,
78
+ timestep=timestep
79
+ )
80
+ pred_fake_image = pred_fake_image_cond + (
81
+ pred_fake_image_cond - pred_fake_image_uncond
82
+ ) * self.fake_guidance_scale
83
+ else:
84
+ pred_fake_image = pred_fake_image_cond
85
+
86
+ # Step 2: Compute the real score
87
+ # We compute the conditional and unconditional prediction
88
+ # and add them together to achieve cfg (https://arxiv.org/abs/2207.12598)
89
+ _, pred_real_image_cond = self.real_score(
90
+ noisy_image_or_video=noisy_image_or_video,
91
+ conditional_dict=conditional_dict,
92
+ timestep=timestep
93
+ )
94
+
95
+ _, pred_real_image_uncond = self.real_score(
96
+ noisy_image_or_video=noisy_image_or_video,
97
+ conditional_dict=unconditional_dict,
98
+ timestep=timestep
99
+ )
100
+
101
+ pred_real_image = pred_real_image_cond + (
102
+ pred_real_image_cond - pred_real_image_uncond
103
+ ) * self.real_guidance_scale
104
+
105
+ # Step 3: Compute the DMD gradient (DMD paper eq. 7).
106
+ grad = (pred_fake_image - pred_real_image)
107
+
108
+ # TODO: Change the normalizer for causal teacher
109
+ if normalization:
110
+ # Step 4: Gradient normalization (DMD paper eq. 8).
111
+ p_real = (estimated_clean_image_or_video - pred_real_image)
112
+ normalizer = torch.abs(p_real).mean(dim=[1, 2, 3, 4], keepdim=True)
113
+ grad = grad / normalizer
114
+ grad = torch.nan_to_num(grad)
115
+
116
+ return grad, {
117
+ "dmdtrain_gradient_norm": torch.mean(torch.abs(grad)).detach(),
118
+ "timestep": timestep.detach()
119
+ }
120
+
121
+ def compute_distribution_matching_loss(
122
+ self,
123
+ image_or_video: torch.Tensor,
124
+ conditional_dict: dict,
125
+ unconditional_dict: dict,
126
+ gradient_mask: torch.Tensor = None,
127
+ ) -> Tuple[torch.Tensor, dict]:
128
+ """
129
+ Compute the DMD loss (eq 7 in https://arxiv.org/abs/2311.18828).
130
+ Input:
131
+ - image_or_video: a tensor with shape [B, F, C, H, W] where the number of frame is 1 for images.
132
+ - conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
133
+ - unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings).
134
+ - gradient_mask: a boolean tensor with the same shape as image_or_video indicating which pixels to compute loss .
135
+ Output:
136
+ - dmd_loss: a scalar tensor representing the DMD loss.
137
+ - dmd_log_dict: a dictionary containing the intermediate tensors for logging.
138
+ """
139
+ original_latent = image_or_video
140
+
141
+ batch_size, num_frame = image_or_video.shape[:2]
142
+
143
+ with torch.no_grad():
144
+ # Step 1: Randomly sample timestep based on the given schedule and corresponding noise
145
+ timestep = self._get_timestep(
146
+ 0,
147
+ self.num_train_timestep,
148
+ batch_size,
149
+ num_frame,
150
+ self.num_frame_per_block,
151
+ uniform_timestep=True
152
+ )
153
+
154
+ if self.timestep_shift > 1:
155
+ timestep = self.timestep_shift * \
156
+ (timestep / 1000) / \
157
+ (1 + (self.timestep_shift - 1) * (timestep / 1000)) * 1000
158
+ timestep = timestep.clamp(self.min_step, self.max_step)
159
+
160
+ noise = torch.randn_like(image_or_video)
161
+ noisy_latent = self.scheduler.add_noise(
162
+ image_or_video.flatten(0, 1),
163
+ noise.flatten(0, 1),
164
+ timestep.flatten(0, 1)
165
+ ).detach().unflatten(0, (batch_size, num_frame))
166
+
167
+ # Step 2: Compute the KL grad
168
+ grad, dmd_log_dict = self._compute_kl_grad(
169
+ noisy_image_or_video=noisy_latent,
170
+ estimated_clean_image_or_video=original_latent,
171
+ timestep=timestep,
172
+ conditional_dict=conditional_dict,
173
+ unconditional_dict=unconditional_dict
174
+ )
175
+
176
+ if gradient_mask is not None:
177
+ dmd_loss = 0.5 * F.mse_loss(original_latent.double(
178
+ )[gradient_mask], (original_latent.double() - grad.double()).detach()[gradient_mask], reduction="mean")
179
+ else:
180
+ dmd_loss = 0.5 * F.mse_loss(original_latent.double(
181
+ ), (original_latent.double() - grad.double()).detach(), reduction="mean")
182
+ return dmd_loss, dmd_log_dict
183
+
184
+ def _run_generator(
185
+ self,
186
+ image_or_video_shape,
187
+ conditional_dict: dict,
188
+ clean_latent: torch.tensor
189
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
190
+ """
191
+ Optionally simulate the generator's input from noise using backward simulation
192
+ and then run the generator for one-step.
193
+ Input:
194
+ - image_or_video_shape: a list containing the shape of the image or video [B, F, C, H, W].
195
+ - conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
196
+ - unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings).
197
+ - clean_latent: a tensor containing the clean latents [B, F, C, H, W]. Need to be passed when no backward simulation is used.
198
+ - initial_latent: a tensor containing the initial latents [B, F, C, H, W].
199
+ Output:
200
+ - pred_image: a tensor with shape [B, F, C, H, W].
201
+ """
202
+ simulated_noisy_input = []
203
+ for timestep in self.denoising_step_list:
204
+ noise = torch.randn(
205
+ image_or_video_shape, device=self.device, dtype=self.dtype)
206
+
207
+ noisy_timestep = timestep * torch.ones(
208
+ image_or_video_shape[:2], device=self.device, dtype=torch.long)
209
+
210
+ if timestep != 0:
211
+ noisy_image = self.scheduler.add_noise(
212
+ clean_latent.flatten(0, 1),
213
+ noise.flatten(0, 1),
214
+ noisy_timestep.flatten(0, 1)
215
+ ).unflatten(0, image_or_video_shape[:2])
216
+ else:
217
+ noisy_image = clean_latent
218
+
219
+ simulated_noisy_input.append(noisy_image)
220
+
221
+ simulated_noisy_input = torch.stack(simulated_noisy_input, dim=1)
222
+
223
+ # Step 2: Randomly sample a timestep and pick the corresponding input
224
+ index = self._get_timestep(
225
+ 0,
226
+ len(self.denoising_step_list),
227
+ image_or_video_shape[0],
228
+ image_or_video_shape[1],
229
+ self.num_frame_per_block,
230
+ uniform_timestep=False
231
+ )
232
+
233
+ # select the corresponding timestep's noisy input from the stacked tensor [B, T, F, C, H, W]
234
+ noisy_input = torch.gather(
235
+ simulated_noisy_input, dim=1,
236
+ index=index.reshape(index.shape[0], 1, index.shape[1], 1, 1, 1).expand(
237
+ -1, -1, -1, *image_or_video_shape[2:]).to(self.device)
238
+ ).squeeze(1)
239
+
240
+ timestep = self.denoising_step_list[index].to(self.device)
241
+
242
+ _, pred_image_or_video = self.generator(
243
+ noisy_image_or_video=noisy_input,
244
+ conditional_dict=conditional_dict,
245
+ timestep=timestep,
246
+ clean_x=clean_latent if self.teacher_forcing else None,
247
+ )
248
+
249
+ gradient_mask = None # timestep != 0
250
+
251
+ pred_image_or_video = pred_image_or_video.type_as(noisy_input)
252
+
253
+ return pred_image_or_video, gradient_mask
254
+
255
+ def generator_loss(
256
+ self,
257
+ image_or_video_shape,
258
+ conditional_dict: dict,
259
+ unconditional_dict: dict,
260
+ clean_latent: torch.Tensor,
261
+ initial_latent: torch.Tensor = None
262
+ ) -> Tuple[torch.Tensor, dict]:
263
+ """
264
+ Generate image/videos from noise and compute the DMD loss.
265
+ The noisy input to the generator is backward simulated.
266
+ This removes the need of any datasets during distillation.
267
+ See Sec 4.5 of the DMD2 paper (https://arxiv.org/abs/2405.14867) for details.
268
+ Input:
269
+ - image_or_video_shape: a list containing the shape of the image or video [B, F, C, H, W].
270
+ - conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
271
+ - unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings).
272
+ - clean_latent: a tensor containing the clean latents [B, F, C, H, W]. Need to be passed when no backward simulation is used.
273
+ Output:
274
+ - loss: a scalar tensor representing the generator loss.
275
+ - generator_log_dict: a dictionary containing the intermediate tensors for logging.
276
+ """
277
+ # Step 1: Run generator on backward simulated noisy input
278
+ pred_image, gradient_mask = self._run_generator(
279
+ image_or_video_shape=image_or_video_shape,
280
+ conditional_dict=conditional_dict,
281
+ clean_latent=clean_latent
282
+ )
283
+
284
+ # Step 2: Compute the DMD loss
285
+ dmd_loss, dmd_log_dict = self.compute_distribution_matching_loss(
286
+ image_or_video=pred_image,
287
+ conditional_dict=conditional_dict,
288
+ unconditional_dict=unconditional_dict,
289
+ gradient_mask=gradient_mask
290
+ )
291
+
292
+ # Step 3: TODO: Implement the GAN loss
293
+
294
+ return dmd_loss, dmd_log_dict
295
+
296
+ def critic_loss(
297
+ self,
298
+ image_or_video_shape,
299
+ conditional_dict: dict,
300
+ unconditional_dict: dict,
301
+ clean_latent: torch.Tensor,
302
+ initial_latent: torch.Tensor = None
303
+ ) -> Tuple[torch.Tensor, dict]:
304
+ """
305
+ Generate image/videos from noise and train the critic with generated samples.
306
+ The noisy input to the generator is backward simulated.
307
+ This removes the need of any datasets during distillation.
308
+ See Sec 4.5 of the DMD2 paper (https://arxiv.org/abs/2405.14867) for details.
309
+ Input:
310
+ - image_or_video_shape: a list containing the shape of the image or video [B, F, C, H, W].
311
+ - conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
312
+ - unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings).
313
+ - clean_latent: a tensor containing the clean latents [B, F, C, H, W]. Need to be passed when no backward simulation is used.
314
+ Output:
315
+ - loss: a scalar tensor representing the generator loss.
316
+ - critic_log_dict: a dictionary containing the intermediate tensors for logging.
317
+ """
318
+
319
+ # Step 1: Run generator on backward simulated noisy input
320
+ with torch.no_grad():
321
+ generated_image, _ = self._run_generator(
322
+ image_or_video_shape=image_or_video_shape,
323
+ conditional_dict=conditional_dict,
324
+ clean_latent=clean_latent
325
+ )
326
+
327
+ # Step 2: Compute the fake prediction
328
+ critic_timestep = self._get_timestep(
329
+ 0,
330
+ self.num_train_timestep,
331
+ image_or_video_shape[0],
332
+ image_or_video_shape[1],
333
+ self.num_frame_per_block,
334
+ uniform_timestep=True
335
+ )
336
+
337
+ if self.timestep_shift > 1:
338
+ critic_timestep = self.timestep_shift * \
339
+ (critic_timestep / 1000) / (1 + (self.timestep_shift - 1) * (critic_timestep / 1000)) * 1000
340
+
341
+ critic_timestep = critic_timestep.clamp(self.min_step, self.max_step)
342
+
343
+ critic_noise = torch.randn_like(generated_image)
344
+ noisy_generated_image = self.scheduler.add_noise(
345
+ generated_image.flatten(0, 1),
346
+ critic_noise.flatten(0, 1),
347
+ critic_timestep.flatten(0, 1)
348
+ ).unflatten(0, image_or_video_shape[:2])
349
+
350
+ _, pred_fake_image = self.fake_score(
351
+ noisy_image_or_video=noisy_generated_image,
352
+ conditional_dict=conditional_dict,
353
+ timestep=critic_timestep
354
+ )
355
+
356
+ # Step 3: Compute the denoising loss for the fake critic
357
+ if self.args.denoising_loss_type == "flow":
358
+ from utils.wan_wrapper import WanDiffusionWrapper
359
+ flow_pred = WanDiffusionWrapper._convert_x0_to_flow_pred(
360
+ scheduler=self.scheduler,
361
+ x0_pred=pred_fake_image.flatten(0, 1),
362
+ xt=noisy_generated_image.flatten(0, 1),
363
+ timestep=critic_timestep.flatten(0, 1)
364
+ )
365
+ pred_fake_noise = None
366
+ else:
367
+ flow_pred = None
368
+ pred_fake_noise = self.scheduler.convert_x0_to_noise(
369
+ x0=pred_fake_image.flatten(0, 1),
370
+ xt=noisy_generated_image.flatten(0, 1),
371
+ timestep=critic_timestep.flatten(0, 1)
372
+ ).unflatten(0, image_or_video_shape[:2])
373
+
374
+ denoising_loss = self.denoising_loss_func(
375
+ x=generated_image.flatten(0, 1),
376
+ x_pred=pred_fake_image.flatten(0, 1),
377
+ noise=critic_noise.flatten(0, 1),
378
+ noise_pred=pred_fake_noise,
379
+ alphas_cumprod=self.scheduler.alphas_cumprod,
380
+ timestep=critic_timestep.flatten(0, 1),
381
+ flow_pred=flow_pred
382
+ )
383
+
384
+ # Step 4: TODO: Compute the GAN loss
385
+
386
+ # Step 5: Debugging Log
387
+ critic_log_dict = {
388
+ "critic_timestep": critic_timestep.detach()
389
+ }
390
+
391
+ return denoising_loss, critic_log_dict
model/diffusion.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+ import torch
3
+
4
+ from model.base import BaseModel
5
+ from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder, WanVAEWrapper
6
+
7
+
8
+ class CausalDiffusion(BaseModel):
9
+ def __init__(self, args, device):
10
+ """
11
+ Initialize the Diffusion loss module.
12
+ """
13
+ super().__init__(args, device)
14
+ self.num_frame_per_block = getattr(args, "num_frame_per_block", 1)
15
+ if self.num_frame_per_block > 1:
16
+ self.generator.model.num_frame_per_block = self.num_frame_per_block
17
+ self.independent_first_frame = getattr(args, "independent_first_frame", False)
18
+ if self.independent_first_frame:
19
+ self.generator.model.independent_first_frame = True
20
+
21
+ if args.gradient_checkpointing:
22
+ self.generator.enable_gradient_checkpointing()
23
+
24
+ # Step 2: Initialize all hyperparameters
25
+ self.num_train_timestep = args.num_train_timestep
26
+ self.min_step = int(0.02 * self.num_train_timestep)
27
+ self.max_step = int(0.98 * self.num_train_timestep)
28
+ self.guidance_scale = args.guidance_scale
29
+ self.timestep_shift = getattr(args, "timestep_shift", 1.0)
30
+ self.teacher_forcing = getattr(args, "teacher_forcing", False)
31
+ # Noise augmentation in teacher forcing, we add small noise to clean context latents
32
+ self.noise_augmentation_max_timestep = getattr(args, "noise_augmentation_max_timestep", 0)
33
+
34
+ def _initialize_models(self, args):
35
+ self.generator = WanDiffusionWrapper(**getattr(args, "model_kwargs", {}), is_causal=True)
36
+ self.generator.model.requires_grad_(True)
37
+
38
+ self.text_encoder = WanTextEncoder()
39
+ self.text_encoder.requires_grad_(False)
40
+
41
+ self.vae = WanVAEWrapper()
42
+ self.vae.requires_grad_(False)
43
+
44
+ def generator_loss(
45
+ self,
46
+ image_or_video_shape,
47
+ conditional_dict: dict,
48
+ unconditional_dict: dict,
49
+ clean_latent: torch.Tensor,
50
+ initial_latent: torch.Tensor = None
51
+ ) -> Tuple[torch.Tensor, dict]:
52
+ """
53
+ Generate image/videos from noise and compute the DMD loss.
54
+ The noisy input to the generator is backward simulated.
55
+ This removes the need of any datasets during distillation.
56
+ See Sec 4.5 of the DMD2 paper (https://arxiv.org/abs/2405.14867) for details.
57
+ Input:
58
+ - image_or_video_shape: a list containing the shape of the image or video [B, F, C, H, W].
59
+ - conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
60
+ - unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings).
61
+ - clean_latent: a tensor containing the clean latents [B, F, C, H, W]. Need to be passed when no backward simulation is used.
62
+ Output:
63
+ - loss: a scalar tensor representing the generator loss.
64
+ - generator_log_dict: a dictionary containing the intermediate tensors for logging.
65
+ """
66
+ noise = torch.randn_like(clean_latent)
67
+ batch_size, num_frame = image_or_video_shape[:2]
68
+
69
+ # Step 2: Randomly sample a timestep and add noise to denoiser inputs
70
+ index = self._get_timestep(
71
+ 0,
72
+ self.scheduler.num_train_timesteps,
73
+ image_or_video_shape[0],
74
+ image_or_video_shape[1],
75
+ self.num_frame_per_block,
76
+ uniform_timestep=False
77
+ )
78
+ timestep = self.scheduler.timesteps[index].to(dtype=self.dtype, device=self.device)
79
+ noisy_latents = self.scheduler.add_noise(
80
+ clean_latent.flatten(0, 1),
81
+ noise.flatten(0, 1),
82
+ timestep.flatten(0, 1)
83
+ ).unflatten(0, (batch_size, num_frame))
84
+ training_target = self.scheduler.training_target(clean_latent, noise, timestep)
85
+
86
+ # Step 3: Noise augmentation, also add small noise to clean context latents
87
+ if self.noise_augmentation_max_timestep > 0:
88
+ index_clean_aug = self._get_timestep(
89
+ 0,
90
+ self.noise_augmentation_max_timestep,
91
+ image_or_video_shape[0],
92
+ image_or_video_shape[1],
93
+ self.num_frame_per_block,
94
+ uniform_timestep=False
95
+ )
96
+ timestep_clean_aug = self.scheduler.timesteps[index_clean_aug].to(dtype=self.dtype, device=self.device)
97
+ clean_latent_aug = self.scheduler.add_noise(
98
+ clean_latent.flatten(0, 1),
99
+ noise.flatten(0, 1),
100
+ timestep_clean_aug.flatten(0, 1)
101
+ ).unflatten(0, (batch_size, num_frame))
102
+ else:
103
+ clean_latent_aug = clean_latent
104
+ timestep_clean_aug = None
105
+
106
+ # Compute loss
107
+ flow_pred, x0_pred = self.generator(
108
+ noisy_image_or_video=noisy_latents,
109
+ conditional_dict=conditional_dict,
110
+ timestep=timestep,
111
+ clean_x=clean_latent_aug if self.teacher_forcing else None,
112
+ aug_t=timestep_clean_aug if self.teacher_forcing else None
113
+ )
114
+ # loss = torch.nn.functional.mse_loss(flow_pred.float(), training_target.float())
115
+ loss = torch.nn.functional.mse_loss(
116
+ flow_pred.float(), training_target.float(), reduction='none'
117
+ ).mean(dim=(2, 3, 4))
118
+ loss = loss * self.scheduler.training_weight(timestep).unflatten(0, (batch_size, num_frame))
119
+ loss = loss.mean()
120
+
121
+ log_dict = {
122
+ "x0": clean_latent.detach(),
123
+ "x0_pred": x0_pred.detach()
124
+ }
125
+ return loss, log_dict
model/dmd.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pipeline import RollingForcingTrainingPipeline
2
+ import torch.nn.functional as F
3
+ from typing import Optional, Tuple
4
+ import torch
5
+
6
+ from model.base import RollingForcingModel
7
+
8
+
9
+ class DMD(RollingForcingModel):
10
+ def __init__(self, args, device):
11
+ """
12
+ Initialize the DMD (Distribution Matching Distillation) module.
13
+ This class is self-contained and compute generator and fake score losses
14
+ in the forward pass.
15
+ """
16
+ super().__init__(args, device)
17
+ self.num_frame_per_block = getattr(args, "num_frame_per_block", 1)
18
+ self.same_step_across_blocks = getattr(args, "same_step_across_blocks", True)
19
+ self.num_training_frames = getattr(args, "num_training_frames", 21)
20
+
21
+ if self.num_frame_per_block > 1:
22
+ self.generator.model.num_frame_per_block = self.num_frame_per_block
23
+
24
+ self.independent_first_frame = getattr(args, "independent_first_frame", False)
25
+ if self.independent_first_frame:
26
+ self.generator.model.independent_first_frame = True
27
+ if args.gradient_checkpointing:
28
+ self.generator.enable_gradient_checkpointing()
29
+ self.fake_score.enable_gradient_checkpointing()
30
+
31
+ # this will be init later with fsdp-wrapped modules
32
+ self.inference_pipeline: RollingForcingTrainingPipeline = None
33
+
34
+ # Step 2: Initialize all dmd hyperparameters
35
+ self.num_train_timestep = args.num_train_timestep
36
+ self.min_step = int(0.02 * self.num_train_timestep)
37
+ self.max_step = int(0.98 * self.num_train_timestep)
38
+ if hasattr(args, "real_guidance_scale"):
39
+ self.real_guidance_scale = args.real_guidance_scale
40
+ self.fake_guidance_scale = args.fake_guidance_scale
41
+ else:
42
+ self.real_guidance_scale = args.guidance_scale
43
+ self.fake_guidance_scale = 0.0
44
+ self.timestep_shift = getattr(args, "timestep_shift", 1.0)
45
+ self.ts_schedule = getattr(args, "ts_schedule", True)
46
+ self.ts_schedule_max = getattr(args, "ts_schedule_max", False)
47
+ self.min_score_timestep = getattr(args, "min_score_timestep", 0)
48
+
49
+ if getattr(self.scheduler, "alphas_cumprod", None) is not None:
50
+ self.scheduler.alphas_cumprod = self.scheduler.alphas_cumprod.to(device)
51
+ else:
52
+ self.scheduler.alphas_cumprod = None
53
+
54
+ def _compute_kl_grad(
55
+ self, noisy_image_or_video: torch.Tensor,
56
+ estimated_clean_image_or_video: torch.Tensor,
57
+ timestep: torch.Tensor,
58
+ conditional_dict: dict, unconditional_dict: dict,
59
+ normalization: bool = True
60
+ ) -> Tuple[torch.Tensor, dict]:
61
+ """
62
+ Compute the KL grad (eq 7 in https://arxiv.org/abs/2311.18828).
63
+ Input:
64
+ - noisy_image_or_video: a tensor with shape [B, F, C, H, W] where the number of frame is 1 for images.
65
+ - estimated_clean_image_or_video: a tensor with shape [B, F, C, H, W] representing the estimated clean image or video.
66
+ - timestep: a tensor with shape [B, F] containing the randomly generated timestep.
67
+ - conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
68
+ - unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings).
69
+ - normalization: a boolean indicating whether to normalize the gradient.
70
+ Output:
71
+ - kl_grad: a tensor representing the KL grad.
72
+ - kl_log_dict: a dictionary containing the intermediate tensors for logging.
73
+ """
74
+ # Step 1: Compute the fake score
75
+ _, pred_fake_image_cond = self.fake_score(
76
+ noisy_image_or_video=noisy_image_or_video,
77
+ conditional_dict=conditional_dict,
78
+ timestep=timestep
79
+ )
80
+
81
+ if self.fake_guidance_scale != 0.0:
82
+ _, pred_fake_image_uncond = self.fake_score(
83
+ noisy_image_or_video=noisy_image_or_video,
84
+ conditional_dict=unconditional_dict,
85
+ timestep=timestep
86
+ )
87
+ pred_fake_image = pred_fake_image_cond + (
88
+ pred_fake_image_cond - pred_fake_image_uncond
89
+ ) * self.fake_guidance_scale
90
+ else:
91
+ pred_fake_image = pred_fake_image_cond
92
+
93
+ # Step 2: Compute the real score
94
+ # We compute the conditional and unconditional prediction
95
+ # and add them together to achieve cfg (https://arxiv.org/abs/2207.12598)
96
+ _, pred_real_image_cond = self.real_score(
97
+ noisy_image_or_video=noisy_image_or_video,
98
+ conditional_dict=conditional_dict,
99
+ timestep=timestep
100
+ )
101
+
102
+ _, pred_real_image_uncond = self.real_score(
103
+ noisy_image_or_video=noisy_image_or_video,
104
+ conditional_dict=unconditional_dict,
105
+ timestep=timestep
106
+ )
107
+
108
+ pred_real_image = pred_real_image_cond + (
109
+ pred_real_image_cond - pred_real_image_uncond
110
+ ) * self.real_guidance_scale
111
+
112
+ # Step 3: Compute the DMD gradient (DMD paper eq. 7).
113
+ grad = (pred_fake_image - pred_real_image)
114
+
115
+ # TODO: Change the normalizer for causal teacher
116
+ if normalization:
117
+ # Step 4: Gradient normalization (DMD paper eq. 8).
118
+ p_real = (estimated_clean_image_or_video - pred_real_image)
119
+ normalizer = torch.abs(p_real).mean(dim=[1, 2, 3, 4], keepdim=True)
120
+ grad = grad / normalizer
121
+ grad = torch.nan_to_num(grad)
122
+
123
+ return grad, {
124
+ "dmdtrain_gradient_norm": torch.mean(torch.abs(grad)).detach(),
125
+ "timestep": timestep.detach()
126
+ }
127
+
128
+ def compute_distribution_matching_loss(
129
+ self,
130
+ image_or_video: torch.Tensor,
131
+ conditional_dict: dict,
132
+ unconditional_dict: dict,
133
+ gradient_mask: Optional[torch.Tensor] = None,
134
+ denoised_timestep_from: int = 0,
135
+ denoised_timestep_to: int = 0
136
+ ) -> Tuple[torch.Tensor, dict]:
137
+ """
138
+ Compute the DMD loss (eq 7 in https://arxiv.org/abs/2311.18828).
139
+ Input:
140
+ - image_or_video: a tensor with shape [B, F, C, H, W] where the number of frame is 1 for images.
141
+ - conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
142
+ - unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings).
143
+ - gradient_mask: a boolean tensor with the same shape as image_or_video indicating which pixels to compute loss .
144
+ Output:
145
+ - dmd_loss: a scalar tensor representing the DMD loss.
146
+ - dmd_log_dict: a dictionary containing the intermediate tensors for logging.
147
+ """
148
+ original_latent = image_or_video
149
+
150
+ batch_size, num_frame = image_or_video.shape[:2]
151
+
152
+ with torch.no_grad():
153
+ # Step 1: Randomly sample timestep based on the given schedule and corresponding noise
154
+ min_timestep = denoised_timestep_to if self.ts_schedule and denoised_timestep_to is not None else self.min_score_timestep
155
+ max_timestep = denoised_timestep_from if self.ts_schedule_max and denoised_timestep_from is not None else self.num_train_timestep
156
+ timestep = self._get_timestep(
157
+ min_timestep,
158
+ max_timestep,
159
+ batch_size,
160
+ num_frame,
161
+ self.num_frame_per_block,
162
+ uniform_timestep=True
163
+ )
164
+
165
+ # TODO:should we change it to `timestep = self.scheduler.timesteps[timestep]`?
166
+ if self.timestep_shift > 1:
167
+ timestep = self.timestep_shift * \
168
+ (timestep / 1000) / \
169
+ (1 + (self.timestep_shift - 1) * (timestep / 1000)) * 1000
170
+ timestep = timestep.clamp(self.min_step, self.max_step)
171
+
172
+ noise = torch.randn_like(image_or_video)
173
+ noisy_latent = self.scheduler.add_noise(
174
+ image_or_video.flatten(0, 1),
175
+ noise.flatten(0, 1),
176
+ timestep.flatten(0, 1)
177
+ ).detach().unflatten(0, (batch_size, num_frame))
178
+
179
+ # Step 2: Compute the KL grad
180
+ grad, dmd_log_dict = self._compute_kl_grad(
181
+ noisy_image_or_video=noisy_latent,
182
+ estimated_clean_image_or_video=original_latent,
183
+ timestep=timestep,
184
+ conditional_dict=conditional_dict,
185
+ unconditional_dict=unconditional_dict
186
+ )
187
+
188
+ if gradient_mask is not None:
189
+ dmd_loss = 0.5 * F.mse_loss(original_latent.double(
190
+ )[gradient_mask], (original_latent.double() - grad.double()).detach()[gradient_mask], reduction="mean")
191
+ else:
192
+ dmd_loss = 0.5 * F.mse_loss(original_latent.double(
193
+ ), (original_latent.double() - grad.double()).detach(), reduction="mean")
194
+ return dmd_loss, dmd_log_dict
195
+
196
+ def generator_loss(
197
+ self,
198
+ image_or_video_shape,
199
+ conditional_dict: dict,
200
+ unconditional_dict: dict,
201
+ clean_latent: torch.Tensor,
202
+ initial_latent: torch.Tensor = None
203
+ ) -> Tuple[torch.Tensor, dict]:
204
+ """
205
+ Generate image/videos from noise and compute the DMD loss.
206
+ The noisy input to the generator is backward simulated.
207
+ This removes the need of any datasets during distillation.
208
+ See Sec 4.5 of the DMD2 paper (https://arxiv.org/abs/2405.14867) for details.
209
+ Input:
210
+ - image_or_video_shape: a list containing the shape of the image or video [B, F, C, H, W].
211
+ - conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
212
+ - unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings).
213
+ - clean_latent: a tensor containing the clean latents [B, F, C, H, W]. Need to be passed when no backward simulation is used.
214
+ Output:
215
+ - loss: a scalar tensor representing the generator loss.
216
+ - generator_log_dict: a dictionary containing the intermediate tensors for logging.
217
+ """
218
+ # Step 1: Unroll generator to obtain fake videos
219
+ pred_image, gradient_mask, denoised_timestep_from, denoised_timestep_to = self._run_generator(
220
+ image_or_video_shape=image_or_video_shape,
221
+ conditional_dict=conditional_dict,
222
+ initial_latent=initial_latent
223
+ )
224
+
225
+ # Step 2: Compute the DMD loss
226
+ dmd_loss, dmd_log_dict = self.compute_distribution_matching_loss(
227
+ image_or_video=pred_image,
228
+ conditional_dict=conditional_dict,
229
+ unconditional_dict=unconditional_dict,
230
+ gradient_mask=gradient_mask,
231
+ denoised_timestep_from=denoised_timestep_from,
232
+ denoised_timestep_to=denoised_timestep_to
233
+ )
234
+
235
+ return dmd_loss, dmd_log_dict
236
+
237
+ def critic_loss(
238
+ self,
239
+ image_or_video_shape,
240
+ conditional_dict: dict,
241
+ unconditional_dict: dict,
242
+ clean_latent: torch.Tensor,
243
+ initial_latent: torch.Tensor = None
244
+ ) -> Tuple[torch.Tensor, dict]:
245
+ """
246
+ Generate image/videos from noise and train the critic with generated samples.
247
+ The noisy input to the generator is backward simulated.
248
+ This removes the need of any datasets during distillation.
249
+ See Sec 4.5 of the DMD2 paper (https://arxiv.org/abs/2405.14867) for details.
250
+ Input:
251
+ - image_or_video_shape: a list containing the shape of the image or video [B, F, C, H, W].
252
+ - conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
253
+ - unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings).
254
+ - clean_latent: a tensor containing the clean latents [B, F, C, H, W]. Need to be passed when no backward simulation is used.
255
+ Output:
256
+ - loss: a scalar tensor representing the generator loss.
257
+ - critic_log_dict: a dictionary containing the intermediate tensors for logging.
258
+ """
259
+
260
+ # Step 1: Run generator on backward simulated noisy input
261
+ with torch.no_grad():
262
+ generated_image, _, denoised_timestep_from, denoised_timestep_to = self._run_generator(
263
+ image_or_video_shape=image_or_video_shape,
264
+ conditional_dict=conditional_dict,
265
+ initial_latent=initial_latent
266
+ )
267
+
268
+ # Step 2: Compute the fake prediction
269
+ min_timestep = denoised_timestep_to if self.ts_schedule and denoised_timestep_to is not None else self.min_score_timestep
270
+ max_timestep = denoised_timestep_from if self.ts_schedule_max and denoised_timestep_from is not None else self.num_train_timestep
271
+ critic_timestep = self._get_timestep(
272
+ min_timestep,
273
+ max_timestep,
274
+ image_or_video_shape[0],
275
+ image_or_video_shape[1],
276
+ self.num_frame_per_block,
277
+ uniform_timestep=True
278
+ )
279
+
280
+ if self.timestep_shift > 1:
281
+ critic_timestep = self.timestep_shift * \
282
+ (critic_timestep / 1000) / (1 + (self.timestep_shift - 1) * (critic_timestep / 1000)) * 1000
283
+
284
+ critic_timestep = critic_timestep.clamp(self.min_step, self.max_step)
285
+
286
+ critic_noise = torch.randn_like(generated_image)
287
+ noisy_generated_image = self.scheduler.add_noise(
288
+ generated_image.flatten(0, 1),
289
+ critic_noise.flatten(0, 1),
290
+ critic_timestep.flatten(0, 1)
291
+ ).unflatten(0, image_or_video_shape[:2])
292
+
293
+ _, pred_fake_image = self.fake_score(
294
+ noisy_image_or_video=noisy_generated_image,
295
+ conditional_dict=conditional_dict,
296
+ timestep=critic_timestep
297
+ )
298
+
299
+ # Step 3: Compute the denoising loss for the fake critic
300
+ if self.args.denoising_loss_type == "flow":
301
+ from utils.wan_wrapper import WanDiffusionWrapper
302
+ flow_pred = WanDiffusionWrapper._convert_x0_to_flow_pred(
303
+ scheduler=self.scheduler,
304
+ x0_pred=pred_fake_image.flatten(0, 1),
305
+ xt=noisy_generated_image.flatten(0, 1),
306
+ timestep=critic_timestep.flatten(0, 1)
307
+ )
308
+ pred_fake_noise = None
309
+ else:
310
+ flow_pred = None
311
+ pred_fake_noise = self.scheduler.convert_x0_to_noise(
312
+ x0=pred_fake_image.flatten(0, 1),
313
+ xt=noisy_generated_image.flatten(0, 1),
314
+ timestep=critic_timestep.flatten(0, 1)
315
+ ).unflatten(0, image_or_video_shape[:2])
316
+
317
+ denoising_loss = self.denoising_loss_func(
318
+ x=generated_image.flatten(0, 1),
319
+ x_pred=pred_fake_image.flatten(0, 1),
320
+ noise=critic_noise.flatten(0, 1),
321
+ noise_pred=pred_fake_noise,
322
+ alphas_cumprod=self.scheduler.alphas_cumprod,
323
+ timestep=critic_timestep.flatten(0, 1),
324
+ flow_pred=flow_pred
325
+ )
326
+
327
+ # Step 5: Debugging Log
328
+ critic_log_dict = {
329
+ "critic_timestep": critic_timestep.detach()
330
+ }
331
+
332
+ return denoising_loss, critic_log_dict
model/gan.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from pipeline import RollingForcingTrainingPipeline
3
+ import torch.nn.functional as F
4
+ from typing import Tuple
5
+ import torch
6
+
7
+ from model.base import RollingForcingModel
8
+
9
+
10
+ class GAN(RollingForcingModel):
11
+ def __init__(self, args, device):
12
+ """
13
+ Initialize the GAN module.
14
+ This class is self-contained and compute generator and fake score losses
15
+ in the forward pass.
16
+ """
17
+ super().__init__(args, device)
18
+ self.num_frame_per_block = getattr(args, "num_frame_per_block", 1)
19
+ self.same_step_across_blocks = getattr(args, "same_step_across_blocks", True)
20
+ self.concat_time_embeddings = getattr(args, "concat_time_embeddings", False)
21
+ self.num_class = args.num_class
22
+ self.relativistic_discriminator = getattr(args, "relativistic_discriminator", False)
23
+
24
+ if self.num_frame_per_block > 1:
25
+ self.generator.model.num_frame_per_block = self.num_frame_per_block
26
+
27
+ self.fake_score.adding_cls_branch(
28
+ atten_dim=1536, num_class=args.num_class, time_embed_dim=1536 if self.concat_time_embeddings else 0)
29
+ self.fake_score.model.requires_grad_(True)
30
+
31
+ self.independent_first_frame = getattr(args, "independent_first_frame", False)
32
+ if self.independent_first_frame:
33
+ self.generator.model.independent_first_frame = True
34
+ if args.gradient_checkpointing:
35
+ self.generator.enable_gradient_checkpointing()
36
+ self.fake_score.enable_gradient_checkpointing()
37
+
38
+ # this will be init later with fsdp-wrapped modules
39
+ self.inference_pipeline: RollingForcingTrainingPipeline = None
40
+
41
+ # Step 2: Initialize all dmd hyperparameters
42
+ self.num_train_timestep = args.num_train_timestep
43
+ self.min_step = int(0.02 * self.num_train_timestep)
44
+ self.max_step = int(0.98 * self.num_train_timestep)
45
+ if hasattr(args, "real_guidance_scale"):
46
+ self.real_guidance_scale = args.real_guidance_scale
47
+ self.fake_guidance_scale = args.fake_guidance_scale
48
+ else:
49
+ self.real_guidance_scale = args.guidance_scale
50
+ self.fake_guidance_scale = 0.0
51
+ self.timestep_shift = getattr(args, "timestep_shift", 1.0)
52
+ self.critic_timestep_shift = getattr(args, "critic_timestep_shift", self.timestep_shift)
53
+ self.ts_schedule = getattr(args, "ts_schedule", True)
54
+ self.ts_schedule_max = getattr(args, "ts_schedule_max", False)
55
+ self.min_score_timestep = getattr(args, "min_score_timestep", 0)
56
+
57
+ self.gan_g_weight = getattr(args, "gan_g_weight", 1e-2)
58
+ self.gan_d_weight = getattr(args, "gan_d_weight", 1e-2)
59
+ self.r1_weight = getattr(args, "r1_weight", 0.0)
60
+ self.r2_weight = getattr(args, "r2_weight", 0.0)
61
+ self.r1_sigma = getattr(args, "r1_sigma", 0.01)
62
+ self.r2_sigma = getattr(args, "r2_sigma", 0.01)
63
+
64
+ if getattr(self.scheduler, "alphas_cumprod", None) is not None:
65
+ self.scheduler.alphas_cumprod = self.scheduler.alphas_cumprod.to(device)
66
+ else:
67
+ self.scheduler.alphas_cumprod = None
68
+
69
+ def _run_cls_pred_branch(self,
70
+ noisy_image_or_video: torch.Tensor,
71
+ conditional_dict: dict,
72
+ timestep: torch.Tensor) -> torch.Tensor:
73
+ """
74
+ Run the classifier prediction branch on the generated image or video.
75
+ Input:
76
+ - image_or_video: a tensor with shape [B, F, C, H, W].
77
+ Output:
78
+ - cls_pred: a tensor with shape [B, 1, 1, 1, 1] representing the feature map for classification.
79
+ """
80
+ _, _, noisy_logit = self.fake_score(
81
+ noisy_image_or_video=noisy_image_or_video,
82
+ conditional_dict=conditional_dict,
83
+ timestep=timestep,
84
+ classify_mode=True,
85
+ concat_time_embeddings=self.concat_time_embeddings
86
+ )
87
+
88
+ return noisy_logit
89
+
90
+ def generator_loss(
91
+ self,
92
+ image_or_video_shape,
93
+ conditional_dict: dict,
94
+ unconditional_dict: dict,
95
+ clean_latent: torch.Tensor,
96
+ initial_latent: torch.Tensor = None
97
+ ) -> Tuple[torch.Tensor, dict]:
98
+ """
99
+ Generate image/videos from noise and compute the DMD loss.
100
+ The noisy input to the generator is backward simulated.
101
+ This removes the need of any datasets during distillation.
102
+ See Sec 4.5 of the DMD2 paper (https://arxiv.org/abs/2405.14867) for details.
103
+ Input:
104
+ - image_or_video_shape: a list containing the shape of the image or video [B, F, C, H, W].
105
+ - conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
106
+ - unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings).
107
+ - clean_latent: a tensor containing the clean latents [B, F, C, H, W]. Need to be passed when no backward simulation is used.
108
+ Output:
109
+ - loss: a scalar tensor representing the generator loss.
110
+ - generator_log_dict: a dictionary containing the intermediate tensors for logging.
111
+ """
112
+ # Step 1: Unroll generator to obtain fake videos
113
+ pred_image, gradient_mask, denoised_timestep_from, denoised_timestep_to = self._run_generator(
114
+ image_or_video_shape=image_or_video_shape,
115
+ conditional_dict=conditional_dict,
116
+ initial_latent=initial_latent
117
+ )
118
+
119
+ # Step 2: Get timestep and add noise to generated/real latents
120
+ min_timestep = denoised_timestep_to if self.ts_schedule and denoised_timestep_to is not None else self.min_score_timestep
121
+ max_timestep = denoised_timestep_from if self.ts_schedule_max and denoised_timestep_from is not None else self.num_train_timestep
122
+ critic_timestep = self._get_timestep(
123
+ min_timestep,
124
+ max_timestep,
125
+ image_or_video_shape[0],
126
+ image_or_video_shape[1],
127
+ self.num_frame_per_block,
128
+ uniform_timestep=True
129
+ )
130
+
131
+ if self.critic_timestep_shift > 1:
132
+ critic_timestep = self.critic_timestep_shift * \
133
+ (critic_timestep / 1000) / (1 + (self.critic_timestep_shift - 1) * (critic_timestep / 1000)) * 1000
134
+
135
+ critic_timestep = critic_timestep.clamp(self.min_step, self.max_step)
136
+
137
+ critic_noise = torch.randn_like(pred_image)
138
+ noisy_fake_latent = self.scheduler.add_noise(
139
+ pred_image.flatten(0, 1),
140
+ critic_noise.flatten(0, 1),
141
+ critic_timestep.flatten(0, 1)
142
+ ).unflatten(0, image_or_video_shape[:2])
143
+
144
+ # Step 4: Compute the real GAN discriminator loss
145
+ real_image_or_video = clean_latent.clone()
146
+ critic_noise = torch.randn_like(real_image_or_video)
147
+ noisy_real_latent = self.scheduler.add_noise(
148
+ real_image_or_video.flatten(0, 1),
149
+ critic_noise.flatten(0, 1),
150
+ critic_timestep.flatten(0, 1)
151
+ ).unflatten(0, image_or_video_shape[:2])
152
+
153
+ conditional_dict["prompt_embeds"] = torch.concatenate(
154
+ (conditional_dict["prompt_embeds"], conditional_dict["prompt_embeds"]), dim=0)
155
+ critic_timestep = torch.concatenate((critic_timestep, critic_timestep), dim=0)
156
+ noisy_latent = torch.concatenate((noisy_fake_latent, noisy_real_latent), dim=0)
157
+ _, _, noisy_logit = self.fake_score(
158
+ noisy_image_or_video=noisy_latent,
159
+ conditional_dict=conditional_dict,
160
+ timestep=critic_timestep,
161
+ classify_mode=True,
162
+ concat_time_embeddings=self.concat_time_embeddings
163
+ )
164
+ noisy_fake_logit, noisy_real_logit = noisy_logit.chunk(2, dim=0)
165
+
166
+ if not self.relativistic_discriminator:
167
+ gan_G_loss = F.softplus(-noisy_fake_logit.float()).mean() * self.gan_g_weight
168
+ else:
169
+ relative_fake_logit = noisy_fake_logit - noisy_real_logit
170
+ gan_G_loss = F.softplus(-relative_fake_logit.float()).mean() * self.gan_g_weight
171
+
172
+ return gan_G_loss
173
+
174
+ def critic_loss(
175
+ self,
176
+ image_or_video_shape,
177
+ conditional_dict: dict,
178
+ unconditional_dict: dict,
179
+ clean_latent: torch.Tensor,
180
+ real_image_or_video: torch.Tensor,
181
+ initial_latent: torch.Tensor = None
182
+ ) -> Tuple[torch.Tensor, dict]:
183
+ """
184
+ Generate image/videos from noise and train the critic with generated samples.
185
+ The noisy input to the generator is backward simulated.
186
+ This removes the need of any datasets during distillation.
187
+ See Sec 4.5 of the DMD2 paper (https://arxiv.org/abs/2405.14867) for details.
188
+ Input:
189
+ - image_or_video_shape: a list containing the shape of the image or video [B, F, C, H, W].
190
+ - conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
191
+ - unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings).
192
+ - clean_latent: a tensor containing the clean latents [B, F, C, H, W]. Need to be passed when no backward simulation is used.
193
+ Output:
194
+ - loss: a scalar tensor representing the generator loss.
195
+ - critic_log_dict: a dictionary containing the intermediate tensors for logging.
196
+ """
197
+
198
+ # Step 1: Run generator on backward simulated noisy input
199
+ with torch.no_grad():
200
+ generated_image, _, denoised_timestep_from, denoised_timestep_to, num_sim_steps = self._run_generator(
201
+ image_or_video_shape=image_or_video_shape,
202
+ conditional_dict=conditional_dict,
203
+ initial_latent=initial_latent
204
+ )
205
+
206
+ # Step 2: Get timestep and add noise to generated/real latents
207
+ min_timestep = denoised_timestep_to if self.ts_schedule and denoised_timestep_to is not None else self.min_score_timestep
208
+ max_timestep = denoised_timestep_from if self.ts_schedule_max and denoised_timestep_from is not None else self.num_train_timestep
209
+ critic_timestep = self._get_timestep(
210
+ min_timestep,
211
+ max_timestep,
212
+ image_or_video_shape[0],
213
+ image_or_video_shape[1],
214
+ self.num_frame_per_block,
215
+ uniform_timestep=True
216
+ )
217
+
218
+ if self.critic_timestep_shift > 1:
219
+ critic_timestep = self.critic_timestep_shift * \
220
+ (critic_timestep / 1000) / (1 + (self.critic_timestep_shift - 1) * (critic_timestep / 1000)) * 1000
221
+
222
+ critic_timestep = critic_timestep.clamp(self.min_step, self.max_step)
223
+
224
+ critic_noise = torch.randn_like(generated_image)
225
+ noisy_fake_latent = self.scheduler.add_noise(
226
+ generated_image.flatten(0, 1),
227
+ critic_noise.flatten(0, 1),
228
+ critic_timestep.flatten(0, 1)
229
+ ).unflatten(0, image_or_video_shape[:2])
230
+
231
+ # Step 4: Compute the real GAN discriminator loss
232
+ noisy_real_latent = self.scheduler.add_noise(
233
+ real_image_or_video.flatten(0, 1),
234
+ critic_noise.flatten(0, 1),
235
+ critic_timestep.flatten(0, 1)
236
+ ).unflatten(0, image_or_video_shape[:2])
237
+
238
+ conditional_dict_cloned = copy.deepcopy(conditional_dict)
239
+ conditional_dict_cloned["prompt_embeds"] = torch.concatenate(
240
+ (conditional_dict_cloned["prompt_embeds"], conditional_dict_cloned["prompt_embeds"]), dim=0)
241
+ _, _, noisy_logit = self.fake_score(
242
+ noisy_image_or_video=torch.concatenate((noisy_fake_latent, noisy_real_latent), dim=0),
243
+ conditional_dict=conditional_dict_cloned,
244
+ timestep=torch.concatenate((critic_timestep, critic_timestep), dim=0),
245
+ classify_mode=True,
246
+ concat_time_embeddings=self.concat_time_embeddings
247
+ )
248
+ noisy_fake_logit, noisy_real_logit = noisy_logit.chunk(2, dim=0)
249
+
250
+ if not self.relativistic_discriminator:
251
+ gan_D_loss = F.softplus(-noisy_real_logit.float()).mean() + F.softplus(noisy_fake_logit.float()).mean()
252
+ else:
253
+ relative_real_logit = noisy_real_logit - noisy_fake_logit
254
+ gan_D_loss = F.softplus(-relative_real_logit.float()).mean()
255
+ gan_D_loss = gan_D_loss * self.gan_d_weight
256
+
257
+ # R1 regularization
258
+ if self.r1_weight > 0.:
259
+ noisy_real_latent_perturbed = noisy_real_latent.clone()
260
+ epison_real = self.r1_sigma * torch.randn_like(noisy_real_latent_perturbed)
261
+ noisy_real_latent_perturbed = noisy_real_latent_perturbed + epison_real
262
+ noisy_real_logit_perturbed = self._run_cls_pred_branch(
263
+ noisy_image_or_video=noisy_real_latent_perturbed,
264
+ conditional_dict=conditional_dict,
265
+ timestep=critic_timestep
266
+ )
267
+
268
+ r1_grad = (noisy_real_logit_perturbed - noisy_real_logit) / self.r1_sigma
269
+ r1_loss = self.r1_weight * torch.mean((r1_grad)**2)
270
+ else:
271
+ r1_loss = torch.zeros_like(gan_D_loss)
272
+
273
+ # R2 regularization
274
+ if self.r2_weight > 0.:
275
+ noisy_fake_latent_perturbed = noisy_fake_latent.clone()
276
+ epison_generated = self.r2_sigma * torch.randn_like(noisy_fake_latent_perturbed)
277
+ noisy_fake_latent_perturbed = noisy_fake_latent_perturbed + epison_generated
278
+ noisy_fake_logit_perturbed = self._run_cls_pred_branch(
279
+ noisy_image_or_video=noisy_fake_latent_perturbed,
280
+ conditional_dict=conditional_dict,
281
+ timestep=critic_timestep
282
+ )
283
+
284
+ r2_grad = (noisy_fake_logit_perturbed - noisy_fake_logit) / self.r2_sigma
285
+ r2_loss = self.r2_weight * torch.mean((r2_grad)**2)
286
+ else:
287
+ r2_loss = torch.zeros_like(r2_loss)
288
+
289
+ critic_log_dict = {
290
+ "critic_timestep": critic_timestep.detach(),
291
+ 'noisy_real_logit': noisy_real_logit.detach(),
292
+ 'noisy_fake_logit': noisy_fake_logit.detach(),
293
+ }
294
+
295
+ return (gan_D_loss, r1_loss, r2_loss), critic_log_dict
model/ode_regression.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn.functional as F
2
+ from typing import Tuple
3
+ import torch
4
+
5
+ from model.base import BaseModel
6
+ from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder, WanVAEWrapper
7
+
8
+
9
+ class ODERegression(BaseModel):
10
+ def __init__(self, args, device):
11
+ """
12
+ Initialize the ODERegression module.
13
+ This class is self-contained and compute generator losses
14
+ in the forward pass given precomputed ode solution pairs.
15
+ This class supports the ode regression loss for both causal and bidirectional models.
16
+ See Sec 4.3 of CausVid https://arxiv.org/abs/2412.07772 for details
17
+ """
18
+ super().__init__(args, device)
19
+
20
+ # Step 1: Initialize all models
21
+
22
+ self.generator = WanDiffusionWrapper(**getattr(args, "model_kwargs", {}), is_causal=True)
23
+ self.generator.model.requires_grad_(True)
24
+ if getattr(args, "generator_ckpt", False):
25
+ print(f"Loading pretrained generator from {args.generator_ckpt}")
26
+ state_dict = torch.load(args.generator_ckpt, map_location="cpu")[
27
+ 'generator']
28
+ self.generator.load_state_dict(
29
+ state_dict, strict=True
30
+ )
31
+
32
+ self.num_frame_per_block = getattr(args, "num_frame_per_block", 1)
33
+
34
+ if self.num_frame_per_block > 1:
35
+ self.generator.model.num_frame_per_block = self.num_frame_per_block
36
+
37
+ self.independent_first_frame = getattr(args, "independent_first_frame", False)
38
+ if self.independent_first_frame:
39
+ self.generator.model.independent_first_frame = True
40
+ if args.gradient_checkpointing:
41
+ self.generator.enable_gradient_checkpointing()
42
+
43
+ # Step 2: Initialize all hyperparameters
44
+ self.timestep_shift = getattr(args, "timestep_shift", 1.0)
45
+
46
+ def _initialize_models(self, args):
47
+ self.generator = WanDiffusionWrapper(**getattr(args, "model_kwargs", {}), is_causal=True)
48
+ self.generator.model.requires_grad_(True)
49
+
50
+ self.text_encoder = WanTextEncoder()
51
+ self.text_encoder.requires_grad_(False)
52
+
53
+ self.vae = WanVAEWrapper()
54
+ self.vae.requires_grad_(False)
55
+
56
+ @torch.no_grad()
57
+ def _prepare_generator_input(self, ode_latent: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
58
+ """
59
+ Given a tensor containing the whole ODE sampling trajectories,
60
+ randomly choose an intermediate timestep and return the latent as well as the corresponding timestep.
61
+ Input:
62
+ - ode_latent: a tensor containing the whole ODE sampling trajectories [batch_size, num_denoising_steps, num_frames, num_channels, height, width].
63
+ Output:
64
+ - noisy_input: a tensor containing the selected latent [batch_size, num_frames, num_channels, height, width].
65
+ - timestep: a tensor containing the corresponding timestep [batch_size].
66
+ """
67
+ batch_size, num_denoising_steps, num_frames, num_channels, height, width = ode_latent.shape
68
+
69
+ # Step 1: Randomly choose a timestep for each frame
70
+ index = self._get_timestep(
71
+ 0,
72
+ len(self.denoising_step_list),
73
+ batch_size,
74
+ num_frames,
75
+ self.num_frame_per_block,
76
+ uniform_timestep=False
77
+ )
78
+ if self.args.i2v:
79
+ index[:, 0] = len(self.denoising_step_list) - 1
80
+
81
+ noisy_input = torch.gather(
82
+ ode_latent, dim=1,
83
+ index=index.reshape(batch_size, 1, num_frames, 1, 1, 1).expand(
84
+ -1, -1, -1, num_channels, height, width).to(self.device)
85
+ ).squeeze(1)
86
+
87
+ timestep = self.denoising_step_list[index].to(self.device)
88
+
89
+ # if self.extra_noise_step > 0:
90
+ # random_timestep = torch.randint(0, self.extra_noise_step, [
91
+ # batch_size, num_frames], device=self.device, dtype=torch.long)
92
+ # perturbed_noisy_input = self.scheduler.add_noise(
93
+ # noisy_input.flatten(0, 1),
94
+ # torch.randn_like(noisy_input.flatten(0, 1)),
95
+ # random_timestep.flatten(0, 1)
96
+ # ).detach().unflatten(0, (batch_size, num_frames)).type_as(noisy_input)
97
+
98
+ # noisy_input[timestep == 0] = perturbed_noisy_input[timestep == 0]
99
+
100
+ return noisy_input, timestep
101
+
102
+ def generator_loss(self, ode_latent: torch.Tensor, conditional_dict: dict) -> Tuple[torch.Tensor, dict]:
103
+ """
104
+ Generate image/videos from noisy latents and compute the ODE regression loss.
105
+ Input:
106
+ - ode_latent: a tensor containing the ODE latents [batch_size, num_denoising_steps, num_frames, num_channels, height, width].
107
+ They are ordered from most noisy to clean latents.
108
+ - conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
109
+ Output:
110
+ - loss: a scalar tensor representing the generator loss.
111
+ - log_dict: a dictionary containing additional information for loss timestep breakdown.
112
+ """
113
+ # Step 1: Run generator on noisy latents
114
+ target_latent = ode_latent[:, -1]
115
+
116
+ noisy_input, timestep = self._prepare_generator_input(
117
+ ode_latent=ode_latent)
118
+
119
+ _, pred_image_or_video = self.generator(
120
+ noisy_image_or_video=noisy_input,
121
+ conditional_dict=conditional_dict,
122
+ timestep=timestep
123
+ )
124
+
125
+ # Step 2: Compute the regression loss
126
+ mask = timestep != 0
127
+
128
+ loss = F.mse_loss(
129
+ pred_image_or_video[mask], target_latent[mask], reduction="mean")
130
+
131
+ log_dict = {
132
+ "unnormalized_loss": F.mse_loss(pred_image_or_video, target_latent, reduction='none').mean(dim=[1, 2, 3, 4]).detach(),
133
+ "timestep": timestep.float().mean(dim=1).detach(),
134
+ "input": noisy_input.detach(),
135
+ "output": pred_image_or_video.detach(),
136
+ }
137
+
138
+ return loss, log_dict
model/sid.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pipeline import RollingForcingTrainingPipeline
2
+ from typing import Optional, Tuple
3
+ import torch
4
+
5
+ from model.base import RollingForcingModel
6
+
7
+
8
+ class SiD(RollingForcingModel):
9
+ def __init__(self, args, device):
10
+ """
11
+ Initialize the DMD (Distribution Matching Distillation) module.
12
+ This class is self-contained and compute generator and fake score losses
13
+ in the forward pass.
14
+ """
15
+ super().__init__(args, device)
16
+ self.num_frame_per_block = getattr(args, "num_frame_per_block", 1)
17
+
18
+ if self.num_frame_per_block > 1:
19
+ self.generator.model.num_frame_per_block = self.num_frame_per_block
20
+
21
+ if args.gradient_checkpointing:
22
+ self.generator.enable_gradient_checkpointing()
23
+ self.fake_score.enable_gradient_checkpointing()
24
+ self.real_score.enable_gradient_checkpointing()
25
+
26
+ # this will be init later with fsdp-wrapped modules
27
+ self.inference_pipeline: RollingForcingTrainingPipeline = None
28
+
29
+ # Step 2: Initialize all dmd hyperparameters
30
+ self.num_train_timestep = args.num_train_timestep
31
+ self.min_step = int(0.02 * self.num_train_timestep)
32
+ self.max_step = int(0.98 * self.num_train_timestep)
33
+ if hasattr(args, "real_guidance_scale"):
34
+ self.real_guidance_scale = args.real_guidance_scale
35
+ else:
36
+ self.real_guidance_scale = args.guidance_scale
37
+ self.timestep_shift = getattr(args, "timestep_shift", 1.0)
38
+ self.sid_alpha = getattr(args, "sid_alpha", 1.0)
39
+ self.ts_schedule = getattr(args, "ts_schedule", True)
40
+ self.ts_schedule_max = getattr(args, "ts_schedule_max", False)
41
+
42
+ if getattr(self.scheduler, "alphas_cumprod", None) is not None:
43
+ self.scheduler.alphas_cumprod = self.scheduler.alphas_cumprod.to(device)
44
+ else:
45
+ self.scheduler.alphas_cumprod = None
46
+
47
+ def compute_distribution_matching_loss(
48
+ self,
49
+ image_or_video: torch.Tensor,
50
+ conditional_dict: dict,
51
+ unconditional_dict: dict,
52
+ gradient_mask: Optional[torch.Tensor] = None,
53
+ denoised_timestep_from: int = 0,
54
+ denoised_timestep_to: int = 0
55
+ ) -> Tuple[torch.Tensor, dict]:
56
+ """
57
+ Compute the DMD loss (eq 7 in https://arxiv.org/abs/2311.18828).
58
+ Input:
59
+ - image_or_video: a tensor with shape [B, F, C, H, W] where the number of frame is 1 for images.
60
+ - conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
61
+ - unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings).
62
+ - gradient_mask: a boolean tensor with the same shape as image_or_video indicating which pixels to compute loss .
63
+ Output:
64
+ - dmd_loss: a scalar tensor representing the DMD loss.
65
+ - dmd_log_dict: a dictionary containing the intermediate tensors for logging.
66
+ """
67
+ original_latent = image_or_video
68
+
69
+ batch_size, num_frame = image_or_video.shape[:2]
70
+
71
+ # Step 1: Randomly sample timestep based on the given schedule and corresponding noise
72
+ min_timestep = denoised_timestep_to if self.ts_schedule and denoised_timestep_to is not None else self.min_score_timestep
73
+ max_timestep = denoised_timestep_from if self.ts_schedule_max and denoised_timestep_from is not None else self.num_train_timestep
74
+ timestep = self._get_timestep(
75
+ min_timestep,
76
+ max_timestep,
77
+ batch_size,
78
+ num_frame,
79
+ self.num_frame_per_block,
80
+ uniform_timestep=True
81
+ )
82
+
83
+ if self.timestep_shift > 1:
84
+ timestep = self.timestep_shift * \
85
+ (timestep / 1000) / \
86
+ (1 + (self.timestep_shift - 1) * (timestep / 1000)) * 1000
87
+ timestep = timestep.clamp(self.min_step, self.max_step)
88
+
89
+ noise = torch.randn_like(image_or_video)
90
+ noisy_latent = self.scheduler.add_noise(
91
+ image_or_video.flatten(0, 1),
92
+ noise.flatten(0, 1),
93
+ timestep.flatten(0, 1)
94
+ ).unflatten(0, (batch_size, num_frame))
95
+
96
+ # Step 2: SiD (May be wrap it?)
97
+ noisy_image_or_video = noisy_latent
98
+ # Step 2.1: Compute the fake score
99
+ _, pred_fake_image = self.fake_score(
100
+ noisy_image_or_video=noisy_image_or_video,
101
+ conditional_dict=conditional_dict,
102
+ timestep=timestep
103
+ )
104
+ # Step 2.2: Compute the real score
105
+ # We compute the conditional and unconditional prediction
106
+ # and add them together to achieve cfg (https://arxiv.org/abs/2207.12598)
107
+ # NOTE: This step may cause OOM issue, which can be addressed by the CFG-free technique
108
+
109
+ _, pred_real_image_cond = self.real_score(
110
+ noisy_image_or_video=noisy_image_or_video,
111
+ conditional_dict=conditional_dict,
112
+ timestep=timestep
113
+ )
114
+
115
+ _, pred_real_image_uncond = self.real_score(
116
+ noisy_image_or_video=noisy_image_or_video,
117
+ conditional_dict=unconditional_dict,
118
+ timestep=timestep
119
+ )
120
+
121
+ pred_real_image = pred_real_image_cond + (
122
+ pred_real_image_cond - pred_real_image_uncond
123
+ ) * self.real_guidance_scale
124
+
125
+ # Step 2.3: SiD Loss
126
+ # TODO: Add alpha
127
+ # TODO: Double?
128
+ sid_loss = (pred_real_image.double() - pred_fake_image.double()) * ((pred_real_image.double() - original_latent.double()) - self.sid_alpha * (pred_real_image.double() - pred_fake_image.double()))
129
+
130
+ # Step 2.4: Loss normalizer
131
+ with torch.no_grad():
132
+ p_real = (original_latent - pred_real_image)
133
+ normalizer = torch.abs(p_real).mean(dim=[1, 2, 3, 4], keepdim=True)
134
+ sid_loss = sid_loss / normalizer
135
+
136
+ sid_loss = torch.nan_to_num(sid_loss)
137
+ num_frame = sid_loss.shape[1]
138
+ sid_loss = sid_loss.mean()
139
+
140
+ sid_log_dict = {
141
+ "dmdtrain_gradient_norm": torch.zeros_like(sid_loss),
142
+ "timestep": timestep.detach()
143
+ }
144
+
145
+ return sid_loss, sid_log_dict
146
+
147
+ def generator_loss(
148
+ self,
149
+ image_or_video_shape,
150
+ conditional_dict: dict,
151
+ unconditional_dict: dict,
152
+ clean_latent: torch.Tensor,
153
+ initial_latent: torch.Tensor = None
154
+ ) -> Tuple[torch.Tensor, dict]:
155
+ """
156
+ Generate image/videos from noise and compute the DMD loss.
157
+ The noisy input to the generator is backward simulated.
158
+ This removes the need of any datasets during distillation.
159
+ See Sec 4.5 of the DMD2 paper (https://arxiv.org/abs/2405.14867) for details.
160
+ Input:
161
+ - image_or_video_shape: a list containing the shape of the image or video [B, F, C, H, W].
162
+ - conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
163
+ - unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings).
164
+ - clean_latent: a tensor containing the clean latents [B, F, C, H, W]. Need to be passed when no backward simulation is used.
165
+ Output:
166
+ - loss: a scalar tensor representing the generator loss.
167
+ - generator_log_dict: a dictionary containing the intermediate tensors for logging.
168
+ """
169
+ # Step 1: Unroll generator to obtain fake videos
170
+ pred_image, gradient_mask, denoised_timestep_from, denoised_timestep_to = self._run_generator(
171
+ image_or_video_shape=image_or_video_shape,
172
+ conditional_dict=conditional_dict,
173
+ initial_latent=initial_latent
174
+ )
175
+
176
+ # Step 2: Compute the DMD loss
177
+ dmd_loss, dmd_log_dict = self.compute_distribution_matching_loss(
178
+ image_or_video=pred_image,
179
+ conditional_dict=conditional_dict,
180
+ unconditional_dict=unconditional_dict,
181
+ gradient_mask=gradient_mask,
182
+ denoised_timestep_from=denoised_timestep_from,
183
+ denoised_timestep_to=denoised_timestep_to
184
+ )
185
+
186
+ return dmd_loss, dmd_log_dict
187
+
188
+ def critic_loss(
189
+ self,
190
+ image_or_video_shape,
191
+ conditional_dict: dict,
192
+ unconditional_dict: dict,
193
+ clean_latent: torch.Tensor,
194
+ initial_latent: torch.Tensor = None
195
+ ) -> Tuple[torch.Tensor, dict]:
196
+ """
197
+ Generate image/videos from noise and train the critic with generated samples.
198
+ The noisy input to the generator is backward simulated.
199
+ This removes the need of any datasets during distillation.
200
+ See Sec 4.5 of the DMD2 paper (https://arxiv.org/abs/2405.14867) for details.
201
+ Input:
202
+ - image_or_video_shape: a list containing the shape of the image or video [B, F, C, H, W].
203
+ - conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
204
+ - unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings).
205
+ - clean_latent: a tensor containing the clean latents [B, F, C, H, W]. Need to be passed when no backward simulation is used.
206
+ Output:
207
+ - loss: a scalar tensor representing the generator loss.
208
+ - critic_log_dict: a dictionary containing the intermediate tensors for logging.
209
+ """
210
+
211
+ # Step 1: Run generator on backward simulated noisy input
212
+ with torch.no_grad():
213
+ generated_image, _, denoised_timestep_from, denoised_timestep_to = self._run_generator(
214
+ image_or_video_shape=image_or_video_shape,
215
+ conditional_dict=conditional_dict,
216
+ initial_latent=initial_latent
217
+ )
218
+
219
+ # Step 2: Compute the fake prediction
220
+ min_timestep = denoised_timestep_to if self.ts_schedule and denoised_timestep_to is not None else self.min_score_timestep
221
+ max_timestep = denoised_timestep_from if self.ts_schedule_max and denoised_timestep_from is not None else self.num_train_timestep
222
+ critic_timestep = self._get_timestep(
223
+ min_timestep,
224
+ max_timestep,
225
+ image_or_video_shape[0],
226
+ image_or_video_shape[1],
227
+ self.num_frame_per_block,
228
+ uniform_timestep=True
229
+ )
230
+
231
+ if self.timestep_shift > 1:
232
+ critic_timestep = self.timestep_shift * \
233
+ (critic_timestep / 1000) / (1 + (self.timestep_shift - 1) * (critic_timestep / 1000)) * 1000
234
+
235
+ critic_timestep = critic_timestep.clamp(self.min_step, self.max_step)
236
+
237
+ critic_noise = torch.randn_like(generated_image)
238
+ noisy_generated_image = self.scheduler.add_noise(
239
+ generated_image.flatten(0, 1),
240
+ critic_noise.flatten(0, 1),
241
+ critic_timestep.flatten(0, 1)
242
+ ).unflatten(0, image_or_video_shape[:2])
243
+
244
+ _, pred_fake_image = self.fake_score(
245
+ noisy_image_or_video=noisy_generated_image,
246
+ conditional_dict=conditional_dict,
247
+ timestep=critic_timestep
248
+ )
249
+
250
+ # Step 3: Compute the denoising loss for the fake critic
251
+ if self.args.denoising_loss_type == "flow":
252
+ from utils.wan_wrapper import WanDiffusionWrapper
253
+ flow_pred = WanDiffusionWrapper._convert_x0_to_flow_pred(
254
+ scheduler=self.scheduler,
255
+ x0_pred=pred_fake_image.flatten(0, 1),
256
+ xt=noisy_generated_image.flatten(0, 1),
257
+ timestep=critic_timestep.flatten(0, 1)
258
+ )
259
+ pred_fake_noise = None
260
+ else:
261
+ flow_pred = None
262
+ pred_fake_noise = self.scheduler.convert_x0_to_noise(
263
+ x0=pred_fake_image.flatten(0, 1),
264
+ xt=noisy_generated_image.flatten(0, 1),
265
+ timestep=critic_timestep.flatten(0, 1)
266
+ ).unflatten(0, image_or_video_shape[:2])
267
+
268
+ denoising_loss = self.denoising_loss_func(
269
+ x=generated_image.flatten(0, 1),
270
+ x_pred=pred_fake_image.flatten(0, 1),
271
+ noise=critic_noise.flatten(0, 1),
272
+ noise_pred=pred_fake_noise,
273
+ alphas_cumprod=self.scheduler.alphas_cumprod,
274
+ timestep=critic_timestep.flatten(0, 1),
275
+ flow_pred=flow_pred
276
+ )
277
+
278
+ # Step 5: Debugging Log
279
+ critic_log_dict = {
280
+ "critic_timestep": critic_timestep.detach()
281
+ }
282
+
283
+ return denoising_loss, critic_log_dict
pipeline/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .bidirectional_diffusion_inference import BidirectionalDiffusionInferencePipeline
2
+ from .bidirectional_inference import BidirectionalInferencePipeline
3
+ from .causal_diffusion_inference import CausalDiffusionInferencePipeline
4
+ from .rolling_forcing_inference import CausalInferencePipeline
5
+ from .rolling_forcing_training import RollingForcingTrainingPipeline
6
+
7
+ __all__ = [
8
+ "BidirectionalDiffusionInferencePipeline",
9
+ "BidirectionalInferencePipeline",
10
+ "CausalDiffusionInferencePipeline",
11
+ "CausalInferencePipeline",
12
+ "RollingForcingTrainingPipeline"
13
+ ]
pipeline/bidirectional_diffusion_inference.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tqdm import tqdm
2
+ from typing import List
3
+ import torch
4
+
5
+ from wan.utils.fm_solvers import FlowDPMSolverMultistepScheduler, get_sampling_sigmas, retrieve_timesteps
6
+ from wan.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
7
+ from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder, WanVAEWrapper
8
+
9
+
10
+ class BidirectionalDiffusionInferencePipeline(torch.nn.Module):
11
+ def __init__(
12
+ self,
13
+ args,
14
+ device,
15
+ generator=None,
16
+ text_encoder=None,
17
+ vae=None
18
+ ):
19
+ super().__init__()
20
+ # Step 1: Initialize all models
21
+ self.generator = WanDiffusionWrapper(
22
+ **getattr(args, "model_kwargs", {}), is_causal=False) if generator is None else generator
23
+ self.text_encoder = WanTextEncoder() if text_encoder is None else text_encoder
24
+ self.vae = WanVAEWrapper() if vae is None else vae
25
+
26
+ # Step 2: Initialize scheduler
27
+ self.num_train_timesteps = args.num_train_timestep
28
+ self.sampling_steps = 50
29
+ self.sample_solver = 'unipc'
30
+ self.shift = 8.0
31
+
32
+ self.args = args
33
+
34
+ def inference(
35
+ self,
36
+ noise: torch.Tensor,
37
+ text_prompts: List[str],
38
+ return_latents=False
39
+ ) -> torch.Tensor:
40
+ """
41
+ Perform inference on the given noise and text prompts.
42
+ Inputs:
43
+ noise (torch.Tensor): The input noise tensor of shape
44
+ (batch_size, num_frames, num_channels, height, width).
45
+ text_prompts (List[str]): The list of text prompts.
46
+ Outputs:
47
+ video (torch.Tensor): The generated video tensor of shape
48
+ (batch_size, num_frames, num_channels, height, width). It is normalized to be in the range [0, 1].
49
+ """
50
+
51
+ conditional_dict = self.text_encoder(
52
+ text_prompts=text_prompts
53
+ )
54
+ unconditional_dict = self.text_encoder(
55
+ text_prompts=[self.args.negative_prompt] * len(text_prompts)
56
+ )
57
+
58
+ latents = noise
59
+
60
+ sample_scheduler = self._initialize_sample_scheduler(noise)
61
+ for _, t in enumerate(tqdm(sample_scheduler.timesteps)):
62
+ latent_model_input = latents
63
+ timestep = t * torch.ones([latents.shape[0], 21], device=noise.device, dtype=torch.float32)
64
+
65
+ flow_pred_cond, _ = self.generator(latent_model_input, conditional_dict, timestep)
66
+ flow_pred_uncond, _ = self.generator(latent_model_input, unconditional_dict, timestep)
67
+
68
+ flow_pred = flow_pred_uncond + self.args.guidance_scale * (
69
+ flow_pred_cond - flow_pred_uncond)
70
+
71
+ temp_x0 = sample_scheduler.step(
72
+ flow_pred.unsqueeze(0),
73
+ t,
74
+ latents.unsqueeze(0),
75
+ return_dict=False)[0]
76
+ latents = temp_x0.squeeze(0)
77
+
78
+ x0 = latents
79
+ video = self.vae.decode_to_pixel(x0)
80
+ video = (video * 0.5 + 0.5).clamp(0, 1)
81
+
82
+ del sample_scheduler
83
+
84
+ if return_latents:
85
+ return video, latents
86
+ else:
87
+ return video
88
+
89
+ def _initialize_sample_scheduler(self, noise):
90
+ if self.sample_solver == 'unipc':
91
+ sample_scheduler = FlowUniPCMultistepScheduler(
92
+ num_train_timesteps=self.num_train_timesteps,
93
+ shift=1,
94
+ use_dynamic_shifting=False)
95
+ sample_scheduler.set_timesteps(
96
+ self.sampling_steps, device=noise.device, shift=self.shift)
97
+ self.timesteps = sample_scheduler.timesteps
98
+ elif self.sample_solver == 'dpm++':
99
+ sample_scheduler = FlowDPMSolverMultistepScheduler(
100
+ num_train_timesteps=self.num_train_timesteps,
101
+ shift=1,
102
+ use_dynamic_shifting=False)
103
+ sampling_sigmas = get_sampling_sigmas(self.sampling_steps, self.shift)
104
+ self.timesteps, _ = retrieve_timesteps(
105
+ sample_scheduler,
106
+ device=noise.device,
107
+ sigmas=sampling_sigmas)
108
+ else:
109
+ raise NotImplementedError("Unsupported solver.")
110
+ return sample_scheduler
pipeline/bidirectional_inference.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ import torch
3
+
4
+ from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder, WanVAEWrapper
5
+
6
+
7
+ class BidirectionalInferencePipeline(torch.nn.Module):
8
+ def __init__(
9
+ self,
10
+ args,
11
+ device,
12
+ generator=None,
13
+ text_encoder=None,
14
+ vae=None
15
+ ):
16
+ super().__init__()
17
+ # Step 1: Initialize all models
18
+ self.generator = WanDiffusionWrapper(
19
+ **getattr(args, "model_kwargs", {}), is_causal=False) if generator is None else generator
20
+ self.text_encoder = WanTextEncoder() if text_encoder is None else text_encoder
21
+ self.vae = WanVAEWrapper() if vae is None else vae
22
+
23
+ # Step 2: Initialize all bidirectional wan hyperparmeters
24
+ self.scheduler = self.generator.get_scheduler()
25
+ self.denoising_step_list = torch.tensor(
26
+ args.denoising_step_list, dtype=torch.long, device=device)
27
+ if self.denoising_step_list[-1] == 0:
28
+ self.denoising_step_list = self.denoising_step_list[:-1] # remove the zero timestep for inference
29
+ if args.warp_denoising_step:
30
+ timesteps = torch.cat((self.scheduler.timesteps.cpu(), torch.tensor([0], dtype=torch.float32)))
31
+ self.denoising_step_list = timesteps[1000 - self.denoising_step_list]
32
+
33
+ def inference(self, noise: torch.Tensor, text_prompts: List[str]) -> torch.Tensor:
34
+ """
35
+ Perform inference on the given noise and text prompts.
36
+ Inputs:
37
+ noise (torch.Tensor): The input noise tensor of shape
38
+ (batch_size, num_frames, num_channels, height, width).
39
+ text_prompts (List[str]): The list of text prompts.
40
+ Outputs:
41
+ video (torch.Tensor): The generated video tensor of shape
42
+ (batch_size, num_frames, num_channels, height, width). It is normalized to be in the range [0, 1].
43
+ """
44
+ conditional_dict = self.text_encoder(
45
+ text_prompts=text_prompts
46
+ )
47
+
48
+ # initial point
49
+ noisy_image_or_video = noise
50
+
51
+ # use the last n-1 timesteps to simulate the generator's input
52
+ for index, current_timestep in enumerate(self.denoising_step_list[:-1]):
53
+ _, pred_image_or_video = self.generator(
54
+ noisy_image_or_video=noisy_image_or_video,
55
+ conditional_dict=conditional_dict,
56
+ timestep=torch.ones(
57
+ noise.shape[:2], dtype=torch.long, device=noise.device) * current_timestep
58
+ ) # [B, F, C, H, W]
59
+
60
+ next_timestep = self.denoising_step_list[index + 1] * torch.ones(
61
+ noise.shape[:2], dtype=torch.long, device=noise.device)
62
+
63
+ noisy_image_or_video = self.scheduler.add_noise(
64
+ pred_image_or_video.flatten(0, 1),
65
+ torch.randn_like(pred_image_or_video.flatten(0, 1)),
66
+ next_timestep.flatten(0, 1)
67
+ ).unflatten(0, noise.shape[:2])
68
+
69
+ video = self.vae.decode_to_pixel(pred_image_or_video)
70
+ video = (video * 0.5 + 0.5).clamp(0, 1)
71
+ return video
pipeline/causal_diffusion_inference.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tqdm import tqdm
2
+ from typing import List, Optional
3
+ import torch
4
+
5
+ from wan.utils.fm_solvers import FlowDPMSolverMultistepScheduler, get_sampling_sigmas, retrieve_timesteps
6
+ from wan.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
7
+ from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder, WanVAEWrapper
8
+
9
+
10
+ class CausalDiffusionInferencePipeline(torch.nn.Module):
11
+ def __init__(
12
+ self,
13
+ args,
14
+ device,
15
+ generator=None,
16
+ text_encoder=None,
17
+ vae=None
18
+ ):
19
+ super().__init__()
20
+ # Step 1: Initialize all models
21
+ self.generator = WanDiffusionWrapper(
22
+ **getattr(args, "model_kwargs", {}), is_causal=True) if generator is None else generator
23
+ self.text_encoder = WanTextEncoder() if text_encoder is None else text_encoder
24
+ self.vae = WanVAEWrapper() if vae is None else vae
25
+
26
+ # Step 2: Initialize scheduler
27
+ self.num_train_timesteps = args.num_train_timestep
28
+ self.sampling_steps = 50
29
+ self.sample_solver = 'unipc'
30
+ self.shift = args.timestep_shift
31
+
32
+ self.num_transformer_blocks = 30
33
+ self.frame_seq_length = 1560
34
+
35
+ self.kv_cache_pos = None
36
+ self.kv_cache_neg = None
37
+ self.crossattn_cache_pos = None
38
+ self.crossattn_cache_neg = None
39
+ self.args = args
40
+ self.num_frame_per_block = getattr(args, "num_frame_per_block", 1)
41
+ self.independent_first_frame = args.independent_first_frame
42
+ self.local_attn_size = self.generator.model.local_attn_size
43
+
44
+ print(f"KV inference with {self.num_frame_per_block} frames per block")
45
+
46
+ if self.num_frame_per_block > 1:
47
+ self.generator.model.num_frame_per_block = self.num_frame_per_block
48
+
49
+ def inference(
50
+ self,
51
+ noise: torch.Tensor,
52
+ text_prompts: List[str],
53
+ initial_latent: Optional[torch.Tensor] = None,
54
+ return_latents: bool = False,
55
+ start_frame_index: Optional[int] = 0
56
+ ) -> torch.Tensor:
57
+ """
58
+ Perform inference on the given noise and text prompts.
59
+ Inputs:
60
+ noise (torch.Tensor): The input noise tensor of shape
61
+ (batch_size, num_output_frames, num_channels, height, width).
62
+ text_prompts (List[str]): The list of text prompts.
63
+ initial_latent (torch.Tensor): The initial latent tensor of shape
64
+ (batch_size, num_input_frames, num_channels, height, width).
65
+ If num_input_frames is 1, perform image to video.
66
+ If num_input_frames is greater than 1, perform video extension.
67
+ return_latents (bool): Whether to return the latents.
68
+ start_frame_index (int): In long video generation, where does the current window start?
69
+ Outputs:
70
+ video (torch.Tensor): The generated video tensor of shape
71
+ (batch_size, num_frames, num_channels, height, width). It is normalized to be in the range [0, 1].
72
+ """
73
+ batch_size, num_frames, num_channels, height, width = noise.shape
74
+ if not self.independent_first_frame or (self.independent_first_frame and initial_latent is not None):
75
+ # If the first frame is independent and the first frame is provided, then the number of frames in the
76
+ # noise should still be a multiple of num_frame_per_block
77
+ assert num_frames % self.num_frame_per_block == 0
78
+ num_blocks = num_frames // self.num_frame_per_block
79
+ elif self.independent_first_frame and initial_latent is None:
80
+ # Using a [1, 4, 4, 4, 4, 4] model to generate a video without image conditioning
81
+ assert (num_frames - 1) % self.num_frame_per_block == 0
82
+ num_blocks = (num_frames - 1) // self.num_frame_per_block
83
+ num_input_frames = initial_latent.shape[1] if initial_latent is not None else 0
84
+ num_output_frames = num_frames + num_input_frames # add the initial latent frames
85
+ conditional_dict = self.text_encoder(
86
+ text_prompts=text_prompts
87
+ )
88
+ unconditional_dict = self.text_encoder(
89
+ text_prompts=[self.args.negative_prompt] * len(text_prompts)
90
+ )
91
+
92
+ output = torch.zeros(
93
+ [batch_size, num_output_frames, num_channels, height, width],
94
+ device=noise.device,
95
+ dtype=noise.dtype
96
+ )
97
+
98
+ # Step 1: Initialize KV cache to all zeros
99
+ if self.kv_cache_pos is None:
100
+ self._initialize_kv_cache(
101
+ batch_size=batch_size,
102
+ dtype=noise.dtype,
103
+ device=noise.device
104
+ )
105
+ self._initialize_crossattn_cache(
106
+ batch_size=batch_size,
107
+ dtype=noise.dtype,
108
+ device=noise.device
109
+ )
110
+ else:
111
+ # reset cross attn cache
112
+ for block_index in range(self.num_transformer_blocks):
113
+ self.crossattn_cache_pos[block_index]["is_init"] = False
114
+ self.crossattn_cache_neg[block_index]["is_init"] = False
115
+ # reset kv cache
116
+ for block_index in range(len(self.kv_cache_pos)):
117
+ self.kv_cache_pos[block_index]["global_end_index"] = torch.tensor(
118
+ [0], dtype=torch.long, device=noise.device)
119
+ self.kv_cache_pos[block_index]["local_end_index"] = torch.tensor(
120
+ [0], dtype=torch.long, device=noise.device)
121
+ self.kv_cache_neg[block_index]["global_end_index"] = torch.tensor(
122
+ [0], dtype=torch.long, device=noise.device)
123
+ self.kv_cache_neg[block_index]["local_end_index"] = torch.tensor(
124
+ [0], dtype=torch.long, device=noise.device)
125
+
126
+ # Step 2: Cache context feature
127
+ current_start_frame = start_frame_index
128
+ cache_start_frame = 0
129
+ if initial_latent is not None:
130
+ timestep = torch.ones([batch_size, 1], device=noise.device, dtype=torch.int64) * 0
131
+ if self.independent_first_frame:
132
+ # Assume num_input_frames is 1 + self.num_frame_per_block * num_input_blocks
133
+ assert (num_input_frames - 1) % self.num_frame_per_block == 0
134
+ num_input_blocks = (num_input_frames - 1) // self.num_frame_per_block
135
+ output[:, :1] = initial_latent[:, :1]
136
+ self.generator(
137
+ noisy_image_or_video=initial_latent[:, :1],
138
+ conditional_dict=conditional_dict,
139
+ timestep=timestep * 0,
140
+ kv_cache=self.kv_cache_pos,
141
+ crossattn_cache=self.crossattn_cache_pos,
142
+ current_start=current_start_frame * self.frame_seq_length,
143
+ cache_start=cache_start_frame * self.frame_seq_length
144
+ )
145
+ self.generator(
146
+ noisy_image_or_video=initial_latent[:, :1],
147
+ conditional_dict=unconditional_dict,
148
+ timestep=timestep * 0,
149
+ kv_cache=self.kv_cache_neg,
150
+ crossattn_cache=self.crossattn_cache_neg,
151
+ current_start=current_start_frame * self.frame_seq_length,
152
+ cache_start=cache_start_frame * self.frame_seq_length
153
+ )
154
+ current_start_frame += 1
155
+ cache_start_frame += 1
156
+ else:
157
+ # Assume num_input_frames is self.num_frame_per_block * num_input_blocks
158
+ assert num_input_frames % self.num_frame_per_block == 0
159
+ num_input_blocks = num_input_frames // self.num_frame_per_block
160
+
161
+ for block_index in range(num_input_blocks):
162
+ current_ref_latents = \
163
+ initial_latent[:, cache_start_frame:cache_start_frame + self.num_frame_per_block]
164
+ output[:, cache_start_frame:cache_start_frame + self.num_frame_per_block] = current_ref_latents
165
+ self.generator(
166
+ noisy_image_or_video=current_ref_latents,
167
+ conditional_dict=conditional_dict,
168
+ timestep=timestep * 0,
169
+ kv_cache=self.kv_cache_pos,
170
+ crossattn_cache=self.crossattn_cache_pos,
171
+ current_start=current_start_frame * self.frame_seq_length,
172
+ cache_start=cache_start_frame * self.frame_seq_length
173
+ )
174
+ self.generator(
175
+ noisy_image_or_video=current_ref_latents,
176
+ conditional_dict=unconditional_dict,
177
+ timestep=timestep * 0,
178
+ kv_cache=self.kv_cache_neg,
179
+ crossattn_cache=self.crossattn_cache_neg,
180
+ current_start=current_start_frame * self.frame_seq_length,
181
+ cache_start=cache_start_frame * self.frame_seq_length
182
+ )
183
+ current_start_frame += self.num_frame_per_block
184
+ cache_start_frame += self.num_frame_per_block
185
+
186
+ # Step 3: Temporal denoising loop
187
+ all_num_frames = [self.num_frame_per_block] * num_blocks
188
+ if self.independent_first_frame and initial_latent is None:
189
+ all_num_frames = [1] + all_num_frames
190
+ for current_num_frames in all_num_frames:
191
+ noisy_input = noise[
192
+ :, cache_start_frame - num_input_frames:cache_start_frame + current_num_frames - num_input_frames]
193
+ latents = noisy_input
194
+
195
+ # Step 3.1: Spatial denoising loop
196
+ sample_scheduler = self._initialize_sample_scheduler(noise)
197
+ for _, t in enumerate(tqdm(sample_scheduler.timesteps)):
198
+ latent_model_input = latents
199
+ timestep = t * torch.ones(
200
+ [batch_size, current_num_frames], device=noise.device, dtype=torch.float32
201
+ )
202
+
203
+ flow_pred_cond, _ = self.generator(
204
+ noisy_image_or_video=latent_model_input,
205
+ conditional_dict=conditional_dict,
206
+ timestep=timestep,
207
+ kv_cache=self.kv_cache_pos,
208
+ crossattn_cache=self.crossattn_cache_pos,
209
+ current_start=current_start_frame * self.frame_seq_length,
210
+ cache_start=cache_start_frame * self.frame_seq_length
211
+ )
212
+ flow_pred_uncond, _ = self.generator(
213
+ noisy_image_or_video=latent_model_input,
214
+ conditional_dict=unconditional_dict,
215
+ timestep=timestep,
216
+ kv_cache=self.kv_cache_neg,
217
+ crossattn_cache=self.crossattn_cache_neg,
218
+ current_start=current_start_frame * self.frame_seq_length,
219
+ cache_start=cache_start_frame * self.frame_seq_length
220
+ )
221
+
222
+ flow_pred = flow_pred_uncond + self.args.guidance_scale * (
223
+ flow_pred_cond - flow_pred_uncond)
224
+
225
+ temp_x0 = sample_scheduler.step(
226
+ flow_pred,
227
+ t,
228
+ latents,
229
+ return_dict=False)[0]
230
+ latents = temp_x0
231
+ print(f"kv_cache['local_end_index']: {self.kv_cache_pos[0]['local_end_index']}")
232
+ print(f"kv_cache['global_end_index']: {self.kv_cache_pos[0]['global_end_index']}")
233
+
234
+ # Step 3.2: record the model's output
235
+ output[:, cache_start_frame:cache_start_frame + current_num_frames] = latents
236
+
237
+ # Step 3.3: rerun with timestep zero to update KV cache using clean context
238
+ self.generator(
239
+ noisy_image_or_video=latents,
240
+ conditional_dict=conditional_dict,
241
+ timestep=timestep * 0,
242
+ kv_cache=self.kv_cache_pos,
243
+ crossattn_cache=self.crossattn_cache_pos,
244
+ current_start=current_start_frame * self.frame_seq_length,
245
+ cache_start=cache_start_frame * self.frame_seq_length
246
+ )
247
+ self.generator(
248
+ noisy_image_or_video=latents,
249
+ conditional_dict=unconditional_dict,
250
+ timestep=timestep * 0,
251
+ kv_cache=self.kv_cache_neg,
252
+ crossattn_cache=self.crossattn_cache_neg,
253
+ current_start=current_start_frame * self.frame_seq_length,
254
+ cache_start=cache_start_frame * self.frame_seq_length
255
+ )
256
+
257
+ # Step 3.4: update the start and end frame indices
258
+ current_start_frame += current_num_frames
259
+ cache_start_frame += current_num_frames
260
+
261
+ # Step 4: Decode the output
262
+ video = self.vae.decode_to_pixel(output)
263
+ video = (video * 0.5 + 0.5).clamp(0, 1)
264
+
265
+ if return_latents:
266
+ return video, output
267
+ else:
268
+ return video
269
+
270
+ def _initialize_kv_cache(self, batch_size, dtype, device):
271
+ """
272
+ Initialize a Per-GPU KV cache for the Wan model.
273
+ """
274
+ kv_cache_pos = []
275
+ kv_cache_neg = []
276
+ if self.local_attn_size != -1:
277
+ # Use the local attention size to compute the KV cache size
278
+ kv_cache_size = self.local_attn_size * self.frame_seq_length
279
+ else:
280
+ # Use the default KV cache size
281
+ kv_cache_size = 32760
282
+
283
+ for _ in range(self.num_transformer_blocks):
284
+ kv_cache_pos.append({
285
+ "k": torch.zeros([batch_size, kv_cache_size, 12, 128], dtype=dtype, device=device),
286
+ "v": torch.zeros([batch_size, kv_cache_size, 12, 128], dtype=dtype, device=device),
287
+ "global_end_index": torch.tensor([0], dtype=torch.long, device=device),
288
+ "local_end_index": torch.tensor([0], dtype=torch.long, device=device)
289
+ })
290
+ kv_cache_neg.append({
291
+ "k": torch.zeros([batch_size, kv_cache_size, 12, 128], dtype=dtype, device=device),
292
+ "v": torch.zeros([batch_size, kv_cache_size, 12, 128], dtype=dtype, device=device),
293
+ "global_end_index": torch.tensor([0], dtype=torch.long, device=device),
294
+ "local_end_index": torch.tensor([0], dtype=torch.long, device=device)
295
+ })
296
+
297
+ self.kv_cache_pos = kv_cache_pos # always store the clean cache
298
+ self.kv_cache_neg = kv_cache_neg # always store the clean cache
299
+
300
+ def _initialize_crossattn_cache(self, batch_size, dtype, device):
301
+ """
302
+ Initialize a Per-GPU cross-attention cache for the Wan model.
303
+ """
304
+ crossattn_cache_pos = []
305
+ crossattn_cache_neg = []
306
+ for _ in range(self.num_transformer_blocks):
307
+ crossattn_cache_pos.append({
308
+ "k": torch.zeros([batch_size, 512, 12, 128], dtype=dtype, device=device),
309
+ "v": torch.zeros([batch_size, 512, 12, 128], dtype=dtype, device=device),
310
+ "is_init": False
311
+ })
312
+ crossattn_cache_neg.append({
313
+ "k": torch.zeros([batch_size, 512, 12, 128], dtype=dtype, device=device),
314
+ "v": torch.zeros([batch_size, 512, 12, 128], dtype=dtype, device=device),
315
+ "is_init": False
316
+ })
317
+
318
+ self.crossattn_cache_pos = crossattn_cache_pos # always store the clean cache
319
+ self.crossattn_cache_neg = crossattn_cache_neg # always store the clean cache
320
+
321
+ def _initialize_sample_scheduler(self, noise):
322
+ if self.sample_solver == 'unipc':
323
+ sample_scheduler = FlowUniPCMultistepScheduler(
324
+ num_train_timesteps=self.num_train_timesteps,
325
+ shift=1,
326
+ use_dynamic_shifting=False)
327
+ sample_scheduler.set_timesteps(
328
+ self.sampling_steps, device=noise.device, shift=self.shift)
329
+ self.timesteps = sample_scheduler.timesteps
330
+ elif self.sample_solver == 'dpm++':
331
+ sample_scheduler = FlowDPMSolverMultistepScheduler(
332
+ num_train_timesteps=self.num_train_timesteps,
333
+ shift=1,
334
+ use_dynamic_shifting=False)
335
+ sampling_sigmas = get_sampling_sigmas(self.sampling_steps, self.shift)
336
+ self.timesteps, _ = retrieve_timesteps(
337
+ sample_scheduler,
338
+ device=noise.device,
339
+ sigmas=sampling_sigmas)
340
+ else:
341
+ raise NotImplementedError("Unsupported solver.")
342
+ return sample_scheduler
pipeline/rolling_forcing_inference.py ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional
2
+ import torch
3
+
4
+ from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder, WanVAEWrapper
5
+
6
+
7
+ class CausalInferencePipeline(torch.nn.Module):
8
+ def __init__(
9
+ self,
10
+ args,
11
+ device,
12
+ generator=None,
13
+ text_encoder=None,
14
+ vae=None
15
+ ):
16
+ super().__init__()
17
+ # Step 1: Initialize all models
18
+ self.generator = WanDiffusionWrapper(
19
+ **getattr(args, "model_kwargs", {}), is_causal=True) if generator is None else generator
20
+ self.text_encoder = WanTextEncoder() if text_encoder is None else text_encoder
21
+ self.vae = WanVAEWrapper() if vae is None else vae
22
+
23
+ # Step 2: Initialize all causal hyperparmeters
24
+ self.scheduler = self.generator.get_scheduler()
25
+ self.denoising_step_list = torch.tensor(
26
+ args.denoising_step_list, dtype=torch.long)
27
+ if args.warp_denoising_step:
28
+ timesteps = torch.cat((self.scheduler.timesteps.cpu(), torch.tensor([0], dtype=torch.float32)))
29
+ self.denoising_step_list = timesteps[1000 - self.denoising_step_list]
30
+
31
+ self.num_transformer_blocks = 30
32
+ self.frame_seq_length = 1560
33
+
34
+ self.kv_cache_clean = None
35
+ self.args = args
36
+ self.num_frame_per_block = getattr(args, "num_frame_per_block", 1)
37
+ self.independent_first_frame = args.independent_first_frame
38
+ self.local_attn_size = self.generator.model.local_attn_size
39
+
40
+ print(f"KV inference with {self.num_frame_per_block} frames per block")
41
+
42
+ if self.num_frame_per_block > 1:
43
+ self.generator.model.num_frame_per_block = self.num_frame_per_block
44
+
45
+ def inference_rolling_forcing(
46
+ self,
47
+ noise: torch.Tensor,
48
+ text_prompts: List[str],
49
+ initial_latent: Optional[torch.Tensor] = None,
50
+ return_latents: bool = False,
51
+ profile: bool = False
52
+ ) -> torch.Tensor:
53
+ """
54
+ Perform inference on the given noise and text prompts.
55
+ Inputs:
56
+ noise (torch.Tensor): The input noise tensor of shape
57
+ (batch_size, num_output_frames, num_channels, height, width).
58
+ text_prompts (List[str]): The list of text prompts.
59
+ initial_latent (torch.Tensor): The initial latent tensor of shape
60
+ (batch_size, num_input_frames, num_channels, height, width).
61
+ If num_input_frames is 1, perform image to video.
62
+ If num_input_frames is greater than 1, perform video extension.
63
+ return_latents (bool): Whether to return the latents.
64
+ Outputs:
65
+ video (torch.Tensor): The generated video tensor of shape
66
+ (batch_size, num_output_frames, num_channels, height, width).
67
+ It is normalized to be in the range [0, 1].
68
+ """
69
+ batch_size, num_frames, num_channels, height, width = noise.shape
70
+ if not self.independent_first_frame or (self.independent_first_frame and initial_latent is not None):
71
+ # If the first frame is independent and the first frame is provided, then the number of frames in the
72
+ # noise should still be a multiple of num_frame_per_block
73
+ assert num_frames % self.num_frame_per_block == 0
74
+ num_blocks = num_frames // self.num_frame_per_block
75
+ else:
76
+ # Using a [1, 4, 4, 4, 4, 4, ...] model to generate a video without image conditioning
77
+ assert (num_frames - 1) % self.num_frame_per_block == 0
78
+ num_blocks = (num_frames - 1) // self.num_frame_per_block
79
+ num_input_frames = initial_latent.shape[1] if initial_latent is not None else 0
80
+ num_output_frames = num_frames + num_input_frames # add the initial latent frames
81
+ conditional_dict = self.text_encoder(
82
+ text_prompts=text_prompts
83
+ )
84
+
85
+ output = torch.zeros(
86
+ [batch_size, num_output_frames, num_channels, height, width],
87
+ device=noise.device,
88
+ dtype=noise.dtype
89
+ )
90
+
91
+ # Set up profiling if requested
92
+ if profile:
93
+ init_start = torch.cuda.Event(enable_timing=True)
94
+ init_end = torch.cuda.Event(enable_timing=True)
95
+ diffusion_start = torch.cuda.Event(enable_timing=True)
96
+ diffusion_end = torch.cuda.Event(enable_timing=True)
97
+ vae_start = torch.cuda.Event(enable_timing=True)
98
+ vae_end = torch.cuda.Event(enable_timing=True)
99
+ block_times = []
100
+ block_start = torch.cuda.Event(enable_timing=True)
101
+ block_end = torch.cuda.Event(enable_timing=True)
102
+ init_start.record()
103
+
104
+ # Step 1: Initialize KV cache to all zeros
105
+ if self.kv_cache_clean is None:
106
+ self._initialize_kv_cache(
107
+ batch_size=batch_size,
108
+ dtype=noise.dtype,
109
+ device=noise.device
110
+ )
111
+ self._initialize_crossattn_cache(
112
+ batch_size=batch_size,
113
+ dtype=noise.dtype,
114
+ device=noise.device
115
+ )
116
+ else:
117
+ # reset cross attn cache
118
+ for block_index in range(self.num_transformer_blocks):
119
+ self.crossattn_cache[block_index]["is_init"] = False
120
+ # reset kv cache
121
+ for block_index in range(len(self.kv_cache_clean)):
122
+ self.kv_cache_clean[block_index]["global_end_index"] = torch.tensor(
123
+ [0], dtype=torch.long, device=noise.device)
124
+ self.kv_cache_clean[block_index]["local_end_index"] = torch.tensor(
125
+ [0], dtype=torch.long, device=noise.device)
126
+
127
+ # Step 2: Cache context feature
128
+ if initial_latent is not None:
129
+ timestep = torch.ones([batch_size, 1], device=noise.device, dtype=torch.int64) * 0
130
+ if self.independent_first_frame:
131
+ # Assume num_input_frames is 1 + self.num_frame_per_block * num_input_blocks
132
+ assert (num_input_frames - 1) % self.num_frame_per_block == 0
133
+ num_input_blocks = (num_input_frames - 1) // self.num_frame_per_block
134
+ output[:, :1] = initial_latent[:, :1]
135
+ self.generator(
136
+ noisy_image_or_video=initial_latent[:, :1],
137
+ conditional_dict=conditional_dict,
138
+ timestep=timestep * 0,
139
+ kv_cache=self.kv_cache_clean,
140
+ crossattn_cache=self.crossattn_cache,
141
+ current_start=current_start_frame * self.frame_seq_length,
142
+ )
143
+ current_start_frame += 1
144
+ else:
145
+ # Assume num_input_frames is self.num_frame_per_block * num_input_blocks
146
+ assert num_input_frames % self.num_frame_per_block == 0
147
+ num_input_blocks = num_input_frames // self.num_frame_per_block
148
+
149
+ for _ in range(num_input_blocks):
150
+ current_ref_latents = \
151
+ initial_latent[:, current_start_frame:current_start_frame + self.num_frame_per_block]
152
+ output[:, current_start_frame:current_start_frame + self.num_frame_per_block] = current_ref_latents
153
+ self.generator(
154
+ noisy_image_or_video=current_ref_latents,
155
+ conditional_dict=conditional_dict,
156
+ timestep=timestep * 0,
157
+ kv_cache=self.kv_cache_clean,
158
+ crossattn_cache=self.crossattn_cache,
159
+ current_start=current_start_frame * self.frame_seq_length,
160
+ )
161
+ current_start_frame += self.num_frame_per_block
162
+
163
+ if profile:
164
+ init_end.record()
165
+ torch.cuda.synchronize()
166
+ diffusion_start.record()
167
+
168
+ # implementing rolling forcing
169
+ # construct the rolling forcing windows
170
+ num_denoising_steps = len(self.denoising_step_list)
171
+ rolling_window_length_blocks = num_denoising_steps
172
+ window_start_blocks = []
173
+ window_end_blocks = []
174
+ window_num = num_blocks + rolling_window_length_blocks - 1
175
+
176
+ for window_index in range(window_num):
177
+ start_block = max(0, window_index - rolling_window_length_blocks + 1)
178
+ end_block = min(num_blocks - 1, window_index)
179
+ window_start_blocks.append(start_block)
180
+ window_end_blocks.append(end_block)
181
+
182
+ # init noisy cache
183
+ noisy_cache = torch.zeros(
184
+ [batch_size, num_output_frames, num_channels, height, width],
185
+ device=noise.device,
186
+ dtype=noise.dtype
187
+ )
188
+
189
+ # init denosing timestep, same accross windows
190
+ shared_timestep = torch.ones(
191
+ [batch_size, rolling_window_length_blocks * self.num_frame_per_block],
192
+ device=noise.device,
193
+ dtype=torch.float32)
194
+
195
+ for index, current_timestep in enumerate(reversed(self.denoising_step_list)): # from clean to noisy
196
+ shared_timestep[:, index * self.num_frame_per_block:(index + 1) * self.num_frame_per_block] *= current_timestep
197
+
198
+
199
+ # Denoising loop with rolling forcing
200
+ for window_index in range(window_num):
201
+
202
+ if profile:
203
+ block_start.record()
204
+
205
+ print('window_index:', window_index)
206
+ start_block = window_start_blocks[window_index]
207
+ end_block = window_end_blocks[window_index] # include
208
+ print(f"start_block: {start_block}, end_block: {end_block}")
209
+
210
+ current_start_frame = start_block * self.num_frame_per_block
211
+ current_end_frame = (end_block + 1) * self.num_frame_per_block # not include
212
+ current_num_frames = current_end_frame - current_start_frame
213
+
214
+ # noisy_input: new noise and previous denoised noisy frames, only last block is pure noise
215
+ if current_num_frames == rolling_window_length_blocks * self.num_frame_per_block or current_start_frame == 0:
216
+ noisy_input = torch.cat([
217
+ noisy_cache[:, current_start_frame : current_end_frame - self.num_frame_per_block],
218
+ noise[:, current_end_frame - self.num_frame_per_block : current_end_frame ]
219
+ ], dim=1)
220
+ else: # at the end of the video
221
+ noisy_input = noisy_cache[:, current_start_frame:current_end_frame]
222
+
223
+ # init denosing timestep
224
+ if current_num_frames == rolling_window_length_blocks * self.num_frame_per_block:
225
+ current_timestep = shared_timestep
226
+ elif current_start_frame == 0:
227
+ current_timestep = shared_timestep[:,-current_num_frames:]
228
+ elif current_end_frame == num_frames:
229
+ current_timestep = shared_timestep[:,:current_num_frames]
230
+ else:
231
+ raise ValueError("current_num_frames should be equal to rolling_window_length_blocks * self.num_frame_per_block, or the first or last window.")
232
+
233
+
234
+ # calling DiT
235
+ _, denoised_pred = self.generator(
236
+ noisy_image_or_video=noisy_input,
237
+ conditional_dict=conditional_dict,
238
+ timestep=current_timestep,
239
+ kv_cache=self.kv_cache_clean,
240
+ crossattn_cache=self.crossattn_cache,
241
+ current_start=current_start_frame * self.frame_seq_length
242
+ )
243
+
244
+ output[:, current_start_frame:current_end_frame] = denoised_pred
245
+
246
+
247
+ # update noisy_cache, which is detached from the computation graph
248
+ with torch.no_grad():
249
+ for block_idx in range(start_block, end_block + 1):
250
+
251
+ block_time_step = current_timestep[:,
252
+ (block_idx - start_block)*self.num_frame_per_block :
253
+ (block_idx - start_block+1)*self.num_frame_per_block].mean().item()
254
+ matches = torch.abs(self.denoising_step_list - block_time_step) < 1e-4
255
+ block_timestep_index = torch.nonzero(matches, as_tuple=True)[0]
256
+
257
+ if block_timestep_index == len(self.denoising_step_list) - 1:
258
+ continue
259
+
260
+ next_timestep = self.denoising_step_list[block_timestep_index + 1].to(noise.device)
261
+
262
+ noisy_cache[:, block_idx * self.num_frame_per_block:
263
+ (block_idx+1) * self.num_frame_per_block] = \
264
+ self.scheduler.add_noise(
265
+ denoised_pred.flatten(0, 1),
266
+ torch.randn_like(denoised_pred.flatten(0, 1)),
267
+ next_timestep * torch.ones(
268
+ [batch_size * current_num_frames], device=noise.device, dtype=torch.long)
269
+ ).unflatten(0, denoised_pred.shape[:2])[:, (block_idx - start_block)*self.num_frame_per_block:
270
+ (block_idx - start_block+1)*self.num_frame_per_block]
271
+
272
+
273
+ # rerun with timestep zero to update the clean cache, which is also detached from the computation graph
274
+ with torch.no_grad():
275
+ context_timestep = torch.ones_like(current_timestep) * self.args.context_noise
276
+ # # add context noise
277
+ # denoised_pred = self.scheduler.add_noise(
278
+ # denoised_pred.flatten(0, 1),
279
+ # torch.randn_like(denoised_pred.flatten(0, 1)),
280
+ # context_timestep * torch.ones(
281
+ # [batch_size * current_num_frames], device=noise.device, dtype=torch.long)
282
+ # ).unflatten(0, denoised_pred.shape[:2])
283
+
284
+ # only cache the first block
285
+ denoised_pred = denoised_pred[:,:self.num_frame_per_block]
286
+ context_timestep = context_timestep[:,:self.num_frame_per_block]
287
+ self.generator(
288
+ noisy_image_or_video=denoised_pred,
289
+ conditional_dict=conditional_dict,
290
+ timestep=context_timestep,
291
+ kv_cache=self.kv_cache_clean,
292
+ crossattn_cache=self.crossattn_cache,
293
+ current_start=current_start_frame * self.frame_seq_length,
294
+ updating_cache=True,
295
+ )
296
+
297
+ if profile:
298
+ block_end.record()
299
+ torch.cuda.synchronize()
300
+ block_time = block_start.elapsed_time(block_end)
301
+ block_times.append(block_time)
302
+
303
+
304
+ if profile:
305
+ # End diffusion timing and synchronize CUDA
306
+ diffusion_end.record()
307
+ torch.cuda.synchronize()
308
+ diffusion_time = diffusion_start.elapsed_time(diffusion_end)
309
+ init_time = init_start.elapsed_time(init_end)
310
+ vae_start.record()
311
+
312
+ # Step 4: Decode the output
313
+ video = self.vae.decode_to_pixel(output, use_cache=False)
314
+ video = (video * 0.5 + 0.5).clamp(0, 1)
315
+
316
+ if profile:
317
+ # End VAE timing and synchronize CUDA
318
+ vae_end.record()
319
+ torch.cuda.synchronize()
320
+ vae_time = vae_start.elapsed_time(vae_end)
321
+ total_time = init_time + diffusion_time + vae_time
322
+
323
+ print("Profiling results:")
324
+ print(f" - Initialization/caching time: {init_time:.2f} ms ({100 * init_time / total_time:.2f}%)")
325
+ print(f" - Diffusion generation time: {diffusion_time:.2f} ms ({100 * diffusion_time / total_time:.2f}%)")
326
+ for i, block_time in enumerate(block_times):
327
+ print(f" - Block {i} generation time: {block_time:.2f} ms ({100 * block_time / diffusion_time:.2f}% of diffusion)")
328
+ print(f" - VAE decoding time: {vae_time:.2f} ms ({100 * vae_time / total_time:.2f}%)")
329
+ print(f" - Total time: {total_time:.2f} ms")
330
+
331
+ if return_latents:
332
+ return video, output
333
+ else:
334
+ return video
335
+
336
+
337
+
338
+ def _initialize_kv_cache(self, batch_size, dtype, device):
339
+ """
340
+ Initialize a Per-GPU KV cache for the Wan model.
341
+ """
342
+ kv_cache_clean = []
343
+ # if self.local_attn_size != -1:
344
+ # # Use the local attention size to compute the KV cache size
345
+ # kv_cache_size = self.local_attn_size * self.frame_seq_length
346
+ # else:
347
+ # # Use the default KV cache size
348
+ kv_cache_size = 1560 * 24
349
+
350
+ for _ in range(self.num_transformer_blocks):
351
+ kv_cache_clean.append({
352
+ "k": torch.zeros([batch_size, kv_cache_size, 12, 128], dtype=dtype, device=device),
353
+ "v": torch.zeros([batch_size, kv_cache_size, 12, 128], dtype=dtype, device=device),
354
+ "global_end_index": torch.tensor([0], dtype=torch.long, device=device),
355
+ "local_end_index": torch.tensor([0], dtype=torch.long, device=device)
356
+ })
357
+
358
+ self.kv_cache_clean = kv_cache_clean # always store the clean cache
359
+
360
+ def _initialize_crossattn_cache(self, batch_size, dtype, device):
361
+ """
362
+ Initialize a Per-GPU cross-attention cache for the Wan model.
363
+ """
364
+ crossattn_cache = []
365
+
366
+ for _ in range(self.num_transformer_blocks):
367
+ crossattn_cache.append({
368
+ "k": torch.zeros([batch_size, 512, 12, 128], dtype=dtype, device=device),
369
+ "v": torch.zeros([batch_size, 512, 12, 128], dtype=dtype, device=device),
370
+ "is_init": False
371
+ })
372
+ self.crossattn_cache = crossattn_cache
pipeline/rolling_forcing_training.py ADDED
@@ -0,0 +1,464 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils.wan_wrapper import WanDiffusionWrapper
2
+ from utils.scheduler import SchedulerInterface
3
+ from typing import List, Optional
4
+ import torch
5
+ import torch.distributed as dist
6
+
7
+
8
+ class RollingForcingTrainingPipeline:
9
+ def __init__(self,
10
+ denoising_step_list: List[int],
11
+ scheduler: SchedulerInterface,
12
+ generator: WanDiffusionWrapper,
13
+ num_frame_per_block=3,
14
+ independent_first_frame: bool = False,
15
+ same_step_across_blocks: bool = False,
16
+ last_step_only: bool = False,
17
+ num_max_frames: int = 21,
18
+ context_noise: int = 0,
19
+ **kwargs):
20
+ super().__init__()
21
+ self.scheduler = scheduler
22
+ self.generator = generator
23
+ self.denoising_step_list = denoising_step_list
24
+ if self.denoising_step_list[-1] == 0:
25
+ self.denoising_step_list = self.denoising_step_list[:-1] # remove the zero timestep for inference
26
+
27
+ # Wan specific hyperparameters
28
+ self.num_transformer_blocks = 30
29
+ self.frame_seq_length = 1560
30
+ self.num_frame_per_block = num_frame_per_block
31
+ self.context_noise = context_noise
32
+ self.i2v = False
33
+
34
+ self.kv_cache_clean = None
35
+ self.kv_cache2 = None
36
+ self.independent_first_frame = independent_first_frame
37
+ self.same_step_across_blocks = same_step_across_blocks
38
+ self.last_step_only = last_step_only
39
+ self.kv_cache_size = num_max_frames * self.frame_seq_length
40
+
41
+ def generate_and_sync_list(self, num_blocks, num_denoising_steps, device):
42
+ rank = dist.get_rank() if dist.is_initialized() else 0
43
+
44
+ if rank == 0:
45
+ # Generate random indices
46
+ indices = torch.randint(
47
+ low=0,
48
+ high=num_denoising_steps,
49
+ size=(num_blocks,),
50
+ device=device
51
+ )
52
+ if self.last_step_only:
53
+ indices = torch.ones_like(indices) * (num_denoising_steps - 1)
54
+ else:
55
+ indices = torch.empty(num_blocks, dtype=torch.long, device=device)
56
+
57
+ dist.broadcast(indices, src=0) # Broadcast the random indices to all ranks
58
+ return indices.tolist()
59
+
60
+ def generate_list(self, num_blocks, num_denoising_steps, device):
61
+
62
+ # Generate random indices
63
+ indices = torch.randint(
64
+ low=0,
65
+ high=num_denoising_steps,
66
+ size=(num_blocks,),
67
+ device=device
68
+ )
69
+ if self.last_step_only:
70
+ indices = torch.ones_like(indices) * (num_denoising_steps - 1)
71
+
72
+ return indices.tolist()
73
+
74
+
75
+ def inference_with_rolling_forcing(
76
+ self,
77
+ noise: torch.Tensor,
78
+ initial_latent: Optional[torch.Tensor] = None,
79
+ return_sim_step: bool = False,
80
+ **conditional_dict
81
+ ) -> torch.Tensor:
82
+ batch_size, num_frames, num_channels, height, width = noise.shape
83
+ if not self.independent_first_frame or (self.independent_first_frame and initial_latent is not None):
84
+ # If the first frame is independent and the first frame is provided, then the number of frames in the
85
+ # noise should still be a multiple of num_frame_per_block
86
+ assert num_frames % self.num_frame_per_block == 0
87
+ num_blocks = num_frames // self.num_frame_per_block
88
+ else:
89
+ # Using a [1, 4, 4, 4, 4, 4, ...] model to generate a video without image conditioning
90
+ assert (num_frames - 1) % self.num_frame_per_block == 0
91
+ num_blocks = (num_frames - 1) // self.num_frame_per_block
92
+ num_input_frames = initial_latent.shape[1] if initial_latent is not None else 0
93
+ num_output_frames = num_frames + num_input_frames # add the initial latent frames
94
+ output = torch.zeros(
95
+ [batch_size, num_output_frames, num_channels, height, width],
96
+ device=noise.device,
97
+ dtype=noise.dtype
98
+ )
99
+
100
+ # Step 1: Initialize KV cache to all zeros
101
+ self._initialize_kv_cache(
102
+ batch_size=batch_size, dtype=noise.dtype, device=noise.device
103
+ )
104
+ self._initialize_crossattn_cache(
105
+ batch_size=batch_size, dtype=noise.dtype, device=noise.device
106
+ )
107
+
108
+ # implementing rolling forcing
109
+ # construct the rolling forcing windows
110
+ num_denoising_steps = len(self.denoising_step_list)
111
+ rolling_window_length_blocks = num_denoising_steps
112
+ window_start_blocks = []
113
+ window_end_blocks = []
114
+ window_num = num_blocks + rolling_window_length_blocks - 1
115
+
116
+ for window_index in range(window_num):
117
+ start_block = max(0, window_index - rolling_window_length_blocks + 1)
118
+ end_block = min(num_blocks - 1, window_index)
119
+ window_start_blocks.append(start_block)
120
+ window_end_blocks.append(end_block)
121
+
122
+ # exit_flag indicates the window at which the model will backpropagate gradients.
123
+ exit_flag = torch.randint(high=rolling_window_length_blocks, device=noise.device, size=())
124
+ start_gradient_frame_index = num_output_frames - 21
125
+
126
+ # init noisy cache
127
+ noisy_cache = torch.zeros(
128
+ [batch_size, num_output_frames, num_channels, height, width],
129
+ device=noise.device,
130
+ dtype=noise.dtype
131
+ )
132
+
133
+ # init denosing timestep, same accross windows
134
+ shared_timestep = torch.ones(
135
+ [batch_size, rolling_window_length_blocks * self.num_frame_per_block],
136
+ device=noise.device,
137
+ dtype=torch.float32)
138
+
139
+ for index, current_timestep in enumerate(reversed(self.denoising_step_list)): # from clean to noisy
140
+ shared_timestep[:, index * self.num_frame_per_block:(index + 1) * self.num_frame_per_block] *= current_timestep
141
+
142
+
143
+ # Denoising loop with rolling forcing
144
+ for window_index in range(window_num):
145
+ start_block = window_start_blocks[window_index]
146
+ end_block = window_end_blocks[window_index] # include
147
+
148
+ current_start_frame = start_block * self.num_frame_per_block
149
+ current_end_frame = (end_block + 1) * self.num_frame_per_block # not include
150
+ current_num_frames = current_end_frame - current_start_frame
151
+
152
+ # noisy_input: new noise and previous denoised noisy frames, only last block is pure noise
153
+ if current_num_frames == rolling_window_length_blocks * self.num_frame_per_block or current_start_frame == 0:
154
+ noisy_input = torch.cat([
155
+ noisy_cache[:, current_start_frame : current_end_frame - self.num_frame_per_block],
156
+ noise[:, current_end_frame - self.num_frame_per_block : current_end_frame ]
157
+ ], dim=1)
158
+ else: # at the end of the video
159
+ noisy_input = noisy_cache[:, current_start_frame:current_end_frame].clone()
160
+
161
+ # init denosing timestep
162
+ if current_num_frames == rolling_window_length_blocks * self.num_frame_per_block:
163
+ current_timestep = shared_timestep
164
+ elif current_start_frame == 0:
165
+ current_timestep = shared_timestep[:,-current_num_frames:]
166
+ elif current_end_frame == num_frames:
167
+ current_timestep = shared_timestep[:,:current_num_frames]
168
+ else:
169
+ raise ValueError("current_num_frames should be equal to rolling_window_length_blocks * self.num_frame_per_block, or the first or last window.")
170
+
171
+ require_grad = window_index % rolling_window_length_blocks == exit_flag
172
+ if current_end_frame <= start_gradient_frame_index:
173
+ require_grad = False
174
+
175
+ # calling DiT
176
+ if not require_grad:
177
+ with torch.no_grad():
178
+ _, denoised_pred = self.generator(
179
+ noisy_image_or_video=noisy_input,
180
+ conditional_dict=conditional_dict,
181
+ timestep=current_timestep,
182
+ kv_cache=self.kv_cache_clean,
183
+ crossattn_cache=self.crossattn_cache,
184
+ current_start=current_start_frame * self.frame_seq_length
185
+ )
186
+ else:
187
+ _, denoised_pred = self.generator(
188
+ noisy_image_or_video=noisy_input,
189
+ conditional_dict=conditional_dict,
190
+ timestep=current_timestep,
191
+ kv_cache=self.kv_cache_clean,
192
+ crossattn_cache=self.crossattn_cache,
193
+ current_start=current_start_frame * self.frame_seq_length
194
+ )
195
+ output[:, current_start_frame:current_end_frame] = denoised_pred
196
+
197
+
198
+ # update noisy_cache, which is detached from the computation graph
199
+ with torch.no_grad():
200
+ for block_idx in range(start_block, end_block + 1):
201
+
202
+ block_time_step = current_timestep[:,
203
+ (block_idx - start_block)*self.num_frame_per_block :
204
+ (block_idx - start_block+1)*self.num_frame_per_block].mean().item()
205
+ matches = torch.abs(self.denoising_step_list - block_time_step) < 1e-4
206
+ block_timestep_index = torch.nonzero(matches, as_tuple=True)[0]
207
+
208
+ if block_timestep_index == len(self.denoising_step_list) - 1:
209
+ continue
210
+
211
+ next_timestep = self.denoising_step_list[block_timestep_index + 1].to(noise.device)
212
+
213
+ noisy_cache[:, block_idx * self.num_frame_per_block:
214
+ (block_idx+1) * self.num_frame_per_block] = \
215
+ self.scheduler.add_noise(
216
+ denoised_pred.flatten(0, 1),
217
+ torch.randn_like(denoised_pred.flatten(0, 1)),
218
+ next_timestep * torch.ones(
219
+ [batch_size * current_num_frames], device=noise.device, dtype=torch.long)
220
+ ).unflatten(0, denoised_pred.shape[:2])[:, (block_idx - start_block)*self.num_frame_per_block:
221
+ (block_idx - start_block+1)*self.num_frame_per_block]
222
+
223
+
224
+ # rerun with timestep zero to update the clean cache, which is also detached from the computation graph
225
+ with torch.no_grad():
226
+ context_timestep = torch.ones_like(current_timestep) * self.context_noise
227
+ # # add context noise
228
+ # denoised_pred = self.scheduler.add_noise(
229
+ # denoised_pred.flatten(0, 1),
230
+ # torch.randn_like(denoised_pred.flatten(0, 1)),
231
+ # context_timestep * torch.ones(
232
+ # [batch_size * current_num_frames], device=noise.device, dtype=torch.long)
233
+ # ).unflatten(0, denoised_pred.shape[:2])
234
+
235
+ # only cache the first block
236
+ denoised_pred = denoised_pred[:,:self.num_frame_per_block]
237
+ context_timestep = context_timestep[:,:self.num_frame_per_block]
238
+ self.generator(
239
+ noisy_image_or_video=denoised_pred,
240
+ conditional_dict=conditional_dict,
241
+ timestep=context_timestep,
242
+ kv_cache=self.kv_cache_clean,
243
+ crossattn_cache=self.crossattn_cache,
244
+ current_start=current_start_frame * self.frame_seq_length,
245
+ updating_cache=True,
246
+ )
247
+
248
+ # Step 3.5: Return the denoised timestep
249
+ # can ignore since not used
250
+ denoised_timestep_from, denoised_timestep_to = None, None
251
+
252
+ return output, denoised_timestep_from, denoised_timestep_to
253
+
254
+
255
+
256
+ def inference_with_self_forcing(
257
+ self,
258
+ noise: torch.Tensor,
259
+ initial_latent: Optional[torch.Tensor] = None,
260
+ return_sim_step: bool = False,
261
+ **conditional_dict
262
+ ) -> torch.Tensor:
263
+ batch_size, num_frames, num_channels, height, width = noise.shape
264
+ if not self.independent_first_frame or (self.independent_first_frame and initial_latent is not None):
265
+ # If the first frame is independent and the first frame is provided, then the number of frames in the
266
+ # noise should still be a multiple of num_frame_per_block
267
+ assert num_frames % self.num_frame_per_block == 0
268
+ num_blocks = num_frames // self.num_frame_per_block
269
+ else:
270
+ # Using a [1, 4, 4, 4, 4, 4, ...] model to generate a video without image conditioning
271
+ assert (num_frames - 1) % self.num_frame_per_block == 0
272
+ num_blocks = (num_frames - 1) // self.num_frame_per_block
273
+ num_input_frames = initial_latent.shape[1] if initial_latent is not None else 0
274
+ num_output_frames = num_frames + num_input_frames # add the initial latent frames
275
+ output = torch.zeros(
276
+ [batch_size, num_output_frames, num_channels, height, width],
277
+ device=noise.device,
278
+ dtype=noise.dtype
279
+ )
280
+
281
+ # Step 1: Initialize KV cache to all zeros
282
+ self._initialize_kv_cache(
283
+ batch_size=batch_size, dtype=noise.dtype, device=noise.device
284
+ )
285
+ self._initialize_crossattn_cache(
286
+ batch_size=batch_size, dtype=noise.dtype, device=noise.device
287
+ )
288
+ # if self.kv_cache_clean is None:
289
+ # self._initialize_kv_cache(
290
+ # batch_size=batch_size,
291
+ # dtype=noise.dtype,
292
+ # device=noise.device,
293
+ # )
294
+ # self._initialize_crossattn_cache(
295
+ # batch_size=batch_size,
296
+ # dtype=noise.dtype,
297
+ # device=noise.device
298
+ # )
299
+ # else:
300
+ # # reset cross attn cache
301
+ # for block_index in range(self.num_transformer_blocks):
302
+ # self.crossattn_cache[block_index]["is_init"] = False
303
+ # # reset kv cache
304
+ # for block_index in range(len(self.kv_cache_clean)):
305
+ # self.kv_cache_clean[block_index]["global_end_index"] = torch.tensor(
306
+ # [0], dtype=torch.long, device=noise.device)
307
+ # self.kv_cache_clean[block_index]["local_end_index"] = torch.tensor(
308
+ # [0], dtype=torch.long, device=noise.device)
309
+
310
+ # Step 2: Cache context feature
311
+ current_start_frame = 0
312
+ if initial_latent is not None:
313
+ timestep = torch.ones([batch_size, 1], device=noise.device, dtype=torch.int64) * 0
314
+ # Assume num_input_frames is 1 + self.num_frame_per_block * num_input_blocks
315
+ output[:, :1] = initial_latent
316
+ with torch.no_grad():
317
+ self.generator(
318
+ noisy_image_or_video=initial_latent,
319
+ conditional_dict=conditional_dict,
320
+ timestep=timestep * 0,
321
+ kv_cache=self.kv_cache_clean,
322
+ crossattn_cache=self.crossattn_cache,
323
+ current_start=current_start_frame * self.frame_seq_length
324
+ )
325
+ current_start_frame += 1
326
+
327
+ # Step 3: Temporal denoising loop
328
+ all_num_frames = [self.num_frame_per_block] * num_blocks
329
+ if self.independent_first_frame and initial_latent is None:
330
+ all_num_frames = [1] + all_num_frames
331
+ num_denoising_steps = len(self.denoising_step_list)
332
+ exit_flags = self.generate_and_sync_list(len(all_num_frames), num_denoising_steps, device=noise.device)
333
+ start_gradient_frame_index = num_output_frames - 21
334
+
335
+ # for block_index in range(num_blocks):
336
+ for block_index, current_num_frames in enumerate(all_num_frames):
337
+ noisy_input = noise[
338
+ :, current_start_frame - num_input_frames:current_start_frame + current_num_frames - num_input_frames]
339
+
340
+ # Step 3.1: Spatial denoising loop
341
+ for index, current_timestep in enumerate(self.denoising_step_list):
342
+ if self.same_step_across_blocks:
343
+ exit_flag = (index == exit_flags[0])
344
+ else:
345
+ exit_flag = (index == exit_flags[block_index]) # Only backprop at the randomly selected timestep (consistent across all ranks)
346
+ timestep = torch.ones(
347
+ [batch_size, current_num_frames],
348
+ device=noise.device,
349
+ dtype=torch.int64) * current_timestep
350
+
351
+ if not exit_flag:
352
+ with torch.no_grad():
353
+ _, denoised_pred = self.generator(
354
+ noisy_image_or_video=noisy_input,
355
+ conditional_dict=conditional_dict,
356
+ timestep=timestep,
357
+ kv_cache=self.kv_cache_clean,
358
+ crossattn_cache=self.crossattn_cache,
359
+ current_start=current_start_frame * self.frame_seq_length
360
+ )
361
+ next_timestep = self.denoising_step_list[index + 1]
362
+ noisy_input = self.scheduler.add_noise(
363
+ denoised_pred.flatten(0, 1),
364
+ torch.randn_like(denoised_pred.flatten(0, 1)),
365
+ next_timestep * torch.ones(
366
+ [batch_size * current_num_frames], device=noise.device, dtype=torch.long)
367
+ ).unflatten(0, denoised_pred.shape[:2])
368
+ else:
369
+ # for getting real output
370
+ # with torch.set_grad_enabled(current_start_frame >= start_gradient_frame_index):
371
+ if current_start_frame < start_gradient_frame_index:
372
+ with torch.no_grad():
373
+ _, denoised_pred = self.generator(
374
+ noisy_image_or_video=noisy_input,
375
+ conditional_dict=conditional_dict,
376
+ timestep=timestep,
377
+ kv_cache=self.kv_cache_clean,
378
+ crossattn_cache=self.crossattn_cache,
379
+ current_start=current_start_frame * self.frame_seq_length
380
+ )
381
+ else:
382
+ _, denoised_pred = self.generator(
383
+ noisy_image_or_video=noisy_input,
384
+ conditional_dict=conditional_dict,
385
+ timestep=timestep,
386
+ kv_cache=self.kv_cache_clean,
387
+ crossattn_cache=self.crossattn_cache,
388
+ current_start=current_start_frame * self.frame_seq_length
389
+ )
390
+ break
391
+
392
+ # Step 3.2: record the model's output
393
+ output[:, current_start_frame:current_start_frame + current_num_frames] = denoised_pred
394
+
395
+ # Step 3.3: rerun with timestep zero to update the cache
396
+ context_timestep = torch.ones_like(timestep) * self.context_noise
397
+ # add context noise
398
+ denoised_pred = self.scheduler.add_noise(
399
+ denoised_pred.flatten(0, 1),
400
+ torch.randn_like(denoised_pred.flatten(0, 1)),
401
+ context_timestep * torch.ones(
402
+ [batch_size * current_num_frames], device=noise.device, dtype=torch.long)
403
+ ).unflatten(0, denoised_pred.shape[:2])
404
+ with torch.no_grad():
405
+ self.generator(
406
+ noisy_image_or_video=denoised_pred,
407
+ conditional_dict=conditional_dict,
408
+ timestep=context_timestep,
409
+ kv_cache=self.kv_cache_clean,
410
+ crossattn_cache=self.crossattn_cache,
411
+ current_start=current_start_frame * self.frame_seq_length,
412
+ updating_cache=True,
413
+ )
414
+
415
+ # Step 3.4: update the start and end frame indices
416
+ current_start_frame += current_num_frames
417
+
418
+ # Step 3.5: Return the denoised timestep
419
+ if not self.same_step_across_blocks:
420
+ denoised_timestep_from, denoised_timestep_to = None, None
421
+ elif exit_flags[0] == len(self.denoising_step_list) - 1:
422
+ denoised_timestep_to = 0
423
+ denoised_timestep_from = 1000 - torch.argmin(
424
+ (self.scheduler.timesteps.cuda() - self.denoising_step_list[exit_flags[0]].cuda()).abs(), dim=0).item()
425
+ else:
426
+ denoised_timestep_to = 1000 - torch.argmin(
427
+ (self.scheduler.timesteps.cuda() - self.denoising_step_list[exit_flags[0] + 1].cuda()).abs(), dim=0).item()
428
+ denoised_timestep_from = 1000 - torch.argmin(
429
+ (self.scheduler.timesteps.cuda() - self.denoising_step_list[exit_flags[0]].cuda()).abs(), dim=0).item()
430
+
431
+ if return_sim_step:
432
+ return output, denoised_timestep_from, denoised_timestep_to, exit_flags[0] + 1
433
+
434
+ return output, denoised_timestep_from, denoised_timestep_to
435
+
436
+ def _initialize_kv_cache(self, batch_size, dtype, device):
437
+ """
438
+ Initialize a Per-GPU KV cache for the Wan model.
439
+ """
440
+ kv_cache_clean = []
441
+
442
+ for _ in range(self.num_transformer_blocks):
443
+ kv_cache_clean.append({
444
+ "k": torch.zeros([batch_size, self.kv_cache_size, 12, 128], dtype=dtype, device=device),
445
+ "v": torch.zeros([batch_size, self.kv_cache_size, 12, 128], dtype=dtype, device=device),
446
+ "global_end_index": torch.tensor([0], dtype=torch.long, device=device),
447
+ "local_end_index": torch.tensor([0], dtype=torch.long, device=device)
448
+ })
449
+
450
+ self.kv_cache_clean = kv_cache_clean # always store the clean cache
451
+
452
+ def _initialize_crossattn_cache(self, batch_size, dtype, device):
453
+ """
454
+ Initialize a Per-GPU cross-attention cache for the Wan model.
455
+ """
456
+ crossattn_cache = []
457
+
458
+ for _ in range(self.num_transformer_blocks):
459
+ crossattn_cache.append({
460
+ "k": torch.zeros([batch_size, 512, 12, 128], dtype=dtype, device=device),
461
+ "v": torch.zeros([batch_size, 512, 12, 128], dtype=dtype, device=device),
462
+ "is_init": False
463
+ })
464
+ self.crossattn_cache = crossattn_cache
prompts/example_prompts.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ A cinematic scene from a classic western movie, featuring a rugged man riding a powerful horse through the vast Gobi Desert at sunset. The man, dressed in a dusty cowboy hat and a worn leather jacket, reins tightly on the horse's neck as he gallops across the golden sands. The sun sets dramatically behind them, casting long shadows and warm hues across the landscape. The background is filled with rolling dunes and sparse, rocky outcrops, emphasizing the harsh beauty of the desert. A dynamic wide shot from a low angle, capturing both the man and the expansive desert vista.
2
+ A classic black-and-white photograph style image of an older man playing the piano. The man, with a weathered face and kind eyes, sits at an antique piano with his fingers gracefully moving over the keys. The lighting comes from the side, casting dramatic shadows on his face and emphasizing the texture of his hands. His posture is upright and focused, conveying a sense of deep concentration and passion for music. The background is blurred, revealing only hints of a cozy room with wooden floors and old furniture. A close-up shot from a slightly elevated angle, capturing both the man and the piano in detail.
3
+ A dramatic post-apocalyptic scene in the style of a horror film, featuring a skeleton wearing a colorful flower hat and oversized sunglasses dancing wildly in a sunlit meadow at sunset. The skeleton has a weathered and somewhat decayed appearance, with bones visible through tattered remnants of clothing. The dance is energetic and almost comical, with exaggerated movements. The background is a vivid blend of warm oranges and pinks, with tall grasses and wildflowers swaying in the breeze. The sky is painted with rich hues of orange and pink, casting long shadows across the landscape. A dynamic medium shot from a low angle, capturing the skeleton's animated dance.
4
+ A dynamic action scene in a modern gym, featuring a kangaroo wearing boxing gloves, engaged in an intense sparring session with a punching bag. The kangaroo has a muscular build and is positioned mid-punch, its front legs wrapped in red boxing gloves, eyes focused intently on the target. The background showcases a cluttered gym with heavy equipment and mats, creating a vivid and realistic setting. The kangaroo's movements are fluid and powerful, conveying both agility and strength. The scene captures a split-second moment of mid-action, with the kangaroo's tail swaying behind it. A high-angle shot emphasizing the kangaroo's dynamic pose and the surrounding gym environment.
5
+ A dynamic action shot in the style of a high-energy sports magazine spread, featuring a golden retriever sprinting with all its might after a red sports car speeding down the road. The dog's fur glistens in the sunlight, and its eyes are filled with determination and excitement. It leaps forward, its tail wagging wildly, while the car speeds away in the background, leaving a trail of dust. The background shows a busy city street with blurred cars and pedestrians, adding to the sense of urgency. The photo has a crisp, vibrant color palette and a high-resolution quality. A medium-long shot capturing the dog's full run.
6
+ A dynamic action shot in the style of a professional skateboard magazine, featuring a young male longboarder accelerating downhill. He is fully focused, his expression intense and determined, carving through tight turns with precision. His longboard glides smoothly over the pavement, creating a blur of motion. He wears a black longboard shirt, blue jeans, and white sneakers, with a backpack slung over one shoulder. His hair flows behind him as he moves, and he grips the board tightly with both hands. The background shows a scenic urban street with blurred buildings and trees, hinting at a lively cityscape. The photo captures the moment just after he exits a turn, with a slight bounce in the board and a sense of speed and agility. A medium shot with a slightly elevated camera angle.
7
+ A dynamic hip-hop dance scene in a vibrant urban style, featuring an Asian girl in a bright yellow T-shirt and white pants. She is mid-dance move, arms stretched out and feet rhythmically stepping, exuding energy and confidence. Her hair is tied up in a ponytail, and she has a mischievous smile on her face. The background shows a bustling city street with blurred reflections of tall buildings and passing cars. The scene captures the lively and energetic atmosphere of a hip-hop performance, with a slightly grainy texture. A medium shot from a low-angle perspective.
8
+ A dynamic tracking shot following a skateboarder performing a series of fluid tricks down a bustling city street. The skateboarder, wearing a black helmet and a colorful shirt, moves with grace and confidence, executing flips, grinds, and spins. The camera captures the skateboarder's fluid movements, capturing the essence of each trick with precision. The background showcases the urban environment, with tall buildings, busy traffic, and passersby in the distance. The lighting highlights the skateboarder's movements, creating a sense of speed and energy. The overall style is reminiscent of a skateboarding documentary, emphasizing the natural and dynamic nature of the tricks.
9
+ A handheld camera captures a dog running through a park with a joyful exploration, the camera following the dog closely and bouncing and tilting with its movements. The dog bounds through the grass, tail wagging excitedly, sniffing at flowers and chasing after butterflies. Its fur glistens in the sunlight, and its eyes sparkle with enthusiasm. The park is filled with trees and colorful blooms, and the background shows a blurred path leading into the distance. The camera angle changes dynamically, providing a sense of the dog's lively energy and the vibrant environment around it.
10
+ A handheld shot following a young child running through a field of tall grass, capturing the spontaneity and playfulness of their movements. The child has curly brown hair and a mischievous smile, arms swinging freely as they sprint across the green expanse. Their small feet kick up bits of grass and dirt, creating a trail behind them. The background features a blurred landscape with rolling hills and scattered wildflowers, bathed in warm sunlight. The photo has a natural, documentary-style quality, emphasizing the dynamic motion and joy of the moment. A dynamic handheld shot from a slightly elevated angle, following the child's energetic run.
11
+ A high-speed action shot of a cheetah in its natural habitat, sprinting at full speed while chasing its prey across the savanna. The cheetah's golden fur glistens under the bright African sun, and its muscular body is stretched out in a powerful run. Its sharp eyes focus intently on the fleeing antelope, and its distinctive black tear marks streak down its face. The background is a blurred landscape with tall grass swaying in the wind, and distant acacia trees. The cheetah's tail is raised high, and its paws leave deep prints in the soft earth. A dynamic mid-shot capturing the intense moment of pursuit.
12
+ A photograph in a soft, warm lighting style, capturing a young woman with a bright smile and a playful wink. She has long curly brown hair and warm hazel eyes, with a slightly flushed cheeks from laughter. She is dressed in a casual yet stylish outfit: a floral printed sundress with a flowy skirt and a fitted top. Her hands are on her hips, giving a casual pose. The background features a blurred outdoor garden setting with blooming flowers and greenery. A medium shot from a slightly above-the-shoulder angle, emphasizing her joyful expression and the natural movement of her face.
13
+ A poignant moment captured in a realistic photographic style, showing a middle-aged man with a rugged face and slightly tousled hair, his chin quivering with emotion as he says a heartfelt goodbye to a loved one. He wears a simple grey sweater and jeans, standing on a dewy grassy field under a clear blue sky, with fluffy white clouds in the background. The camera angle is slightly from below, emphasizing his sorrowful expression and the depth of his feelings. A medium shot with a soft focus on the man's face and a blurred background.
14
+ A realistic photo of a llama wearing colorful pajamas dancing energetically on a stage under vibrant disco lighting. The llama has large floppy ears and a playful expression, moving its legs in a lively dance. It wears a red and yellow striped pajama top and matching pajama pants, with a fluffy tail swaying behind it. The stage is adorned with glittering disco balls and colorful lights, casting a lively and joyful atmosphere. The background features blurred audience members and a backdrop with disco-themed decorations. A dynamic shot capturing the llama mid-dance from a slightly elevated angle.
15
+ An adorable kangaroo, dressed in a cute green dress with polka dots, is wearing a small sun hat perched on its head. The kangaroo takes a pleasant stroll through the bustling streets of Mumbai during a vibrant and colorful festival. The background is filled with lively festival-goers in traditional Indian attire, adorned with intricate henna designs and bright jewelry. The scene is filled with colorful decorations, vendors selling various items, and people dancing and singing. The kangaroo moves gracefully, hopping along the cobblestone streets, its tail swinging behind it. The camera angle captures the kangaroo from a slight overhead perspective, highlighting its joyful expression and the festive atmosphere. A medium shot with dynamic movement.
16
+ An atmospheric and dramatic arc shot around a lone tree standing in a vast, foggy field at dawn. The early morning light filters through the mist, casting a soft, warm glow on the tree and the surrounding landscape. The tree's branches stretch out against the backdrop of a gradually lightening sky, with the shadows shifting and changing as the sun rises. The field is dotted with tall grasses and scattered wildflowers, their silhouettes softened by the fog. The overall scene has a moody, ethereal quality, emphasizing the natural movement of the fog and the subtle changes in light and shadow. A dynamic arc shot capturing the transition from night to day.
requirements.txt ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.5.1
2
+ torchvision==0.20.1
3
+ torchaudio==2.5.1
4
+ opencv-python>=4.9.0.80
5
+ diffusers==0.31.0
6
+ transformers>=4.49.0
7
+ tokenizers>=0.20.3
8
+ accelerate>=1.1.1
9
+ tqdm
10
+ imageio
11
+ easydict
12
+ ftfy
13
+ dashscope
14
+ imageio-ffmpeg
15
+ numpy==1.24.4
16
+ wandb
17
+ omegaconf
18
+ einops
19
+ av==13.1.0
20
+ opencv-python
21
+ open_clip_torch
22
+ starlette
23
+ pycocotools
24
+ lmdb
25
+ matplotlib
26
+ sentencepiece
27
+ pydantic==2.10.6
28
+ scikit-image
29
+ huggingface_hub
30
+ dominate
31
+ nvidia-pyindex
32
+ nvidia-tensorrt
33
+ pycuda
34
+ onnx
35
+ onnxruntime
36
+ onnxscript
37
+ onnxconverter_common
38
+ flask
39
+ flask-socketio
40
+ torchao
41
+ tensorboard
42
+ ninja
43
+ packaging
44
+ --no-build-isolation
45
+ flash-attn
train.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from omegaconf import OmegaConf
4
+
5
+ from trainer import DiffusionTrainer, GANTrainer, ODETrainer, ScoreDistillationTrainer
6
+
7
+
8
+ def main():
9
+ parser = argparse.ArgumentParser()
10
+ parser.add_argument("--config_path", type=str, required=True)
11
+ parser.add_argument("--no_save", action="store_true")
12
+ parser.add_argument("--no_visualize", action="store_true")
13
+ parser.add_argument("--logdir", type=str, default="", help="Path to the directory to save logs")
14
+ parser.add_argument("--wandb-save-dir", type=str, default="", help="Path to the directory to save wandb logs")
15
+ parser.add_argument("--disable-wandb", default=False, action="store_true")
16
+
17
+ args = parser.parse_args()
18
+
19
+ config = OmegaConf.load(args.config_path)
20
+ default_config = OmegaConf.load("configs/default_config.yaml")
21
+ config = OmegaConf.merge(default_config, config)
22
+ config.no_save = args.no_save
23
+ config.no_visualize = args.no_visualize
24
+
25
+ # get the filename of config_path
26
+ config_name = os.path.basename(args.config_path).split(".")[0]
27
+ config.config_name = config_name
28
+ config.logdir = args.logdir
29
+ config.wandb_save_dir = args.wandb_save_dir
30
+ config.disable_wandb = args.disable_wandb
31
+
32
+ if config.trainer == "diffusion":
33
+ trainer = DiffusionTrainer(config)
34
+ elif config.trainer == "gan":
35
+ trainer = GANTrainer(config)
36
+ elif config.trainer == "ode":
37
+ trainer = ODETrainer(config)
38
+ elif config.trainer == "score_distillation":
39
+ trainer = ScoreDistillationTrainer(config)
40
+ trainer.train()
41
+
42
+
43
+
44
+ if __name__ == "__main__":
45
+ main()
trainer/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .diffusion import Trainer as DiffusionTrainer
2
+ from .gan import Trainer as GANTrainer
3
+ from .ode import Trainer as ODETrainer
4
+ from .distillation import Trainer as ScoreDistillationTrainer
5
+
6
+ __all__ = [
7
+ "DiffusionTrainer",
8
+ "GANTrainer",
9
+ "ODETrainer",
10
+ "ScoreDistillationTrainer"
11
+ ]
trainer/diffusion.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import logging
3
+
4
+ from model import CausalDiffusion
5
+ from utils.dataset import ShardingLMDBDataset, cycle
6
+ from utils.misc import set_seed
7
+ import torch.distributed as dist
8
+ from omegaconf import OmegaConf
9
+ import torch
10
+ import wandb
11
+ import time
12
+ import os
13
+
14
+ from utils.distributed import EMA_FSDP, barrier, fsdp_wrap, fsdp_state_dict, launch_distributed_job
15
+
16
+
17
+ class Trainer:
18
+ def __init__(self, config):
19
+ self.config = config
20
+ self.step = 0
21
+
22
+ # Step 1: Initialize the distributed training environment (rank, seed, dtype, logging etc.)
23
+ torch.backends.cuda.matmul.allow_tf32 = True
24
+ torch.backends.cudnn.allow_tf32 = True
25
+
26
+ launch_distributed_job()
27
+ global_rank = dist.get_rank()
28
+
29
+ self.dtype = torch.bfloat16 if config.mixed_precision else torch.float32
30
+ self.device = torch.cuda.current_device()
31
+ self.is_main_process = global_rank == 0
32
+ self.causal = config.causal
33
+ self.disable_wandb = config.disable_wandb
34
+
35
+ # use a random seed for the training
36
+ if config.seed == 0:
37
+ random_seed = torch.randint(0, 10000000, (1,), device=self.device)
38
+ dist.broadcast(random_seed, src=0)
39
+ config.seed = random_seed.item()
40
+
41
+ set_seed(config.seed + global_rank)
42
+
43
+ if self.is_main_process and not self.disable_wandb:
44
+ wandb.login(host=config.wandb_host, key=config.wandb_key)
45
+ wandb.init(
46
+ config=OmegaConf.to_container(config, resolve=True),
47
+ name=config.config_name,
48
+ mode="online",
49
+ entity=config.wandb_entity,
50
+ project=config.wandb_project,
51
+ dir=config.wandb_save_dir
52
+ )
53
+
54
+ self.output_path = config.logdir
55
+
56
+ # Step 2: Initialize the model and optimizer
57
+ self.model = CausalDiffusion(config, device=self.device)
58
+ self.model.generator = fsdp_wrap(
59
+ self.model.generator,
60
+ sharding_strategy=config.sharding_strategy,
61
+ mixed_precision=config.mixed_precision,
62
+ wrap_strategy=config.generator_fsdp_wrap_strategy
63
+ )
64
+
65
+ self.model.text_encoder = fsdp_wrap(
66
+ self.model.text_encoder,
67
+ sharding_strategy=config.sharding_strategy,
68
+ mixed_precision=config.mixed_precision,
69
+ wrap_strategy=config.text_encoder_fsdp_wrap_strategy
70
+ )
71
+
72
+ if not config.no_visualize or config.load_raw_video:
73
+ self.model.vae = self.model.vae.to(
74
+ device=self.device, dtype=torch.bfloat16 if config.mixed_precision else torch.float32)
75
+
76
+ self.generator_optimizer = torch.optim.AdamW(
77
+ [param for param in self.model.generator.parameters()
78
+ if param.requires_grad],
79
+ lr=config.lr,
80
+ betas=(config.beta1, config.beta2),
81
+ weight_decay=config.weight_decay
82
+ )
83
+
84
+ # Step 3: Initialize the dataloader
85
+ dataset = ShardingLMDBDataset(config.data_path, max_pair=int(1e8))
86
+ sampler = torch.utils.data.distributed.DistributedSampler(
87
+ dataset, shuffle=True, drop_last=True)
88
+ dataloader = torch.utils.data.DataLoader(
89
+ dataset,
90
+ batch_size=config.batch_size,
91
+ sampler=sampler,
92
+ num_workers=8)
93
+
94
+ if dist.get_rank() == 0:
95
+ print("DATASET SIZE %d" % len(dataset))
96
+ self.dataloader = cycle(dataloader)
97
+
98
+ ##############################################################################################################
99
+ # 6. Set up EMA parameter containers
100
+ rename_param = (
101
+ lambda name: name.replace("_fsdp_wrapped_module.", "")
102
+ .replace("_checkpoint_wrapped_module.", "")
103
+ .replace("_orig_mod.", "")
104
+ )
105
+ self.name_to_trainable_params = {}
106
+ for n, p in self.model.generator.named_parameters():
107
+ if not p.requires_grad:
108
+ continue
109
+
110
+ renamed_n = rename_param(n)
111
+ self.name_to_trainable_params[renamed_n] = p
112
+ ema_weight = config.ema_weight
113
+ self.generator_ema = None
114
+ if (ema_weight is not None) and (ema_weight > 0.0):
115
+ print(f"Setting up EMA with weight {ema_weight}")
116
+ self.generator_ema = EMA_FSDP(self.model.generator, decay=ema_weight)
117
+
118
+ ##############################################################################################################
119
+ # 7. (If resuming) Load the model and optimizer, lr_scheduler, ema's statedicts
120
+ if getattr(config, "generator_ckpt", False):
121
+ print(f"Loading pretrained generator from {config.generator_ckpt}")
122
+ state_dict = torch.load(config.generator_ckpt, map_location="cpu")
123
+ if "generator" in state_dict:
124
+ state_dict = state_dict["generator"]
125
+ elif "model" in state_dict:
126
+ state_dict = state_dict["model"]
127
+ self.model.generator.load_state_dict(
128
+ state_dict, strict=True
129
+ )
130
+
131
+ ##############################################################################################################
132
+
133
+ # Let's delete EMA params for early steps to save some computes at training and inference
134
+ if self.step < config.ema_start_step:
135
+ self.generator_ema = None
136
+
137
+ self.max_grad_norm = 10.0
138
+ self.previous_time = None
139
+
140
+ def save(self):
141
+ print("Start gathering distributed model states...")
142
+ generator_state_dict = fsdp_state_dict(
143
+ self.model.generator)
144
+
145
+ if self.config.ema_start_step < self.step:
146
+ state_dict = {
147
+ "generator": generator_state_dict,
148
+ "generator_ema": self.generator_ema.state_dict(),
149
+ }
150
+ else:
151
+ state_dict = {
152
+ "generator": generator_state_dict,
153
+ }
154
+
155
+ if self.is_main_process:
156
+ os.makedirs(os.path.join(self.output_path,
157
+ f"checkpoint_model_{self.step:06d}"), exist_ok=True)
158
+ torch.save(state_dict, os.path.join(self.output_path,
159
+ f"checkpoint_model_{self.step:06d}", "model.pt"))
160
+ print("Model saved to", os.path.join(self.output_path,
161
+ f"checkpoint_model_{self.step:06d}", "model.pt"))
162
+
163
+ def train_one_step(self, batch):
164
+ self.log_iters = 1
165
+
166
+ if self.step % 20 == 0:
167
+ torch.cuda.empty_cache()
168
+
169
+ # Step 1: Get the next batch of text prompts
170
+ text_prompts = batch["prompts"]
171
+ if not self.config.load_raw_video: # precomputed latent
172
+ clean_latent = batch["ode_latent"][:, -1].to(
173
+ device=self.device, dtype=self.dtype)
174
+ else: # encode raw video to latent
175
+ frames = batch["frames"].to(
176
+ device=self.device, dtype=self.dtype)
177
+ with torch.no_grad():
178
+ clean_latent = self.model.vae.encode_to_latent(
179
+ frames).to(device=self.device, dtype=self.dtype)
180
+ image_latent = clean_latent[:, 0:1, ]
181
+
182
+ batch_size = len(text_prompts)
183
+ image_or_video_shape = list(self.config.image_or_video_shape)
184
+ image_or_video_shape[0] = batch_size
185
+
186
+ # Step 2: Extract the conditional infos
187
+ with torch.no_grad():
188
+ conditional_dict = self.model.text_encoder(
189
+ text_prompts=text_prompts)
190
+
191
+ if not getattr(self, "unconditional_dict", None):
192
+ unconditional_dict = self.model.text_encoder(
193
+ text_prompts=[self.config.negative_prompt] * batch_size)
194
+ unconditional_dict = {k: v.detach()
195
+ for k, v in unconditional_dict.items()}
196
+ self.unconditional_dict = unconditional_dict # cache the unconditional_dict
197
+ else:
198
+ unconditional_dict = self.unconditional_dict
199
+
200
+ # Step 3: Train the generator
201
+ generator_loss, log_dict = self.model.generator_loss(
202
+ image_or_video_shape=image_or_video_shape,
203
+ conditional_dict=conditional_dict,
204
+ unconditional_dict=unconditional_dict,
205
+ clean_latent=clean_latent,
206
+ initial_latent=image_latent
207
+ )
208
+ self.generator_optimizer.zero_grad()
209
+ generator_loss.backward()
210
+ generator_grad_norm = self.model.generator.clip_grad_norm_(
211
+ self.max_grad_norm)
212
+ self.generator_optimizer.step()
213
+
214
+ # Increment the step since we finished gradient update
215
+ self.step += 1
216
+
217
+ wandb_loss_dict = {
218
+ "generator_loss": generator_loss.item(),
219
+ "generator_grad_norm": generator_grad_norm.item(),
220
+ }
221
+
222
+ # Step 4: Logging
223
+ if self.is_main_process:
224
+ if not self.disable_wandb:
225
+ wandb.log(wandb_loss_dict, step=self.step)
226
+
227
+ if self.step % self.config.gc_interval == 0:
228
+ if dist.get_rank() == 0:
229
+ logging.info("DistGarbageCollector: Running GC.")
230
+ gc.collect()
231
+
232
+ # Step 5. Create EMA params
233
+ # TODO: Implement EMA
234
+
235
+ def generate_video(self, pipeline, prompts, image=None):
236
+ batch_size = len(prompts)
237
+ sampled_noise = torch.randn(
238
+ [batch_size, 21, 16, 60, 104], device="cuda", dtype=self.dtype
239
+ )
240
+ video, _ = pipeline.inference(
241
+ noise=sampled_noise,
242
+ text_prompts=prompts,
243
+ return_latents=True
244
+ )
245
+ current_video = video.permute(0, 1, 3, 4, 2).cpu().numpy() * 255.0
246
+ return current_video
247
+
248
+ def train(self):
249
+ while True:
250
+ batch = next(self.dataloader)
251
+ self.train_one_step(batch)
252
+ if (not self.config.no_save) and self.step % self.config.log_iters == 0:
253
+ torch.cuda.empty_cache()
254
+ self.save()
255
+ torch.cuda.empty_cache()
256
+
257
+ barrier()
258
+ if self.is_main_process:
259
+ current_time = time.time()
260
+ if self.previous_time is None:
261
+ self.previous_time = current_time
262
+ else:
263
+ if not self.disable_wandb:
264
+ wandb.log({"per iteration time": current_time - self.previous_time}, step=self.step)
265
+ self.previous_time = current_time
trainer/distillation.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import logging
3
+
4
+ from utils.dataset import ShardingLMDBDataset, cycle
5
+ from utils.dataset import TextDataset
6
+ from utils.distributed import EMA_FSDP, fsdp_wrap, fsdp_state_dict, launch_distributed_job
7
+ from utils.misc import (
8
+ set_seed,
9
+ merge_dict_list
10
+ )
11
+ import torch.distributed as dist
12
+ from omegaconf import OmegaConf
13
+ from model import CausVid, DMD, SiD
14
+ import torch
15
+ from torch.utils.tensorboard import SummaryWriter
16
+ import time
17
+ import os
18
+
19
+
20
+ class Trainer:
21
+ def __init__(self, config):
22
+ self.config = config
23
+ self.step = 0
24
+
25
+ # Step 1: Initialize the distributed training environment (rank, seed, dtype, logging etc.)
26
+ torch.backends.cuda.matmul.allow_tf32 = True
27
+ torch.backends.cudnn.allow_tf32 = True
28
+
29
+ launch_distributed_job()
30
+ global_rank = dist.get_rank()
31
+ self.world_size = dist.get_world_size()
32
+
33
+ self.dtype = torch.bfloat16 if config.mixed_precision else torch.float32
34
+ self.device = torch.cuda.current_device()
35
+ self.is_main_process = global_rank == 0
36
+ self.causal = config.causal
37
+
38
+ # use a random seed for the training
39
+ if config.seed == 0:
40
+ random_seed = torch.randint(0, 10000000, (1,), device=self.device)
41
+ dist.broadcast(random_seed, src=0)
42
+ config.seed = random_seed.item()
43
+
44
+ set_seed(config.seed + global_rank)
45
+
46
+ if self.is_main_process:
47
+ self.writer = SummaryWriter(
48
+ log_dir=os.path.join(config.logdir, "tensorboard"),
49
+ flush_secs=10
50
+ )
51
+
52
+ self.output_path = config.logdir
53
+
54
+ # Step 2: Initialize the model and optimizer
55
+ if config.distribution_loss == "causvid":
56
+ self.model = CausVid(config, device=self.device)
57
+ elif config.distribution_loss == "dmd":
58
+ self.model = DMD(config, device=self.device)
59
+ elif config.distribution_loss == "sid":
60
+ self.model = SiD(config, device=self.device)
61
+ else:
62
+ raise ValueError("Invalid distribution matching loss")
63
+
64
+ # Save pretrained model state_dicts to CPU
65
+ self.fake_score_state_dict_cpu = self.model.fake_score.state_dict()
66
+
67
+ self.model.generator = fsdp_wrap(
68
+ self.model.generator,
69
+ sharding_strategy=config.sharding_strategy,
70
+ mixed_precision=config.mixed_precision,
71
+ wrap_strategy=config.generator_fsdp_wrap_strategy
72
+ )
73
+
74
+ self.model.real_score = fsdp_wrap(
75
+ self.model.real_score,
76
+ sharding_strategy=config.sharding_strategy,
77
+ mixed_precision=config.mixed_precision,
78
+ wrap_strategy=config.real_score_fsdp_wrap_strategy
79
+ )
80
+
81
+ self.model.fake_score = fsdp_wrap(
82
+ self.model.fake_score,
83
+ sharding_strategy=config.sharding_strategy,
84
+ mixed_precision=config.mixed_precision,
85
+ wrap_strategy=config.fake_score_fsdp_wrap_strategy
86
+ )
87
+
88
+ self.model.text_encoder = fsdp_wrap(
89
+ self.model.text_encoder,
90
+ sharding_strategy=config.sharding_strategy,
91
+ mixed_precision=config.mixed_precision,
92
+ wrap_strategy=config.text_encoder_fsdp_wrap_strategy,
93
+ cpu_offload=getattr(config, "text_encoder_cpu_offload", False)
94
+ )
95
+
96
+ if not config.no_visualize or config.load_raw_video:
97
+ self.model.vae = self.model.vae.to(
98
+ device=self.device, dtype=torch.bfloat16 if config.mixed_precision else torch.float32)
99
+
100
+ self.generator_optimizer = torch.optim.AdamW(
101
+ [param for param in self.model.generator.parameters()
102
+ if param.requires_grad],
103
+ lr=config.lr,
104
+ betas=(config.beta1, config.beta2),
105
+ weight_decay=config.weight_decay
106
+ )
107
+
108
+ self.critic_optimizer = torch.optim.AdamW(
109
+ [param for param in self.model.fake_score.parameters()
110
+ if param.requires_grad],
111
+ lr=config.lr_critic if hasattr(config, "lr_critic") else config.lr,
112
+ betas=(config.beta1_critic, config.beta2_critic),
113
+ weight_decay=config.weight_decay
114
+ )
115
+
116
+ # Step 3: Initialize the dataloader
117
+ if self.config.i2v:
118
+ dataset = ShardingLMDBDataset(config.data_path, max_pair=int(1e8))
119
+ else:
120
+ dataset = TextDataset(config.data_path)
121
+ sampler = torch.utils.data.distributed.DistributedSampler(
122
+ dataset, shuffle=True, drop_last=True)
123
+ dataloader = torch.utils.data.DataLoader(
124
+ dataset,
125
+ batch_size=config.batch_size,
126
+ sampler=sampler,
127
+ num_workers=8)
128
+
129
+ if dist.get_rank() == 0:
130
+ print("DATASET SIZE %d" % len(dataset))
131
+ self.dataloader = cycle(dataloader)
132
+
133
+ ##############################################################################################################
134
+ # 6. Set up EMA parameter containers
135
+ rename_param = (
136
+ lambda name: name.replace("_fsdp_wrapped_module.", "")
137
+ .replace("_checkpoint_wrapped_module.", "")
138
+ .replace("_orig_mod.", "")
139
+ )
140
+ self.name_to_trainable_params = {}
141
+ for n, p in self.model.generator.named_parameters():
142
+ if not p.requires_grad:
143
+ continue
144
+
145
+ renamed_n = rename_param(n)
146
+ self.name_to_trainable_params[renamed_n] = p
147
+ ema_weight = config.ema_weight
148
+ self.generator_ema = None
149
+ if (ema_weight is not None) and (ema_weight > 0.0):
150
+ print(f"Setting up EMA with weight {ema_weight}")
151
+ self.generator_ema = EMA_FSDP(self.model.generator, decay=ema_weight)
152
+
153
+ ##############################################################################################################
154
+ # 7. (If resuming) Load the model and optimizer, lr_scheduler, ema's statedicts
155
+ if getattr(config, "generator_ckpt", False):
156
+ print(f"Loading pretrained generator from {config.generator_ckpt}")
157
+ state_dict = torch.load(config.generator_ckpt, map_location="cpu")
158
+ if "generator" in state_dict:
159
+ state_dict = state_dict["generator"]
160
+ elif "model" in state_dict:
161
+ state_dict = state_dict["model"]
162
+ self.model.generator.load_state_dict(
163
+ state_dict, strict=True
164
+ )
165
+
166
+ ##############################################################################################################
167
+
168
+ # Let's delete EMA params for early steps to save some computes at training and inference
169
+ if self.step < config.ema_start_step:
170
+ self.generator_ema = None
171
+
172
+ self.max_grad_norm_generator = getattr(config, "max_grad_norm_generator", 10.0)
173
+ self.max_grad_norm_critic = getattr(config, "max_grad_norm_critic", 10.0)
174
+ self.previous_time = None
175
+
176
+ def save(self):
177
+ print("Start gathering distributed model states...")
178
+ generator_state_dict = fsdp_state_dict(
179
+ self.model.generator)
180
+ critic_state_dict = fsdp_state_dict(
181
+ self.model.fake_score)
182
+
183
+ if self.config.ema_start_step < self.step:
184
+ state_dict = {
185
+ "generator": generator_state_dict,
186
+ "critic": critic_state_dict,
187
+ "generator_ema": self.generator_ema.state_dict(),
188
+ }
189
+ else:
190
+ state_dict = {
191
+ "generator": generator_state_dict,
192
+ "critic": critic_state_dict,
193
+ }
194
+
195
+ if self.is_main_process:
196
+ os.makedirs(os.path.join(self.output_path,
197
+ f"checkpoint_model_{self.step:06d}"), exist_ok=True)
198
+ torch.save(state_dict, os.path.join(self.output_path,
199
+ f"checkpoint_model_{self.step:06d}", "model.pt"))
200
+ print("Model saved to", os.path.join(self.output_path,
201
+ f"checkpoint_model_{self.step:06d}", "model.pt"))
202
+
203
+ def fwdbwd_one_step(self, batch, train_generator):
204
+ self.model.eval() # prevent any randomness (e.g. dropout)
205
+
206
+ if self.step % 20 == 0:
207
+ torch.cuda.empty_cache()
208
+
209
+ # Step 1: Get the next batch of text prompts
210
+ text_prompts = batch["prompts"]
211
+ if self.config.i2v:
212
+ clean_latent = None
213
+ image_latent = batch["ode_latent"][:, -1][:, 0:1, ].to(
214
+ device=self.device, dtype=self.dtype)
215
+ else:
216
+ clean_latent = None
217
+ image_latent = None
218
+
219
+ batch_size = len(text_prompts)
220
+ image_or_video_shape = list(self.config.image_or_video_shape)
221
+ image_or_video_shape[0] = batch_size
222
+
223
+ # Step 2: Extract the conditional infos
224
+ with torch.no_grad():
225
+ conditional_dict = self.model.text_encoder(
226
+ text_prompts=text_prompts)
227
+
228
+ if not getattr(self, "unconditional_dict", None):
229
+ unconditional_dict = self.model.text_encoder(
230
+ text_prompts=[self.config.negative_prompt] * batch_size)
231
+ unconditional_dict = {k: v.detach()
232
+ for k, v in unconditional_dict.items()}
233
+ self.unconditional_dict = unconditional_dict # cache the unconditional_dict
234
+ else:
235
+ unconditional_dict = self.unconditional_dict
236
+
237
+ # Step 3: Store gradients for the generator (if training the generator)
238
+ if train_generator:
239
+ generator_loss, generator_log_dict = self.model.generator_loss(
240
+ image_or_video_shape=image_or_video_shape,
241
+ conditional_dict=conditional_dict,
242
+ unconditional_dict=unconditional_dict,
243
+ clean_latent=clean_latent,
244
+ initial_latent=image_latent if self.config.i2v else None
245
+ )
246
+
247
+ generator_loss.backward()
248
+ generator_grad_norm = self.model.generator.clip_grad_norm_(
249
+ self.max_grad_norm_generator)
250
+
251
+ generator_log_dict.update({"generator_loss": generator_loss,
252
+ "generator_grad_norm": generator_grad_norm})
253
+
254
+ return generator_log_dict
255
+ else:
256
+ generator_log_dict = {}
257
+
258
+ # Step 4: Store gradients for the critic (if training the critic)
259
+ critic_loss, critic_log_dict = self.model.critic_loss(
260
+ image_or_video_shape=image_or_video_shape,
261
+ conditional_dict=conditional_dict,
262
+ unconditional_dict=unconditional_dict,
263
+ clean_latent=clean_latent,
264
+ initial_latent=image_latent if self.config.i2v else None
265
+ )
266
+
267
+ critic_loss.backward()
268
+ critic_grad_norm = self.model.fake_score.clip_grad_norm_(
269
+ self.max_grad_norm_critic)
270
+
271
+ critic_log_dict.update({"critic_loss": critic_loss,
272
+ "critic_grad_norm": critic_grad_norm})
273
+
274
+ return critic_log_dict
275
+
276
+ def generate_video(self, pipeline, prompts, image=None):
277
+ batch_size = len(prompts)
278
+ if image is not None:
279
+ image = image.squeeze(0).unsqueeze(0).unsqueeze(2).to(device="cuda", dtype=torch.bfloat16)
280
+
281
+ # Encode the input image as the first latent
282
+ initial_latent = pipeline.vae.encode_to_latent(image).to(device="cuda", dtype=torch.bfloat16)
283
+ initial_latent = initial_latent.repeat(batch_size, 1, 1, 1, 1)
284
+ sampled_noise = torch.randn(
285
+ [batch_size, self.model.num_training_frames - 1, 16, 60, 104],
286
+ device="cuda",
287
+ dtype=self.dtype
288
+ )
289
+ else:
290
+ initial_latent = None
291
+ sampled_noise = torch.randn(
292
+ [batch_size, self.model.num_training_frames, 16, 60, 104],
293
+ device="cuda",
294
+ dtype=self.dtype
295
+ )
296
+
297
+ video, _ = pipeline.inference(
298
+ noise=sampled_noise,
299
+ text_prompts=prompts,
300
+ return_latents=True,
301
+ initial_latent=initial_latent
302
+ )
303
+ current_video = video.permute(0, 1, 3, 4, 2).cpu().numpy() * 255.0
304
+ return current_video
305
+
306
+ def train(self):
307
+ start_step = self.step
308
+
309
+ while True:
310
+ TRAIN_GENERATOR = self.step % self.config.dfake_gen_update_ratio == 0
311
+
312
+ # Train the generator
313
+ if TRAIN_GENERATOR:
314
+ self.generator_optimizer.zero_grad(set_to_none=True)
315
+ extras_list = []
316
+ batch = next(self.dataloader)
317
+ extra = self.fwdbwd_one_step(batch, True)
318
+ extras_list.append(extra)
319
+ generator_log_dict = merge_dict_list(extras_list)
320
+ self.generator_optimizer.step()
321
+ if self.generator_ema is not None:
322
+ self.generator_ema.update(self.model.generator)
323
+
324
+ # Train the critic
325
+ self.critic_optimizer.zero_grad(set_to_none=True)
326
+ extras_list = []
327
+ batch = next(self.dataloader)
328
+ extra = self.fwdbwd_one_step(batch, False)
329
+ extras_list.append(extra)
330
+ critic_log_dict = merge_dict_list(extras_list)
331
+ self.critic_optimizer.step()
332
+
333
+ # Increment the step since we finished gradient update
334
+ self.step += 1
335
+
336
+ # Create EMA params (if not already created)
337
+ if (self.step >= self.config.ema_start_step) and \
338
+ (self.generator_ema is None) and (self.config.ema_weight > 0):
339
+ self.generator_ema = EMA_FSDP(self.model.generator, decay=self.config.ema_weight)
340
+
341
+ # Save the model
342
+ if (not self.config.no_save) and (self.step - start_step) > 0 and self.step % self.config.log_iters == 0:
343
+ torch.cuda.empty_cache()
344
+ self.save()
345
+ torch.cuda.empty_cache()
346
+
347
+ # Logging
348
+ if self.is_main_process:
349
+
350
+ if TRAIN_GENERATOR:
351
+ self.writer.add_scalar(
352
+ "generator_loss",
353
+ generator_log_dict["generator_loss"].mean().item(),
354
+ self.step
355
+ )
356
+ self.writer.add_scalar(
357
+ "generator_grad_norm",
358
+ generator_log_dict["generator_grad_norm"].mean().item(),
359
+ self.step
360
+ )
361
+ self.writer.add_scalar(
362
+ "dmdtrain_gradient_norm",
363
+ generator_log_dict["dmdtrain_gradient_norm"].mean().item(),
364
+ self.step
365
+ )
366
+
367
+ self.writer.add_scalar(
368
+ "critic_loss",
369
+ critic_log_dict["critic_loss"].mean().item(),
370
+ self.step
371
+ )
372
+ self.writer.add_scalar(
373
+ "critic_grad_norm",
374
+ critic_log_dict["critic_grad_norm"].mean().item(),
375
+ self.step
376
+ )
377
+
378
+ if self.step % self.config.gc_interval == 0:
379
+ if dist.get_rank() == 0:
380
+ logging.info("DistGarbageCollector: Running GC.")
381
+ gc.collect()
382
+ torch.cuda.empty_cache()
383
+
384
+ if self.is_main_process:
385
+ current_time = time.time()
386
+ if self.previous_time is None:
387
+ self.previous_time = current_time
388
+ else:
389
+ self.writer.add_scalar(
390
+ "per iteration time",
391
+ current_time - self.previous_time,
392
+ self.step
393
+ )
394
+ print(
395
+ f"Step {self.step} | "
396
+ f"Iteration time: {current_time - self.previous_time:.2f} seconds | "
397
+ )
398
+ self.previous_time = current_time
trainer/gan.py ADDED
@@ -0,0 +1,464 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import logging
3
+
4
+ from utils.dataset import ShardingLMDBDataset, cycle
5
+ from utils.distributed import EMA_FSDP, fsdp_wrap, fsdp_state_dict, launch_distributed_job
6
+ from utils.misc import (
7
+ set_seed,
8
+ merge_dict_list
9
+ )
10
+ import torch.distributed as dist
11
+ from omegaconf import OmegaConf
12
+ from model import GAN
13
+ import torch
14
+ import wandb
15
+ import time
16
+ import os
17
+
18
+
19
+ class Trainer:
20
+ def __init__(self, config):
21
+ self.config = config
22
+ self.step = 0
23
+
24
+ # Step 1: Initialize the distributed training environment (rank, seed, dtype, logging etc.)
25
+ torch.backends.cuda.matmul.allow_tf32 = True
26
+ torch.backends.cudnn.allow_tf32 = True
27
+
28
+ launch_distributed_job()
29
+ global_rank = dist.get_rank()
30
+ self.world_size = dist.get_world_size()
31
+
32
+ self.dtype = torch.bfloat16 if config.mixed_precision else torch.float32
33
+ self.device = torch.cuda.current_device()
34
+ self.is_main_process = global_rank == 0
35
+ self.causal = config.causal
36
+ self.disable_wandb = config.disable_wandb
37
+
38
+ # Configuration for discriminator warmup
39
+ self.discriminator_warmup_steps = getattr(config, "discriminator_warmup_steps", 0)
40
+ self.in_discriminator_warmup = self.step < self.discriminator_warmup_steps
41
+ if self.in_discriminator_warmup and self.is_main_process:
42
+ print(f"Starting with discriminator warmup for {self.discriminator_warmup_steps} steps")
43
+ self.loss_scale = getattr(config, "loss_scale", 1.0)
44
+
45
+ # use a random seed for the training
46
+ if config.seed == 0:
47
+ random_seed = torch.randint(0, 10000000, (1,), device=self.device)
48
+ dist.broadcast(random_seed, src=0)
49
+ config.seed = random_seed.item()
50
+
51
+ set_seed(config.seed + global_rank)
52
+
53
+ if self.is_main_process and not self.disable_wandb:
54
+ wandb.login(host=config.wandb_host, key=config.wandb_key)
55
+ wandb.init(
56
+ config=OmegaConf.to_container(config, resolve=True),
57
+ name=config.config_name,
58
+ mode="online",
59
+ entity=config.wandb_entity,
60
+ project=config.wandb_project,
61
+ dir=config.wandb_save_dir
62
+ )
63
+
64
+ self.output_path = config.logdir
65
+
66
+ # Step 2: Initialize the model and optimizer
67
+ self.model = GAN(config, device=self.device)
68
+
69
+ self.model.generator = fsdp_wrap(
70
+ self.model.generator,
71
+ sharding_strategy=config.sharding_strategy,
72
+ mixed_precision=config.mixed_precision,
73
+ wrap_strategy=config.generator_fsdp_wrap_strategy
74
+ )
75
+
76
+ self.model.fake_score = fsdp_wrap(
77
+ self.model.fake_score,
78
+ sharding_strategy=config.sharding_strategy,
79
+ mixed_precision=config.mixed_precision,
80
+ wrap_strategy=config.fake_score_fsdp_wrap_strategy
81
+ )
82
+
83
+ self.model.text_encoder = fsdp_wrap(
84
+ self.model.text_encoder,
85
+ sharding_strategy=config.sharding_strategy,
86
+ mixed_precision=config.mixed_precision,
87
+ wrap_strategy=config.text_encoder_fsdp_wrap_strategy,
88
+ cpu_offload=getattr(config, "text_encoder_cpu_offload", False)
89
+ )
90
+
91
+ if not config.no_visualize or config.load_raw_video:
92
+ self.model.vae = self.model.vae.to(
93
+ device=self.device, dtype=torch.bfloat16 if config.mixed_precision else torch.float32)
94
+
95
+ self.generator_optimizer = torch.optim.AdamW(
96
+ [param for param in self.model.generator.parameters()
97
+ if param.requires_grad],
98
+ lr=config.gen_lr,
99
+ betas=(config.beta1, config.beta2)
100
+ )
101
+
102
+ # Create separate parameter groups for the fake_score network
103
+ # One group for parameters with "_cls_pred_branch" or "_gan_ca_blocks" in the name
104
+ # and another group for all other parameters
105
+ fake_score_params = []
106
+ discriminator_params = []
107
+
108
+ for name, param in self.model.fake_score.named_parameters():
109
+ if param.requires_grad:
110
+ if "_cls_pred_branch" in name or "_gan_ca_blocks" in name:
111
+ discriminator_params.append(param)
112
+ else:
113
+ fake_score_params.append(param)
114
+
115
+ # Use the special learning rate for the special parameter group
116
+ # and the default critic learning rate for other parameters
117
+ self.critic_param_groups = [
118
+ {'params': fake_score_params, 'lr': config.critic_lr},
119
+ {'params': discriminator_params, 'lr': config.critic_lr * config.discriminator_lr_multiplier}
120
+ ]
121
+ if self.in_discriminator_warmup:
122
+ self.critic_optimizer = torch.optim.AdamW(
123
+ self.critic_param_groups,
124
+ betas=(0.9, config.beta2_critic)
125
+ )
126
+ else:
127
+ self.critic_optimizer = torch.optim.AdamW(
128
+ self.critic_param_groups,
129
+ betas=(config.beta1_critic, config.beta2_critic)
130
+ )
131
+
132
+ # Step 3: Initialize the dataloader
133
+ self.data_path = config.data_path
134
+ dataset = ShardingLMDBDataset(config.data_path, max_pair=int(1e8))
135
+ sampler = torch.utils.data.distributed.DistributedSampler(
136
+ dataset, shuffle=True, drop_last=True)
137
+ dataloader = torch.utils.data.DataLoader(
138
+ dataset,
139
+ batch_size=config.batch_size,
140
+ sampler=sampler,
141
+ num_workers=8)
142
+
143
+ if dist.get_rank() == 0:
144
+ print("DATASET SIZE %d" % len(dataset))
145
+
146
+ self.dataloader = cycle(dataloader)
147
+
148
+ ##############################################################################################################
149
+ # 6. Set up EMA parameter containers
150
+ rename_param = (
151
+ lambda name: name.replace("_fsdp_wrapped_module.", "")
152
+ .replace("_checkpoint_wrapped_module.", "")
153
+ .replace("_orig_mod.", "")
154
+ )
155
+ self.name_to_trainable_params = {}
156
+ for n, p in self.model.generator.named_parameters():
157
+ if not p.requires_grad:
158
+ continue
159
+
160
+ renamed_n = rename_param(n)
161
+ self.name_to_trainable_params[renamed_n] = p
162
+ ema_weight = config.ema_weight
163
+ self.generator_ema = None
164
+ if (ema_weight is not None) and (ema_weight > 0.0):
165
+ print(f"Setting up EMA with weight {ema_weight}")
166
+ self.generator_ema = EMA_FSDP(self.model.generator, decay=ema_weight)
167
+
168
+ ##############################################################################################################
169
+ # 7. (If resuming) Load the model and optimizer, lr_scheduler, ema's statedicts
170
+ if getattr(config, "generator_ckpt", False):
171
+ print(f"Loading pretrained generator from {config.generator_ckpt}")
172
+ state_dict = torch.load(config.generator_ckpt, map_location="cpu")
173
+ if "generator" in state_dict:
174
+ state_dict = state_dict["generator"]
175
+ elif "model" in state_dict:
176
+ state_dict = state_dict["model"]
177
+ self.model.generator.load_state_dict(
178
+ state_dict, strict=True
179
+ )
180
+ if hasattr(config, "load"):
181
+ resume_ckpt_path_critic = os.path.join(config.load, "critic")
182
+ resume_ckpt_path_generator = os.path.join(config.load, "generator")
183
+ else:
184
+ resume_ckpt_path_critic = "none"
185
+ resume_ckpt_path_generator = "none"
186
+
187
+ _, _ = self.checkpointer_critic.try_best_load(
188
+ resume_ckpt_path=resume_ckpt_path_critic,
189
+ )
190
+ self.step, _ = self.checkpointer_generator.try_best_load(
191
+ resume_ckpt_path=resume_ckpt_path_generator,
192
+ force_start_w_ema=config.force_start_w_ema,
193
+ force_reset_zero_step=config.force_reset_zero_step,
194
+ force_reinit_ema=config.force_reinit_ema,
195
+ skip_optimizer_scheduler=config.skip_optimizer_scheduler,
196
+ )
197
+
198
+ ##############################################################################################################
199
+
200
+ # Let's delete EMA params for early steps to save some computes at training and inference
201
+ if self.step < config.ema_start_step:
202
+ self.generator_ema = None
203
+
204
+ self.max_grad_norm_generator = getattr(config, "max_grad_norm_generator", 10.0)
205
+ self.max_grad_norm_critic = getattr(config, "max_grad_norm_critic", 10.0)
206
+ self.previous_time = None
207
+
208
+ def save(self):
209
+ print("Start gathering distributed model states...")
210
+ generator_state_dict = fsdp_state_dict(
211
+ self.model.generator)
212
+ critic_state_dict = fsdp_state_dict(
213
+ self.model.fake_score)
214
+
215
+ if self.config.ema_start_step < self.step:
216
+ state_dict = {
217
+ "generator": generator_state_dict,
218
+ "critic": critic_state_dict,
219
+ "generator_ema": self.generator_ema.state_dict(),
220
+ }
221
+ else:
222
+ state_dict = {
223
+ "generator": generator_state_dict,
224
+ "critic": critic_state_dict,
225
+ }
226
+
227
+ if self.is_main_process:
228
+ os.makedirs(os.path.join(self.output_path,
229
+ f"checkpoint_model_{self.step:06d}"), exist_ok=True)
230
+ torch.save(state_dict, os.path.join(self.output_path,
231
+ f"checkpoint_model_{self.step:06d}", "model.pt"))
232
+ print("Model saved to", os.path.join(self.output_path,
233
+ f"checkpoint_model_{self.step:06d}", "model.pt"))
234
+
235
+ def fwdbwd_one_step(self, batch, train_generator):
236
+ self.model.eval() # prevent any randomness (e.g. dropout)
237
+
238
+ if self.step % 20 == 0:
239
+ torch.cuda.empty_cache()
240
+
241
+ # Step 1: Get the next batch of text prompts
242
+ text_prompts = batch["prompts"] # next(self.dataloader)
243
+ if "ode_latent" in batch:
244
+ clean_latent = batch["ode_latent"][:, -1].to(device=self.device, dtype=self.dtype)
245
+ else:
246
+ frames = batch["frames"].to(device=self.device, dtype=self.dtype)
247
+ with torch.no_grad():
248
+ clean_latent = self.model.vae.encode_to_latent(
249
+ frames).to(device=self.device, dtype=self.dtype)
250
+
251
+ image_latent = clean_latent[:, 0:1, ]
252
+
253
+ batch_size = len(text_prompts)
254
+ image_or_video_shape = list(self.config.image_or_video_shape)
255
+ image_or_video_shape[0] = batch_size
256
+
257
+ # Step 2: Extract the conditional infos
258
+ with torch.no_grad():
259
+ conditional_dict = self.model.text_encoder(
260
+ text_prompts=text_prompts)
261
+
262
+ if not getattr(self, "unconditional_dict", None):
263
+ unconditional_dict = self.model.text_encoder(
264
+ text_prompts=[self.config.negative_prompt] * batch_size)
265
+ unconditional_dict = {k: v.detach()
266
+ for k, v in unconditional_dict.items()}
267
+ self.unconditional_dict = unconditional_dict # cache the unconditional_dict
268
+ else:
269
+ unconditional_dict = self.unconditional_dict
270
+
271
+ mini_bs, full_bs = (
272
+ batch["mini_bs"],
273
+ batch["full_bs"],
274
+ )
275
+
276
+ # Step 3: Store gradients for the generator (if training the generator)
277
+ if train_generator:
278
+ gan_G_loss = self.model.generator_loss(
279
+ image_or_video_shape=image_or_video_shape,
280
+ conditional_dict=conditional_dict,
281
+ unconditional_dict=unconditional_dict,
282
+ clean_latent=clean_latent,
283
+ initial_latent=image_latent if self.config.i2v else None
284
+ )
285
+
286
+ loss_ratio = mini_bs * self.world_size / full_bs
287
+ total_loss = gan_G_loss * loss_ratio * self.loss_scale
288
+
289
+ total_loss.backward()
290
+ generator_grad_norm = self.model.generator.clip_grad_norm_(
291
+ self.max_grad_norm_generator)
292
+
293
+ generator_log_dict = {"generator_grad_norm": generator_grad_norm,
294
+ "gan_G_loss": gan_G_loss}
295
+
296
+ return generator_log_dict
297
+ else:
298
+ generator_log_dict = {}
299
+
300
+ # Step 4: Store gradients for the critic (if training the critic)
301
+ (gan_D_loss, r1_loss, r2_loss), critic_log_dict = self.model.critic_loss(
302
+ image_or_video_shape=image_or_video_shape,
303
+ conditional_dict=conditional_dict,
304
+ unconditional_dict=unconditional_dict,
305
+ clean_latent=clean_latent,
306
+ real_image_or_video=clean_latent,
307
+ initial_latent=image_latent if self.config.i2v else None
308
+ )
309
+
310
+ loss_ratio = mini_bs * dist.get_world_size() / full_bs
311
+ total_loss = (gan_D_loss + 0.5 * (r1_loss + r2_loss)) * loss_ratio * self.loss_scale
312
+
313
+ total_loss.backward()
314
+ critic_grad_norm = self.model.fake_score.clip_grad_norm_(
315
+ self.max_grad_norm_critic)
316
+
317
+ critic_log_dict.update({"critic_grad_norm": critic_grad_norm,
318
+ "gan_D_loss": gan_D_loss,
319
+ "r1_loss": r1_loss,
320
+ "r2_loss": r2_loss})
321
+
322
+ return critic_log_dict
323
+
324
+ def generate_video(self, pipeline, prompts, image=None):
325
+ batch_size = len(prompts)
326
+ sampled_noise = torch.randn(
327
+ [batch_size, 21, 16, 60, 104], device="cuda", dtype=self.dtype
328
+ )
329
+ video, _ = pipeline.inference(
330
+ noise=sampled_noise,
331
+ text_prompts=prompts,
332
+ return_latents=True
333
+ )
334
+ current_video = video.permute(0, 1, 3, 4, 2).cpu().numpy() * 255.0
335
+ return current_video
336
+
337
+ def train(self):
338
+ start_step = self.step
339
+
340
+ while True:
341
+ if self.step == self.discriminator_warmup_steps and self.discriminator_warmup_steps != 0:
342
+ print("Resetting critic optimizer")
343
+ del self.critic_optimizer
344
+ torch.cuda.empty_cache()
345
+ # Create new optimizers
346
+ self.critic_optimizer = torch.optim.AdamW(
347
+ self.critic_param_groups,
348
+ betas=(self.config.beta1_critic, self.config.beta2_critic)
349
+ )
350
+ # Update checkpointer references
351
+ self.checkpointer_critic.optimizer = self.critic_optimizer
352
+ # Check if we're in the discriminator warmup phase
353
+ self.in_discriminator_warmup = self.step < self.discriminator_warmup_steps
354
+
355
+ # Only update generator and critic outside the warmup phase
356
+ TRAIN_GENERATOR = not self.in_discriminator_warmup and self.step % self.config.dfake_gen_update_ratio == 0
357
+
358
+ # Train the generator (only outside warmup phase)
359
+ if TRAIN_GENERATOR:
360
+ self.model.fake_score.requires_grad_(False)
361
+ self.model.generator.requires_grad_(True)
362
+ self.generator_optimizer.zero_grad(set_to_none=True)
363
+ extras_list = []
364
+ for ii, mini_batch in enumerate(self.dataloader.next()):
365
+ extra = self.fwdbwd_one_step(mini_batch, True)
366
+ extras_list.append(extra)
367
+ generator_log_dict = merge_dict_list(extras_list)
368
+ self.generator_optimizer.step()
369
+ if self.generator_ema is not None:
370
+ self.generator_ema.update(self.model.generator)
371
+ else:
372
+ generator_log_dict = {}
373
+
374
+ # Train the critic/discriminator
375
+ if self.in_discriminator_warmup:
376
+ # During warmup, only allow gradient for discriminator params
377
+ self.model.generator.requires_grad_(False)
378
+ self.model.fake_score.requires_grad_(False)
379
+
380
+ # Enable gradient only for discriminator params
381
+ for name, param in self.model.fake_score.named_parameters():
382
+ if "_cls_pred_branch" in name or "_gan_ca_blocks" in name:
383
+ param.requires_grad_(True)
384
+ else:
385
+ # Normal training mode
386
+ self.model.generator.requires_grad_(False)
387
+ self.model.fake_score.requires_grad_(True)
388
+
389
+ self.critic_optimizer.zero_grad(set_to_none=True)
390
+ extras_list = []
391
+ batch = next(self.dataloader)
392
+ extra = self.fwdbwd_one_step(batch, False)
393
+ extras_list.append(extra)
394
+ critic_log_dict = merge_dict_list(extras_list)
395
+ self.critic_optimizer.step()
396
+
397
+ # Increment the step since we finished gradient update
398
+ self.step += 1
399
+
400
+ # If we just finished warmup, print a message
401
+ if self.is_main_process and self.step == self.discriminator_warmup_steps:
402
+ print(f"Finished discriminator warmup after {self.discriminator_warmup_steps} steps")
403
+
404
+ # Create EMA params (if not already created)
405
+ if (self.step >= self.config.ema_start_step) and \
406
+ (self.generator_ema is None) and (self.config.ema_weight > 0):
407
+ self.generator_ema = EMA_FSDP(self.model.generator, decay=self.config.ema_weight)
408
+
409
+ # Save the model
410
+ if (not self.config.no_save) and (self.step - start_step) > 0 and self.step % self.config.log_iters == 0:
411
+ torch.cuda.empty_cache()
412
+ self.save()
413
+ torch.cuda.empty_cache()
414
+
415
+ # Logging
416
+ wandb_loss_dict = {
417
+ "generator_grad_norm": generator_log_dict["generator_grad_norm"],
418
+ "critic_grad_norm": critic_log_dict["critic_grad_norm"],
419
+ "real_logit": critic_log_dict["noisy_real_logit"],
420
+ "fake_logit": critic_log_dict["noisy_fake_logit"],
421
+ "r1_loss": critic_log_dict["r1_loss"],
422
+ "r2_loss": critic_log_dict["r2_loss"],
423
+ }
424
+ if TRAIN_GENERATOR:
425
+ wandb_loss_dict.update({
426
+ "generator_grad_norm": generator_log_dict["generator_grad_norm"],
427
+ })
428
+ self.all_gather_dict(wandb_loss_dict)
429
+ wandb_loss_dict["diff_logit"] = wandb_loss_dict["real_logit"] - wandb_loss_dict["fake_logit"]
430
+ wandb_loss_dict["reg_loss"] = 0.5 * (wandb_loss_dict["r1_loss"] + wandb_loss_dict["r2_loss"])
431
+
432
+ if self.is_main_process:
433
+ if self.in_discriminator_warmup:
434
+ warmup_status = f"[WARMUP {self.step}/{self.discriminator_warmup_steps}] Training only discriminator params"
435
+ print(warmup_status)
436
+ if not self.disable_wandb:
437
+ wandb_loss_dict.update({"warmup_status": 1.0})
438
+
439
+ if not self.disable_wandb:
440
+ wandb.log(wandb_loss_dict, step=self.step)
441
+
442
+ if self.step % self.config.gc_interval == 0:
443
+ if dist.get_rank() == 0:
444
+ logging.info("DistGarbageCollector: Running GC.")
445
+ gc.collect()
446
+ torch.cuda.empty_cache()
447
+
448
+ if self.is_main_process:
449
+ current_time = time.time()
450
+ if self.previous_time is None:
451
+ self.previous_time = current_time
452
+ else:
453
+ if not self.disable_wandb:
454
+ wandb.log({"per iteration time": current_time - self.previous_time}, step=self.step)
455
+ self.previous_time = current_time
456
+
457
+ def all_gather_dict(self, target_dict):
458
+ for key, value in target_dict.items():
459
+ gathered_value = torch.zeros(
460
+ [self.world_size, *value.shape],
461
+ dtype=value.dtype, device=self.device)
462
+ dist.all_gather_into_tensor(gathered_value, value)
463
+ avg_value = gathered_value.mean().item()
464
+ target_dict[key] = avg_value
trainer/ode.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import logging
3
+ from utils.dataset import ODERegressionLMDBDataset, cycle
4
+ from model import ODERegression
5
+ from collections import defaultdict
6
+ from utils.misc import (
7
+ set_seed
8
+ )
9
+ import torch.distributed as dist
10
+ from omegaconf import OmegaConf
11
+ import torch
12
+ import wandb
13
+ import time
14
+ import os
15
+
16
+ from utils.distributed import barrier, fsdp_wrap, fsdp_state_dict, launch_distributed_job
17
+
18
+
19
+ class Trainer:
20
+ def __init__(self, config):
21
+ self.config = config
22
+ self.step = 0
23
+
24
+ # Step 1: Initialize the distributed training environment (rank, seed, dtype, logging etc.)
25
+ torch.backends.cuda.matmul.allow_tf32 = True
26
+ torch.backends.cudnn.allow_tf32 = True
27
+
28
+ launch_distributed_job()
29
+ global_rank = dist.get_rank()
30
+ self.world_size = dist.get_world_size()
31
+
32
+ self.dtype = torch.bfloat16 if config.mixed_precision else torch.float32
33
+ self.device = torch.cuda.current_device()
34
+ self.is_main_process = global_rank == 0
35
+ self.disable_wandb = config.disable_wandb
36
+
37
+ # use a random seed for the training
38
+ if config.seed == 0:
39
+ random_seed = torch.randint(0, 10000000, (1,), device=self.device)
40
+ dist.broadcast(random_seed, src=0)
41
+ config.seed = random_seed.item()
42
+
43
+ set_seed(config.seed + global_rank)
44
+
45
+ if self.is_main_process and not self.disable_wandb:
46
+ wandb.login(host=config.wandb_host, key=config.wandb_key)
47
+ wandb.init(
48
+ config=OmegaConf.to_container(config, resolve=True),
49
+ name=config.config_name,
50
+ mode="online",
51
+ entity=config.wandb_entity,
52
+ project=config.wandb_project,
53
+ dir=config.wandb_save_dir
54
+ )
55
+
56
+ self.output_path = config.logdir
57
+
58
+ # Step 2: Initialize the model and optimizer
59
+
60
+ assert config.distribution_loss == "ode", "Only ODE loss is supported for ODE training"
61
+ self.model = ODERegression(config, device=self.device)
62
+
63
+ self.model.generator = fsdp_wrap(
64
+ self.model.generator,
65
+ sharding_strategy=config.sharding_strategy,
66
+ mixed_precision=config.mixed_precision,
67
+ wrap_strategy=config.generator_fsdp_wrap_strategy
68
+ )
69
+ self.model.text_encoder = fsdp_wrap(
70
+ self.model.text_encoder,
71
+ sharding_strategy=config.sharding_strategy,
72
+ mixed_precision=config.mixed_precision,
73
+ wrap_strategy=config.text_encoder_fsdp_wrap_strategy,
74
+ cpu_offload=getattr(config, "text_encoder_cpu_offload", False)
75
+ )
76
+
77
+ if not config.no_visualize or config.load_raw_video:
78
+ self.model.vae = self.model.vae.to(
79
+ device=self.device, dtype=torch.bfloat16 if config.mixed_precision else torch.float32)
80
+
81
+ self.generator_optimizer = torch.optim.AdamW(
82
+ [param for param in self.model.generator.parameters()
83
+ if param.requires_grad],
84
+ lr=config.lr,
85
+ betas=(config.beta1, config.beta2),
86
+ weight_decay=config.weight_decay
87
+ )
88
+
89
+ # Step 3: Initialize the dataloader
90
+ dataset = ODERegressionLMDBDataset(
91
+ config.data_path, max_pair=getattr(config, "max_pair", int(1e8)))
92
+ sampler = torch.utils.data.distributed.DistributedSampler(
93
+ dataset, shuffle=True, drop_last=True)
94
+ dataloader = torch.utils.data.DataLoader(
95
+ dataset, batch_size=config.batch_size, sampler=sampler, num_workers=8)
96
+ total_batch_size = getattr(config, "total_batch_size", None)
97
+ if total_batch_size is not None:
98
+ assert total_batch_size == config.batch_size * self.world_size, "Gradient accumulation is not supported for ODE training"
99
+ self.dataloader = cycle(dataloader)
100
+
101
+ self.step = 0
102
+
103
+ ##############################################################################################################
104
+ # 7. (If resuming) Load the model and optimizer, lr_scheduler, ema's statedicts
105
+ if getattr(config, "generator_ckpt", False):
106
+ print(f"Loading pretrained generator from {config.generator_ckpt}")
107
+ state_dict = torch.load(config.generator_ckpt, map_location="cpu")[
108
+ 'generator']
109
+ self.model.generator.load_state_dict(
110
+ state_dict, strict=True
111
+ )
112
+
113
+ ##############################################################################################################
114
+
115
+ self.max_grad_norm = 10.0
116
+ self.previous_time = None
117
+
118
+ def save(self):
119
+ print("Start gathering distributed model states...")
120
+ generator_state_dict = fsdp_state_dict(
121
+ self.model.generator)
122
+ state_dict = {
123
+ "generator": generator_state_dict
124
+ }
125
+
126
+ if self.is_main_process:
127
+ os.makedirs(os.path.join(self.output_path,
128
+ f"checkpoint_model_{self.step:06d}"), exist_ok=True)
129
+ torch.save(state_dict, os.path.join(self.output_path,
130
+ f"checkpoint_model_{self.step:06d}", "model.pt"))
131
+ print("Model saved to", os.path.join(self.output_path,
132
+ f"checkpoint_model_{self.step:06d}", "model.pt"))
133
+
134
+ def train_one_step(self):
135
+ VISUALIZE = self.step % 100 == 0
136
+ self.model.eval() # prevent any randomness (e.g. dropout)
137
+
138
+ # Step 1: Get the next batch of text prompts
139
+ batch = next(self.dataloader)
140
+ text_prompts = batch["prompts"]
141
+ ode_latent = batch["ode_latent"].to(
142
+ device=self.device, dtype=self.dtype)
143
+
144
+ # Step 2: Extract the conditional infos
145
+ with torch.no_grad():
146
+ conditional_dict = self.model.text_encoder(
147
+ text_prompts=text_prompts)
148
+
149
+ # Step 3: Train the generator
150
+ generator_loss, log_dict = self.model.generator_loss(
151
+ ode_latent=ode_latent,
152
+ conditional_dict=conditional_dict
153
+ )
154
+
155
+ unnormalized_loss = log_dict["unnormalized_loss"]
156
+ timestep = log_dict["timestep"]
157
+
158
+ if self.world_size > 1:
159
+ gathered_unnormalized_loss = torch.zeros(
160
+ [self.world_size, *unnormalized_loss.shape],
161
+ dtype=unnormalized_loss.dtype, device=self.device)
162
+ gathered_timestep = torch.zeros(
163
+ [self.world_size, *timestep.shape],
164
+ dtype=timestep.dtype, device=self.device)
165
+
166
+ dist.all_gather_into_tensor(
167
+ gathered_unnormalized_loss, unnormalized_loss)
168
+ dist.all_gather_into_tensor(gathered_timestep, timestep)
169
+ else:
170
+ gathered_unnormalized_loss = unnormalized_loss
171
+ gathered_timestep = timestep
172
+
173
+ loss_breakdown = defaultdict(list)
174
+ stats = {}
175
+
176
+ for index, t in enumerate(timestep):
177
+ loss_breakdown[str(int(t.item()) // 250 * 250)].append(
178
+ unnormalized_loss[index].item())
179
+
180
+ for key_t in loss_breakdown.keys():
181
+ stats["loss_at_time_" + key_t] = sum(loss_breakdown[key_t]) / \
182
+ len(loss_breakdown[key_t])
183
+
184
+ self.generator_optimizer.zero_grad()
185
+ generator_loss.backward()
186
+ generator_grad_norm = self.model.generator.clip_grad_norm_(
187
+ self.max_grad_norm)
188
+ self.generator_optimizer.step()
189
+
190
+ # Step 4: Visualization
191
+ if VISUALIZE and not self.config.no_visualize and not self.config.disable_wandb and self.is_main_process:
192
+ # Visualize the input, output, and ground truth
193
+ input = log_dict["input"]
194
+ output = log_dict["output"]
195
+ ground_truth = ode_latent[:, -1]
196
+
197
+ input_video = self.model.vae.decode_to_pixel(input)
198
+ output_video = self.model.vae.decode_to_pixel(output)
199
+ ground_truth_video = self.model.vae.decode_to_pixel(ground_truth)
200
+ input_video = 255.0 * (input_video.cpu().numpy() * 0.5 + 0.5)
201
+ output_video = 255.0 * (output_video.cpu().numpy() * 0.5 + 0.5)
202
+ ground_truth_video = 255.0 * (ground_truth_video.cpu().numpy() * 0.5 + 0.5)
203
+
204
+ # Visualize the input, output, and ground truth
205
+ wandb.log({
206
+ "input": wandb.Video(input_video, caption="Input", fps=16, format="mp4"),
207
+ "output": wandb.Video(output_video, caption="Output", fps=16, format="mp4"),
208
+ "ground_truth": wandb.Video(ground_truth_video, caption="Ground Truth", fps=16, format="mp4"),
209
+ }, step=self.step)
210
+
211
+ # Step 5: Logging
212
+ if self.is_main_process and not self.disable_wandb:
213
+ wandb_loss_dict = {
214
+ "generator_loss": generator_loss.item(),
215
+ "generator_grad_norm": generator_grad_norm.item(),
216
+ **stats
217
+ }
218
+ wandb.log(wandb_loss_dict, step=self.step)
219
+
220
+ if self.step % self.config.gc_interval == 0:
221
+ if dist.get_rank() == 0:
222
+ logging.info("DistGarbageCollector: Running GC.")
223
+ gc.collect()
224
+
225
+ def train(self):
226
+ while True:
227
+ self.train_one_step()
228
+ if (not self.config.no_save) and self.step % self.config.log_iters == 0:
229
+ self.save()
230
+ torch.cuda.empty_cache()
231
+
232
+ barrier()
233
+ if self.is_main_process:
234
+ current_time = time.time()
235
+ if self.previous_time is None:
236
+ self.previous_time = current_time
237
+ else:
238
+ if not self.disable_wandb:
239
+ wandb.log({"per iteration time": current_time - self.previous_time}, step=self.step)
240
+ self.previous_time = current_time
241
+
242
+ self.step += 1
utils/dataset.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils.lmdb import get_array_shape_from_lmdb, retrieve_row_from_lmdb
2
+ from torch.utils.data import Dataset
3
+ import numpy as np
4
+ import torch
5
+ import lmdb
6
+ import json
7
+ from pathlib import Path
8
+ from PIL import Image
9
+ import os
10
+
11
+
12
+ class TextDataset(Dataset):
13
+ def __init__(self, prompt_path, extended_prompt_path=None):
14
+ with open(prompt_path, encoding="utf-8") as f:
15
+ self.prompt_list = [line.rstrip() for line in f]
16
+
17
+ if extended_prompt_path is not None:
18
+ with open(extended_prompt_path, encoding="utf-8") as f:
19
+ self.extended_prompt_list = [line.rstrip() for line in f]
20
+ assert len(self.extended_prompt_list) == len(self.prompt_list)
21
+ else:
22
+ self.extended_prompt_list = None
23
+
24
+ def __len__(self):
25
+ return len(self.prompt_list)
26
+
27
+ def __getitem__(self, idx):
28
+ batch = {
29
+ "prompts": self.prompt_list[idx],
30
+ "idx": idx,
31
+ }
32
+ if self.extended_prompt_list is not None:
33
+ batch["extended_prompts"] = self.extended_prompt_list[idx]
34
+ return batch
35
+
36
+
37
+ class ODERegressionLMDBDataset(Dataset):
38
+ def __init__(self, data_path: str, max_pair: int = int(1e8)):
39
+ self.env = lmdb.open(data_path, readonly=True,
40
+ lock=False, readahead=False, meminit=False)
41
+
42
+ self.latents_shape = get_array_shape_from_lmdb(self.env, 'latents')
43
+ self.max_pair = max_pair
44
+
45
+ def __len__(self):
46
+ return min(self.latents_shape[0], self.max_pair)
47
+
48
+ def __getitem__(self, idx):
49
+ """
50
+ Outputs:
51
+ - prompts: List of Strings
52
+ - latents: Tensor of shape (num_denoising_steps, num_frames, num_channels, height, width). It is ordered from pure noise to clean image.
53
+ """
54
+ latents = retrieve_row_from_lmdb(
55
+ self.env,
56
+ "latents", np.float16, idx, shape=self.latents_shape[1:]
57
+ )
58
+
59
+ if len(latents.shape) == 4:
60
+ latents = latents[None, ...]
61
+
62
+ prompts = retrieve_row_from_lmdb(
63
+ self.env,
64
+ "prompts", str, idx
65
+ )
66
+ return {
67
+ "prompts": prompts,
68
+ "ode_latent": torch.tensor(latents, dtype=torch.float32)
69
+ }
70
+
71
+
72
+ class ShardingLMDBDataset(Dataset):
73
+ def __init__(self, data_path: str, max_pair: int = int(1e8)):
74
+ self.envs = []
75
+ self.index = []
76
+
77
+ for fname in sorted(os.listdir(data_path)):
78
+ path = os.path.join(data_path, fname)
79
+ env = lmdb.open(path,
80
+ readonly=True,
81
+ lock=False,
82
+ readahead=False,
83
+ meminit=False)
84
+ self.envs.append(env)
85
+
86
+ self.latents_shape = [None] * len(self.envs)
87
+ for shard_id, env in enumerate(self.envs):
88
+ self.latents_shape[shard_id] = get_array_shape_from_lmdb(env, 'latents')
89
+ for local_i in range(self.latents_shape[shard_id][0]):
90
+ self.index.append((shard_id, local_i))
91
+
92
+ # print("shard_id ", shard_id, " local_i ", local_i)
93
+
94
+ self.max_pair = max_pair
95
+
96
+ def __len__(self):
97
+ return len(self.index)
98
+
99
+ def __getitem__(self, idx):
100
+ """
101
+ Outputs:
102
+ - prompts: List of Strings
103
+ - latents: Tensor of shape (num_denoising_steps, num_frames, num_channels, height, width). It is ordered from pure noise to clean image.
104
+ """
105
+ shard_id, local_idx = self.index[idx]
106
+
107
+ latents = retrieve_row_from_lmdb(
108
+ self.envs[shard_id],
109
+ "latents", np.float16, local_idx,
110
+ shape=self.latents_shape[shard_id][1:]
111
+ )
112
+
113
+ if len(latents.shape) == 4:
114
+ latents = latents[None, ...]
115
+
116
+ prompts = retrieve_row_from_lmdb(
117
+ self.envs[shard_id],
118
+ "prompts", str, local_idx
119
+ )
120
+
121
+ return {
122
+ "prompts": prompts,
123
+ "ode_latent": torch.tensor(latents, dtype=torch.float32)
124
+ }
125
+
126
+
127
+ class TextImagePairDataset(Dataset):
128
+ def __init__(
129
+ self,
130
+ data_dir,
131
+ transform=None,
132
+ eval_first_n=-1,
133
+ pad_to_multiple_of=None
134
+ ):
135
+ """
136
+ Args:
137
+ data_dir (str): Path to the directory containing:
138
+ - target_crop_info_*.json (metadata file)
139
+ - */ (subdirectory containing images with matching aspect ratio)
140
+ transform (callable, optional): Optional transform to be applied on the image
141
+ """
142
+ self.transform = transform
143
+ data_dir = Path(data_dir)
144
+
145
+ # Find the metadata JSON file
146
+ metadata_files = list(data_dir.glob('target_crop_info_*.json'))
147
+ if not metadata_files:
148
+ raise FileNotFoundError(f"No metadata file found in {data_dir}")
149
+ if len(metadata_files) > 1:
150
+ raise ValueError(f"Multiple metadata files found in {data_dir}")
151
+
152
+ metadata_path = metadata_files[0]
153
+ # Extract aspect ratio from metadata filename (e.g. target_crop_info_26-15.json -> 26-15)
154
+ aspect_ratio = metadata_path.stem.split('_')[-1]
155
+
156
+ # Use aspect ratio subfolder for images
157
+ self.image_dir = data_dir / aspect_ratio
158
+ if not self.image_dir.exists():
159
+ raise FileNotFoundError(f"Image directory not found: {self.image_dir}")
160
+
161
+ # Load metadata
162
+ with open(metadata_path, 'r') as f:
163
+ self.metadata = json.load(f)
164
+
165
+ eval_first_n = eval_first_n if eval_first_n != -1 else len(self.metadata)
166
+ self.metadata = self.metadata[:eval_first_n]
167
+
168
+ # Verify all images exist
169
+ for item in self.metadata:
170
+ image_path = self.image_dir / item['file_name']
171
+ if not image_path.exists():
172
+ raise FileNotFoundError(f"Image not found: {image_path}")
173
+
174
+ self.dummy_prompt = "DUMMY PROMPT"
175
+ self.pre_pad_len = len(self.metadata)
176
+ if pad_to_multiple_of is not None and len(self.metadata) % pad_to_multiple_of != 0:
177
+ # Duplicate the last entry
178
+ self.metadata += [self.metadata[-1]] * (
179
+ pad_to_multiple_of - len(self.metadata) % pad_to_multiple_of
180
+ )
181
+
182
+ def __len__(self):
183
+ return len(self.metadata)
184
+
185
+ def __getitem__(self, idx):
186
+ """
187
+ Returns:
188
+ dict: A dictionary containing:
189
+ - image: PIL Image
190
+ - caption: str
191
+ - target_bbox: list of int [x1, y1, x2, y2]
192
+ - target_ratio: str
193
+ - type: str
194
+ - origin_size: tuple of int (width, height)
195
+ """
196
+ item = self.metadata[idx]
197
+
198
+ # Load image
199
+ image_path = self.image_dir / item['file_name']
200
+ image = Image.open(image_path).convert('RGB')
201
+
202
+ # Apply transform if specified
203
+ if self.transform:
204
+ image = self.transform(image)
205
+
206
+ return {
207
+ 'image': image,
208
+ 'prompts': item['caption'],
209
+ 'target_bbox': item['target_crop']['target_bbox'],
210
+ 'target_ratio': item['target_crop']['target_ratio'],
211
+ 'type': item['type'],
212
+ 'origin_size': (item['origin_width'], item['origin_height']),
213
+ 'idx': idx
214
+ }
215
+
216
+
217
+ def cycle(dl):
218
+ while True:
219
+ for data in dl:
220
+ yield data
utils/distributed.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import timedelta
2
+ from functools import partial
3
+ import os
4
+ import torch
5
+ import torch.distributed as dist
6
+ from torch.distributed.fsdp import FullStateDictConfig, FullyShardedDataParallel as FSDP, MixedPrecision, ShardingStrategy, StateDictType
7
+ from torch.distributed.fsdp.api import CPUOffload
8
+ from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy
9
+
10
+
11
+ def fsdp_state_dict(model):
12
+ fsdp_fullstate_save_policy = FullStateDictConfig(
13
+ offload_to_cpu=True, rank0_only=True
14
+ )
15
+ with FSDP.state_dict_type(
16
+ model, StateDictType.FULL_STATE_DICT, fsdp_fullstate_save_policy
17
+ ):
18
+ checkpoint = model.state_dict()
19
+
20
+ return checkpoint
21
+
22
+
23
+ def fsdp_wrap(module, sharding_strategy="full", mixed_precision=False, wrap_strategy="size", min_num_params=int(5e7), transformer_module=None, cpu_offload=False):
24
+ if mixed_precision:
25
+ mixed_precision_policy = MixedPrecision(
26
+ param_dtype=torch.bfloat16,
27
+ reduce_dtype=torch.float32,
28
+ buffer_dtype=torch.float32,
29
+ cast_forward_inputs=False
30
+ )
31
+ else:
32
+ mixed_precision_policy = None
33
+
34
+ if wrap_strategy == "transformer":
35
+ auto_wrap_policy = partial(
36
+ transformer_auto_wrap_policy,
37
+ transformer_layer_cls=transformer_module
38
+ )
39
+ elif wrap_strategy == "size":
40
+ auto_wrap_policy = partial(
41
+ size_based_auto_wrap_policy,
42
+ min_num_params=min_num_params
43
+ )
44
+ else:
45
+ raise ValueError(f"Invalid wrap strategy: {wrap_strategy}")
46
+
47
+ os.environ["NCCL_CROSS_NIC"] = "1"
48
+
49
+ sharding_strategy = {
50
+ "full": ShardingStrategy.FULL_SHARD,
51
+ "hybrid_full": ShardingStrategy.HYBRID_SHARD,
52
+ "hybrid_zero2": ShardingStrategy._HYBRID_SHARD_ZERO2,
53
+ "no_shard": ShardingStrategy.NO_SHARD,
54
+ }[sharding_strategy]
55
+
56
+ module = FSDP(
57
+ module,
58
+ auto_wrap_policy=auto_wrap_policy,
59
+ sharding_strategy=sharding_strategy,
60
+ mixed_precision=mixed_precision_policy,
61
+ device_id=torch.cuda.current_device(),
62
+ limit_all_gathers=True,
63
+ use_orig_params=True,
64
+ cpu_offload=CPUOffload(offload_params=cpu_offload),
65
+ sync_module_states=False # Load ckpt on rank 0 and sync to other ranks
66
+ )
67
+ return module
68
+
69
+
70
+ def barrier():
71
+ if dist.is_initialized():
72
+ dist.barrier()
73
+
74
+
75
+ def launch_distributed_job(backend: str = "nccl"):
76
+ rank = int(os.environ["RANK"])
77
+ local_rank = int(os.environ["LOCAL_RANK"])
78
+ world_size = int(os.environ["WORLD_SIZE"])
79
+ host = os.environ["MASTER_ADDR"]
80
+ port = int(os.environ["MASTER_PORT"])
81
+
82
+ if ":" in host: # IPv6
83
+ init_method = f"tcp://[{host}]:{port}"
84
+ else: # IPv4
85
+ init_method = f"tcp://{host}:{port}"
86
+ dist.init_process_group(rank=rank, world_size=world_size, backend=backend,
87
+ init_method=init_method, timeout=timedelta(minutes=30))
88
+ torch.cuda.set_device(local_rank)
89
+
90
+
91
+ class EMA_FSDP:
92
+ def __init__(self, fsdp_module: torch.nn.Module, decay: float = 0.999):
93
+ self.decay = decay
94
+ self.shadow = {}
95
+ self._init_shadow(fsdp_module)
96
+
97
+ @torch.no_grad()
98
+ def _init_shadow(self, fsdp_module):
99
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
100
+ with FSDP.summon_full_params(fsdp_module, writeback=False):
101
+ for n, p in fsdp_module.module.named_parameters():
102
+ self.shadow[n] = p.detach().clone().float().cpu()
103
+
104
+ @torch.no_grad()
105
+ def update(self, fsdp_module):
106
+ d = self.decay
107
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
108
+ with FSDP.summon_full_params(fsdp_module, writeback=False):
109
+ for n, p in fsdp_module.module.named_parameters():
110
+ self.shadow[n].mul_(d).add_(p.detach().float().cpu(), alpha=1. - d)
111
+
112
+ # Optional helpers ---------------------------------------------------
113
+ def state_dict(self):
114
+ return self.shadow # picklable
115
+
116
+ def load_state_dict(self, sd):
117
+ self.shadow = {k: v.clone() for k, v in sd.items()}
118
+
119
+ def copy_to(self, fsdp_module):
120
+ # load EMA weights into an (unwrapped) copy of the generator
121
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
122
+ with FSDP.summon_full_params(fsdp_module, writeback=True):
123
+ for n, p in fsdp_module.module.named_parameters():
124
+ if n in self.shadow:
125
+ p.data.copy_(self.shadow[n].to(p.dtype, device=p.device))
utils/lmdb.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ def get_array_shape_from_lmdb(env, array_name):
5
+ with env.begin() as txn:
6
+ image_shape = txn.get(f"{array_name}_shape".encode()).decode()
7
+ image_shape = tuple(map(int, image_shape.split()))
8
+ return image_shape
9
+
10
+
11
+ def store_arrays_to_lmdb(env, arrays_dict, start_index=0):
12
+ """
13
+ Store rows of multiple numpy arrays in a single LMDB.
14
+ Each row is stored separately with a naming convention.
15
+ """
16
+ with env.begin(write=True) as txn:
17
+ for array_name, array in arrays_dict.items():
18
+ for i, row in enumerate(array):
19
+ # Convert row to bytes
20
+ if isinstance(row, str):
21
+ row_bytes = row.encode()
22
+ else:
23
+ row_bytes = row.tobytes()
24
+
25
+ data_key = f'{array_name}_{start_index + i}_data'.encode()
26
+
27
+ txn.put(data_key, row_bytes)
28
+
29
+
30
+ def process_data_dict(data_dict, seen_prompts):
31
+ output_dict = {}
32
+
33
+ all_videos = []
34
+ all_prompts = []
35
+ for prompt, video in data_dict.items():
36
+ if prompt in seen_prompts:
37
+ continue
38
+ else:
39
+ seen_prompts.add(prompt)
40
+
41
+ video = video.half().numpy()
42
+ all_videos.append(video)
43
+ all_prompts.append(prompt)
44
+
45
+ if len(all_videos) == 0:
46
+ return {"latents": np.array([]), "prompts": np.array([])}
47
+
48
+ all_videos = np.concatenate(all_videos, axis=0)
49
+
50
+ output_dict['latents'] = all_videos
51
+ output_dict['prompts'] = np.array(all_prompts)
52
+
53
+ return output_dict
54
+
55
+
56
+ def retrieve_row_from_lmdb(lmdb_env, array_name, dtype, row_index, shape=None):
57
+ """
58
+ Retrieve a specific row from a specific array in the LMDB.
59
+ """
60
+ data_key = f'{array_name}_{row_index}_data'.encode()
61
+
62
+ with lmdb_env.begin() as txn:
63
+ row_bytes = txn.get(data_key)
64
+
65
+ if dtype == str:
66
+ array = row_bytes.decode()
67
+ else:
68
+ array = np.frombuffer(row_bytes, dtype=dtype)
69
+
70
+ if shape is not None and len(shape) > 0:
71
+ array = array.reshape(shape)
72
+ return array
utils/loss.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ import torch
3
+
4
+
5
+ class DenoisingLoss(ABC):
6
+ @abstractmethod
7
+ def __call__(
8
+ self, x: torch.Tensor, x_pred: torch.Tensor,
9
+ noise: torch.Tensor, noise_pred: torch.Tensor,
10
+ alphas_cumprod: torch.Tensor,
11
+ timestep: torch.Tensor,
12
+ **kwargs
13
+ ) -> torch.Tensor:
14
+ """
15
+ Base class for denoising loss.
16
+ Input:
17
+ - x: the clean data with shape [B, F, C, H, W]
18
+ - x_pred: the predicted clean data with shape [B, F, C, H, W]
19
+ - noise: the noise with shape [B, F, C, H, W]
20
+ - noise_pred: the predicted noise with shape [B, F, C, H, W]
21
+ - alphas_cumprod: the cumulative product of alphas (defining the noise schedule) with shape [T]
22
+ - timestep: the current timestep with shape [B, F]
23
+ """
24
+ pass
25
+
26
+
27
+ class X0PredLoss(DenoisingLoss):
28
+ def __call__(
29
+ self, x: torch.Tensor, x_pred: torch.Tensor,
30
+ noise: torch.Tensor, noise_pred: torch.Tensor,
31
+ alphas_cumprod: torch.Tensor,
32
+ timestep: torch.Tensor,
33
+ **kwargs
34
+ ) -> torch.Tensor:
35
+ return torch.mean((x - x_pred) ** 2)
36
+
37
+
38
+ class VPredLoss(DenoisingLoss):
39
+ def __call__(
40
+ self, x: torch.Tensor, x_pred: torch.Tensor,
41
+ noise: torch.Tensor, noise_pred: torch.Tensor,
42
+ alphas_cumprod: torch.Tensor,
43
+ timestep: torch.Tensor,
44
+ **kwargs
45
+ ) -> torch.Tensor:
46
+ weights = 1 / (1 - alphas_cumprod[timestep].reshape(*timestep.shape, 1, 1, 1))
47
+ return torch.mean(weights * (x - x_pred) ** 2)
48
+
49
+
50
+ class NoisePredLoss(DenoisingLoss):
51
+ def __call__(
52
+ self, x: torch.Tensor, x_pred: torch.Tensor,
53
+ noise: torch.Tensor, noise_pred: torch.Tensor,
54
+ alphas_cumprod: torch.Tensor,
55
+ timestep: torch.Tensor,
56
+ **kwargs
57
+ ) -> torch.Tensor:
58
+ return torch.mean((noise - noise_pred) ** 2)
59
+
60
+
61
+ class FlowPredLoss(DenoisingLoss):
62
+ def __call__(
63
+ self, x: torch.Tensor, x_pred: torch.Tensor,
64
+ noise: torch.Tensor, noise_pred: torch.Tensor,
65
+ alphas_cumprod: torch.Tensor,
66
+ timestep: torch.Tensor,
67
+ **kwargs
68
+ ) -> torch.Tensor:
69
+ return torch.mean((kwargs["flow_pred"] - (noise - x)) ** 2)
70
+
71
+
72
+ NAME_TO_CLASS = {
73
+ "x0": X0PredLoss,
74
+ "v": VPredLoss,
75
+ "noise": NoisePredLoss,
76
+ "flow": FlowPredLoss
77
+ }
78
+
79
+
80
+ def get_denoising_loss(loss_type: str) -> DenoisingLoss:
81
+ return NAME_TO_CLASS[loss_type]
utils/misc.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import random
3
+ import torch
4
+
5
+
6
+ def set_seed(seed: int, deterministic: bool = False):
7
+ """
8
+ Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`.
9
+
10
+ Args:
11
+ seed (`int`):
12
+ The seed to set.
13
+ deterministic (`bool`, *optional*, defaults to `False`):
14
+ Whether to use deterministic algorithms where available. Can slow down training.
15
+ """
16
+ random.seed(seed)
17
+ np.random.seed(seed)
18
+ torch.manual_seed(seed)
19
+ torch.cuda.manual_seed_all(seed)
20
+
21
+ if deterministic:
22
+ torch.use_deterministic_algorithms(True)
23
+
24
+
25
+ def merge_dict_list(dict_list):
26
+ if len(dict_list) == 1:
27
+ return dict_list[0]
28
+
29
+ merged_dict = {}
30
+ for k, v in dict_list[0].items():
31
+ if isinstance(v, torch.Tensor):
32
+ if v.ndim == 0:
33
+ merged_dict[k] = torch.stack([d[k] for d in dict_list], dim=0)
34
+ else:
35
+ merged_dict[k] = torch.cat([d[k] for d in dict_list], dim=0)
36
+ else:
37
+ # for non-tensor values, we just copy the value from the first item
38
+ merged_dict[k] = v
39
+ return merged_dict
utils/scheduler.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod, ABC
2
+ import torch
3
+
4
+
5
+ class SchedulerInterface(ABC):
6
+ """
7
+ Base class for diffusion noise schedule.
8
+ """
9
+ alphas_cumprod: torch.Tensor # [T], alphas for defining the noise schedule
10
+
11
+ @abstractmethod
12
+ def add_noise(
13
+ self, clean_latent: torch.Tensor,
14
+ noise: torch.Tensor, timestep: torch.Tensor
15
+ ):
16
+ """
17
+ Diffusion forward corruption process.
18
+ Input:
19
+ - clean_latent: the clean latent with shape [B, C, H, W]
20
+ - noise: the noise with shape [B, C, H, W]
21
+ - timestep: the timestep with shape [B]
22
+ Output: the corrupted latent with shape [B, C, H, W]
23
+ """
24
+ pass
25
+
26
+ def convert_x0_to_noise(
27
+ self, x0: torch.Tensor, xt: torch.Tensor,
28
+ timestep: torch.Tensor
29
+ ) -> torch.Tensor:
30
+ """
31
+ Convert the diffusion network's x0 prediction to noise predidction.
32
+ x0: the predicted clean data with shape [B, C, H, W]
33
+ xt: the input noisy data with shape [B, C, H, W]
34
+ timestep: the timestep with shape [B]
35
+
36
+ noise = (xt-sqrt(alpha_t)*x0) / sqrt(beta_t) (eq 11 in https://arxiv.org/abs/2311.18828)
37
+ """
38
+ # use higher precision for calculations
39
+ original_dtype = x0.dtype
40
+ x0, xt, alphas_cumprod = map(
41
+ lambda x: x.double().to(x0.device), [x0, xt,
42
+ self.alphas_cumprod]
43
+ )
44
+
45
+ alpha_prod_t = alphas_cumprod[timestep].reshape(-1, 1, 1, 1)
46
+ beta_prod_t = 1 - alpha_prod_t
47
+
48
+ noise_pred = (xt - alpha_prod_t **
49
+ (0.5) * x0) / beta_prod_t ** (0.5)
50
+ return noise_pred.to(original_dtype)
51
+
52
+ def convert_noise_to_x0(
53
+ self, noise: torch.Tensor, xt: torch.Tensor,
54
+ timestep: torch.Tensor
55
+ ) -> torch.Tensor:
56
+ """
57
+ Convert the diffusion network's noise prediction to x0 predidction.
58
+ noise: the predicted noise with shape [B, C, H, W]
59
+ xt: the input noisy data with shape [B, C, H, W]
60
+ timestep: the timestep with shape [B]
61
+
62
+ x0 = (x_t - sqrt(beta_t) * noise) / sqrt(alpha_t) (eq 11 in https://arxiv.org/abs/2311.18828)
63
+ """
64
+ # use higher precision for calculations
65
+ original_dtype = noise.dtype
66
+ noise, xt, alphas_cumprod = map(
67
+ lambda x: x.double().to(noise.device), [noise, xt,
68
+ self.alphas_cumprod]
69
+ )
70
+ alpha_prod_t = alphas_cumprod[timestep].reshape(-1, 1, 1, 1)
71
+ beta_prod_t = 1 - alpha_prod_t
72
+
73
+ x0_pred = (xt - beta_prod_t **
74
+ (0.5) * noise) / alpha_prod_t ** (0.5)
75
+ return x0_pred.to(original_dtype)
76
+
77
+ def convert_velocity_to_x0(
78
+ self, velocity: torch.Tensor, xt: torch.Tensor,
79
+ timestep: torch.Tensor
80
+ ) -> torch.Tensor:
81
+ """
82
+ Convert the diffusion network's velocity prediction to x0 predidction.
83
+ velocity: the predicted noise with shape [B, C, H, W]
84
+ xt: the input noisy data with shape [B, C, H, W]
85
+ timestep: the timestep with shape [B]
86
+
87
+ v = sqrt(alpha_t) * noise - sqrt(beta_t) x0
88
+ noise = (xt-sqrt(alpha_t)*x0) / sqrt(beta_t)
89
+ given v, x_t, we have
90
+ x0 = sqrt(alpha_t) * x_t - sqrt(beta_t) * v
91
+ see derivations https://chatgpt.com/share/679fb6c8-3a30-8008-9b0e-d1ae892dac56
92
+ """
93
+ # use higher precision for calculations
94
+ original_dtype = velocity.dtype
95
+ velocity, xt, alphas_cumprod = map(
96
+ lambda x: x.double().to(velocity.device), [velocity, xt,
97
+ self.alphas_cumprod]
98
+ )
99
+ alpha_prod_t = alphas_cumprod[timestep].reshape(-1, 1, 1, 1)
100
+ beta_prod_t = 1 - alpha_prod_t
101
+
102
+ x0_pred = (alpha_prod_t ** 0.5) * xt - (beta_prod_t ** 0.5) * velocity
103
+ return x0_pred.to(original_dtype)
104
+
105
+
106
+ class FlowMatchScheduler():
107
+
108
+ def __init__(self, num_inference_steps=100, num_train_timesteps=1000, shift=3.0, sigma_max=1.0, sigma_min=0.003 / 1.002, inverse_timesteps=False, extra_one_step=False, reverse_sigmas=False):
109
+ self.num_train_timesteps = num_train_timesteps
110
+ self.shift = shift
111
+ self.sigma_max = sigma_max
112
+ self.sigma_min = sigma_min
113
+ self.inverse_timesteps = inverse_timesteps
114
+ self.extra_one_step = extra_one_step
115
+ self.reverse_sigmas = reverse_sigmas
116
+ self.set_timesteps(num_inference_steps)
117
+
118
+ def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False):
119
+ sigma_start = self.sigma_min + \
120
+ (self.sigma_max - self.sigma_min) * denoising_strength
121
+ if self.extra_one_step:
122
+ self.sigmas = torch.linspace(
123
+ sigma_start, self.sigma_min, num_inference_steps + 1)[:-1]
124
+ else:
125
+ self.sigmas = torch.linspace(
126
+ sigma_start, self.sigma_min, num_inference_steps)
127
+ if self.inverse_timesteps:
128
+ self.sigmas = torch.flip(self.sigmas, dims=[0])
129
+ self.sigmas = self.shift * self.sigmas / \
130
+ (1 + (self.shift - 1) * self.sigmas)
131
+ if self.reverse_sigmas:
132
+ self.sigmas = 1 - self.sigmas
133
+ self.timesteps = self.sigmas * self.num_train_timesteps
134
+ if training:
135
+ x = self.timesteps
136
+ y = torch.exp(-2 * ((x - num_inference_steps / 2) /
137
+ num_inference_steps) ** 2)
138
+ y_shifted = y - y.min()
139
+ bsmntw_weighing = y_shifted * \
140
+ (num_inference_steps / y_shifted.sum())
141
+ self.linear_timesteps_weights = bsmntw_weighing
142
+
143
+ def step(self, model_output, timestep, sample, to_final=False):
144
+ if timestep.ndim == 2:
145
+ timestep = timestep.flatten(0, 1)
146
+ self.sigmas = self.sigmas.to(model_output.device)
147
+ self.timesteps = self.timesteps.to(model_output.device)
148
+ timestep_id = torch.argmin(
149
+ (self.timesteps.unsqueeze(0) - timestep.unsqueeze(1)).abs(), dim=1)
150
+ sigma = self.sigmas[timestep_id].reshape(-1, 1, 1, 1)
151
+ if to_final or (timestep_id + 1 >= len(self.timesteps)).any():
152
+ sigma_ = 1 if (
153
+ self.inverse_timesteps or self.reverse_sigmas) else 0
154
+ else:
155
+ sigma_ = self.sigmas[timestep_id + 1].reshape(-1, 1, 1, 1)
156
+ prev_sample = sample + model_output * (sigma_ - sigma)
157
+ return prev_sample
158
+
159
+ def add_noise(self, original_samples, noise, timestep):
160
+ """
161
+ Diffusion forward corruption process.
162
+ Input:
163
+ - clean_latent: the clean latent with shape [B*T, C, H, W]
164
+ - noise: the noise with shape [B*T, C, H, W]
165
+ - timestep: the timestep with shape [B*T]
166
+ Output: the corrupted latent with shape [B*T, C, H, W]
167
+ """
168
+ if timestep.ndim == 2:
169
+ timestep = timestep.flatten(0, 1)
170
+ self.sigmas = self.sigmas.to(noise.device)
171
+ self.timesteps = self.timesteps.to(noise.device)
172
+ timestep_id = torch.argmin(
173
+ (self.timesteps.unsqueeze(0) - timestep.unsqueeze(1)).abs(), dim=1)
174
+ sigma = self.sigmas[timestep_id].reshape(-1, 1, 1, 1)
175
+ sample = (1 - sigma) * original_samples + sigma * noise
176
+ return sample.type_as(noise)
177
+
178
+ def training_target(self, sample, noise, timestep):
179
+ target = noise - sample
180
+ return target
181
+
182
+ def training_weight(self, timestep):
183
+ """
184
+ Input:
185
+ - timestep: the timestep with shape [B*T]
186
+ Output: the corresponding weighting [B*T]
187
+ """
188
+ if timestep.ndim == 2:
189
+ timestep = timestep.flatten(0, 1)
190
+ self.linear_timesteps_weights = self.linear_timesteps_weights.to(timestep.device)
191
+ timestep_id = torch.argmin(
192
+ (self.timesteps.unsqueeze(1) - timestep.unsqueeze(0)).abs(), dim=0)
193
+ weights = self.linear_timesteps_weights[timestep_id]
194
+ return weights
utils/wan_wrapper.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import types
2
+ from typing import List, Optional
3
+ import torch
4
+ from torch import nn
5
+
6
+ from utils.scheduler import SchedulerInterface, FlowMatchScheduler
7
+ from wan.modules.tokenizers import HuggingfaceTokenizer
8
+ from wan.modules.model import WanModel, RegisterTokens, GanAttentionBlock
9
+ from wan.modules.vae import _video_vae
10
+ from wan.modules.t5 import umt5_xxl
11
+ from wan.modules.causal_model import CausalWanModel
12
+
13
+
14
+ class WanTextEncoder(torch.nn.Module):
15
+ def __init__(self) -> None:
16
+ super().__init__()
17
+
18
+ self.text_encoder = umt5_xxl(
19
+ encoder_only=True,
20
+ return_tokenizer=False,
21
+ dtype=torch.float32,
22
+ device=torch.device('cpu')
23
+ ).eval().requires_grad_(False)
24
+ self.text_encoder.load_state_dict(
25
+ torch.load("wan_models/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth",
26
+ map_location='cpu', weights_only=False)
27
+ )
28
+
29
+ self.tokenizer = HuggingfaceTokenizer(
30
+ name="wan_models/Wan2.1-T2V-1.3B/google/umt5-xxl/", seq_len=512, clean='whitespace')
31
+
32
+ @property
33
+ def device(self):
34
+ # Assume we are always on GPU
35
+ return torch.cuda.current_device()
36
+
37
+ def forward(self, text_prompts: List[str]) -> dict:
38
+ ids, mask = self.tokenizer(
39
+ text_prompts, return_mask=True, add_special_tokens=True)
40
+ ids = ids.to(self.device)
41
+ mask = mask.to(self.device)
42
+ seq_lens = mask.gt(0).sum(dim=1).long()
43
+ context = self.text_encoder(ids, mask)
44
+
45
+ for u, v in zip(context, seq_lens):
46
+ u[v:] = 0.0 # set padding to 0.0
47
+
48
+ return {
49
+ "prompt_embeds": context
50
+ }
51
+
52
+
53
+ class WanVAEWrapper(torch.nn.Module):
54
+ def __init__(self):
55
+ super().__init__()
56
+ mean = [
57
+ -0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
58
+ 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
59
+ ]
60
+ std = [
61
+ 2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
62
+ 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
63
+ ]
64
+ self.mean = torch.tensor(mean, dtype=torch.float32)
65
+ self.std = torch.tensor(std, dtype=torch.float32)
66
+
67
+ # init model
68
+ self.model = _video_vae(
69
+ pretrained_path="wan_models/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
70
+ z_dim=16,
71
+ ).eval().requires_grad_(False)
72
+
73
+ def encode_to_latent(self, pixel: torch.Tensor) -> torch.Tensor:
74
+ # pixel: [batch_size, num_channels, num_frames, height, width]
75
+ device, dtype = pixel.device, pixel.dtype
76
+ scale = [self.mean.to(device=device, dtype=dtype),
77
+ 1.0 / self.std.to(device=device, dtype=dtype)]
78
+
79
+ output = [
80
+ self.model.encode(u.unsqueeze(0), scale).float().squeeze(0)
81
+ for u in pixel
82
+ ]
83
+ output = torch.stack(output, dim=0)
84
+ # from [batch_size, num_channels, num_frames, height, width]
85
+ # to [batch_size, num_frames, num_channels, height, width]
86
+ output = output.permute(0, 2, 1, 3, 4)
87
+ return output
88
+
89
+ def decode_to_pixel(self, latent: torch.Tensor, use_cache: bool = False) -> torch.Tensor:
90
+ # from [batch_size, num_frames, num_channels, height, width]
91
+ # to [batch_size, num_channels, num_frames, height, width]
92
+ zs = latent.permute(0, 2, 1, 3, 4)
93
+ if use_cache:
94
+ assert latent.shape[0] == 1, "Batch size must be 1 when using cache"
95
+
96
+ device, dtype = latent.device, latent.dtype
97
+ scale = [self.mean.to(device=device, dtype=dtype),
98
+ 1.0 / self.std.to(device=device, dtype=dtype)]
99
+
100
+ if use_cache:
101
+ decode_function = self.model.cached_decode
102
+ else:
103
+ decode_function = self.model.decode
104
+
105
+ output = []
106
+ for u in zs:
107
+ output.append(decode_function(u.unsqueeze(0), scale).float().clamp_(-1, 1).squeeze(0))
108
+ output = torch.stack(output, dim=0)
109
+ # from [batch_size, num_channels, num_frames, height, width]
110
+ # to [batch_size, num_frames, num_channels, height, width]
111
+ output = output.permute(0, 2, 1, 3, 4)
112
+ return output
113
+
114
+
115
+ class WanDiffusionWrapper(torch.nn.Module):
116
+ def __init__(
117
+ self,
118
+ model_name="Wan2.1-T2V-1.3B",
119
+ timestep_shift=8.0,
120
+ is_causal=False,
121
+ local_attn_size=-1,
122
+ sink_size=0
123
+ ):
124
+ super().__init__()
125
+
126
+ if is_causal:
127
+ self.model = CausalWanModel.from_pretrained(
128
+ f"wan_models/{model_name}/", local_attn_size=local_attn_size, sink_size=sink_size)
129
+ else:
130
+ self.model = WanModel.from_pretrained(f"wan_models/{model_name}/")
131
+ self.model.eval()
132
+
133
+ # For non-causal diffusion, all frames share the same timestep
134
+ self.uniform_timestep = not is_causal
135
+
136
+ self.scheduler = FlowMatchScheduler(
137
+ shift=timestep_shift, sigma_min=0.0, extra_one_step=True
138
+ )
139
+ self.scheduler.set_timesteps(1000, training=True)
140
+
141
+ self.seq_len = 32760 # [1, 21, 16, 60, 104]
142
+ self.post_init()
143
+
144
+ def enable_gradient_checkpointing(self) -> None:
145
+ self.model.enable_gradient_checkpointing()
146
+
147
+ def adding_cls_branch(self, atten_dim=1536, num_class=4, time_embed_dim=0) -> None:
148
+ # NOTE: This is hard coded for WAN2.1-T2V-1.3B for now!!!!!!!!!!!!!!!!!!!!
149
+ self._cls_pred_branch = nn.Sequential(
150
+ # Input: [B, 384, 21, 60, 104]
151
+ nn.LayerNorm(atten_dim * 3 + time_embed_dim),
152
+ nn.Linear(atten_dim * 3 + time_embed_dim, 1536),
153
+ nn.SiLU(),
154
+ nn.Linear(atten_dim, num_class)
155
+ )
156
+ self._cls_pred_branch.requires_grad_(True)
157
+ num_registers = 3
158
+ self._register_tokens = RegisterTokens(num_registers=num_registers, dim=atten_dim)
159
+ self._register_tokens.requires_grad_(True)
160
+
161
+ gan_ca_blocks = []
162
+ for _ in range(num_registers):
163
+ block = GanAttentionBlock()
164
+ gan_ca_blocks.append(block)
165
+ self._gan_ca_blocks = nn.ModuleList(gan_ca_blocks)
166
+ self._gan_ca_blocks.requires_grad_(True)
167
+ # self.has_cls_branch = True
168
+
169
+ def _convert_flow_pred_to_x0(self, flow_pred: torch.Tensor, xt: torch.Tensor, timestep: torch.Tensor) -> torch.Tensor:
170
+ """
171
+ Convert flow matching's prediction to x0 prediction.
172
+ flow_pred: the prediction with shape [B, C, H, W]
173
+ xt: the input noisy data with shape [B, C, H, W]
174
+ timestep: the timestep with shape [B]
175
+
176
+ pred = noise - x0
177
+ x_t = (1-sigma_t) * x0 + sigma_t * noise
178
+ we have x0 = x_t - sigma_t * pred
179
+ see derivations https://chatgpt.com/share/67bf8589-3d04-8008-bc6e-4cf1a24e2d0e
180
+ """
181
+ # use higher precision for calculations
182
+ original_dtype = flow_pred.dtype
183
+ flow_pred, xt, sigmas, timesteps = map(
184
+ lambda x: x.double().to(flow_pred.device), [flow_pred, xt,
185
+ self.scheduler.sigmas,
186
+ self.scheduler.timesteps]
187
+ )
188
+
189
+ timestep_id = torch.argmin(
190
+ (timesteps.unsqueeze(0) - timestep.unsqueeze(1)).abs(), dim=1)
191
+ sigma_t = sigmas[timestep_id].reshape(-1, 1, 1, 1)
192
+ x0_pred = xt - sigma_t * flow_pred
193
+ return x0_pred.to(original_dtype)
194
+
195
+ @staticmethod
196
+ def _convert_x0_to_flow_pred(scheduler, x0_pred: torch.Tensor, xt: torch.Tensor, timestep: torch.Tensor) -> torch.Tensor:
197
+ """
198
+ Convert x0 prediction to flow matching's prediction.
199
+ x0_pred: the x0 prediction with shape [B, C, H, W]
200
+ xt: the input noisy data with shape [B, C, H, W]
201
+ timestep: the timestep with shape [B]
202
+
203
+ pred = (x_t - x_0) / sigma_t
204
+ """
205
+ # use higher precision for calculations
206
+ original_dtype = x0_pred.dtype
207
+ x0_pred, xt, sigmas, timesteps = map(
208
+ lambda x: x.double().to(x0_pred.device), [x0_pred, xt,
209
+ scheduler.sigmas,
210
+ scheduler.timesteps]
211
+ )
212
+ timestep_id = torch.argmin(
213
+ (timesteps.unsqueeze(0) - timestep.unsqueeze(1)).abs(), dim=1)
214
+ sigma_t = sigmas[timestep_id].reshape(-1, 1, 1, 1)
215
+ flow_pred = (xt - x0_pred) / sigma_t
216
+ return flow_pred.to(original_dtype)
217
+
218
+ def forward(
219
+ self,
220
+ noisy_image_or_video: torch.Tensor, conditional_dict: dict,
221
+ timestep: torch.Tensor, kv_cache: Optional[List[dict]] = None,
222
+ crossattn_cache: Optional[List[dict]] = None,
223
+ current_start: Optional[int] = None,
224
+ classify_mode: Optional[bool] = False,
225
+ concat_time_embeddings: Optional[bool] = False,
226
+ clean_x: Optional[torch.Tensor] = None,
227
+ aug_t: Optional[torch.Tensor] = None,
228
+ cache_start: Optional[int] = None,
229
+ updating_cache: Optional[bool] = False
230
+ ) -> torch.Tensor:
231
+ prompt_embeds = conditional_dict["prompt_embeds"]
232
+
233
+ # [B, F] -> [B]
234
+ if self.uniform_timestep:
235
+ input_timestep = timestep[:, 0]
236
+ else:
237
+ input_timestep = timestep
238
+
239
+ logits = None
240
+ # X0 prediction
241
+ if kv_cache is not None:
242
+ flow_pred = self.model(
243
+ noisy_image_or_video.permute(0, 2, 1, 3, 4),
244
+ t=input_timestep, context=prompt_embeds,
245
+ seq_len=self.seq_len,
246
+ kv_cache=kv_cache,
247
+ crossattn_cache=crossattn_cache,
248
+ current_start=current_start,
249
+ cache_start=cache_start,
250
+ updating_cache=updating_cache
251
+ ).permute(0, 2, 1, 3, 4)
252
+ else:
253
+ if clean_x is not None:
254
+ # teacher forcing
255
+ flow_pred = self.model(
256
+ noisy_image_or_video.permute(0, 2, 1, 3, 4),
257
+ t=input_timestep, context=prompt_embeds,
258
+ seq_len=self.seq_len,
259
+ clean_x=clean_x.permute(0, 2, 1, 3, 4),
260
+ aug_t=aug_t,
261
+ ).permute(0, 2, 1, 3, 4)
262
+ else:
263
+ if classify_mode:
264
+ flow_pred, logits = self.model(
265
+ noisy_image_or_video.permute(0, 2, 1, 3, 4),
266
+ t=input_timestep, context=prompt_embeds,
267
+ seq_len=self.seq_len,
268
+ classify_mode=True,
269
+ register_tokens=self._register_tokens,
270
+ cls_pred_branch=self._cls_pred_branch,
271
+ gan_ca_blocks=self._gan_ca_blocks,
272
+ concat_time_embeddings=concat_time_embeddings
273
+ )
274
+ flow_pred = flow_pred.permute(0, 2, 1, 3, 4)
275
+ else:
276
+ flow_pred = self.model(
277
+ noisy_image_or_video.permute(0, 2, 1, 3, 4),
278
+ t=input_timestep, context=prompt_embeds,
279
+ seq_len=self.seq_len
280
+ ).permute(0, 2, 1, 3, 4)
281
+
282
+ pred_x0 = self._convert_flow_pred_to_x0(
283
+ flow_pred=flow_pred.flatten(0, 1),
284
+ xt=noisy_image_or_video.flatten(0, 1),
285
+ timestep=timestep.flatten(0, 1)
286
+ ).unflatten(0, flow_pred.shape[:2])
287
+
288
+ if logits is not None:
289
+ return flow_pred, pred_x0, logits
290
+
291
+ return flow_pred, pred_x0
292
+
293
+ def get_scheduler(self) -> SchedulerInterface:
294
+ """
295
+ Update the current scheduler with the interface's static method
296
+ """
297
+ scheduler = self.scheduler
298
+ scheduler.convert_x0_to_noise = types.MethodType(
299
+ SchedulerInterface.convert_x0_to_noise, scheduler)
300
+ scheduler.convert_noise_to_x0 = types.MethodType(
301
+ SchedulerInterface.convert_noise_to_x0, scheduler)
302
+ scheduler.convert_velocity_to_x0 = types.MethodType(
303
+ SchedulerInterface.convert_velocity_to_x0, scheduler)
304
+ self.scheduler = scheduler
305
+ return scheduler
306
+
307
+ def post_init(self):
308
+ """
309
+ A few custom initialization steps that should be called after the object is created.
310
+ Currently, the only one we have is to bind a few methods to scheduler.
311
+ We can gradually add more methods here if needed.
312
+ """
313
+ self.get_scheduler()
wan/README.md ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ Code in this folder is modified from https://github.com/Wan-Video/Wan2.1
2
+ Apache-2.0 License
wan/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from . import configs, distributed, modules
2
+ from .image2video import WanI2V
3
+ from .text2video import WanT2V
wan/configs/__init__.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ from .wan_t2v_14B import t2v_14B
3
+ from .wan_t2v_1_3B import t2v_1_3B
4
+ from .wan_i2v_14B import i2v_14B
5
+ import copy
6
+ import os
7
+
8
+ os.environ['TOKENIZERS_PARALLELISM'] = 'false'
9
+
10
+
11
+ # the config of t2i_14B is the same as t2v_14B
12
+ t2i_14B = copy.deepcopy(t2v_14B)
13
+ t2i_14B.__name__ = 'Config: Wan T2I 14B'
14
+
15
+ WAN_CONFIGS = {
16
+ 't2v-14B': t2v_14B,
17
+ 't2v-1.3B': t2v_1_3B,
18
+ 'i2v-14B': i2v_14B,
19
+ 't2i-14B': t2i_14B,
20
+ }
21
+
22
+ SIZE_CONFIGS = {
23
+ '720*1280': (720, 1280),
24
+ '1280*720': (1280, 720),
25
+ '480*832': (480, 832),
26
+ '832*480': (832, 480),
27
+ '1024*1024': (1024, 1024),
28
+ }
29
+
30
+ MAX_AREA_CONFIGS = {
31
+ '720*1280': 720 * 1280,
32
+ '1280*720': 1280 * 720,
33
+ '480*832': 480 * 832,
34
+ '832*480': 832 * 480,
35
+ }
36
+
37
+ SUPPORTED_SIZES = {
38
+ 't2v-14B': ('720*1280', '1280*720', '480*832', '832*480'),
39
+ 't2v-1.3B': ('480*832', '832*480'),
40
+ 'i2v-14B': ('720*1280', '1280*720', '480*832', '832*480'),
41
+ 't2i-14B': tuple(SIZE_CONFIGS.keys()),
42
+ }
wan/configs/shared_config.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import torch
3
+ from easydict import EasyDict
4
+
5
+ # ------------------------ Wan shared config ------------------------#
6
+ wan_shared_cfg = EasyDict()
7
+
8
+ # t5
9
+ wan_shared_cfg.t5_model = 'umt5_xxl'
10
+ wan_shared_cfg.t5_dtype = torch.bfloat16
11
+ wan_shared_cfg.text_len = 512
12
+
13
+ # transformer
14
+ wan_shared_cfg.param_dtype = torch.bfloat16
15
+
16
+ # inference
17
+ wan_shared_cfg.num_train_timesteps = 1000
18
+ wan_shared_cfg.sample_fps = 16
19
+ wan_shared_cfg.sample_neg_prompt = '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走'
wan/configs/wan_i2v_14B.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import torch
3
+ from easydict import EasyDict
4
+
5
+ from .shared_config import wan_shared_cfg
6
+
7
+ # ------------------------ Wan I2V 14B ------------------------#
8
+
9
+ i2v_14B = EasyDict(__name__='Config: Wan I2V 14B')
10
+ i2v_14B.update(wan_shared_cfg)
11
+
12
+ i2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
13
+ i2v_14B.t5_tokenizer = 'google/umt5-xxl'
14
+
15
+ # clip
16
+ i2v_14B.clip_model = 'clip_xlm_roberta_vit_h_14'
17
+ i2v_14B.clip_dtype = torch.float16
18
+ i2v_14B.clip_checkpoint = 'models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth'
19
+ i2v_14B.clip_tokenizer = 'xlm-roberta-large'
20
+
21
+ # vae
22
+ i2v_14B.vae_checkpoint = 'Wan2.1_VAE.pth'
23
+ i2v_14B.vae_stride = (4, 8, 8)
24
+
25
+ # transformer
26
+ i2v_14B.patch_size = (1, 2, 2)
27
+ i2v_14B.dim = 5120
28
+ i2v_14B.ffn_dim = 13824
29
+ i2v_14B.freq_dim = 256
30
+ i2v_14B.num_heads = 40
31
+ i2v_14B.num_layers = 40
32
+ i2v_14B.window_size = (-1, -1)
33
+ i2v_14B.qk_norm = True
34
+ i2v_14B.cross_attn_norm = True
35
+ i2v_14B.eps = 1e-6
wan/configs/wan_t2v_14B.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ from easydict import EasyDict
3
+
4
+ from .shared_config import wan_shared_cfg
5
+
6
+ # ------------------------ Wan T2V 14B ------------------------#
7
+
8
+ t2v_14B = EasyDict(__name__='Config: Wan T2V 14B')
9
+ t2v_14B.update(wan_shared_cfg)
10
+
11
+ # t5
12
+ t2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
13
+ t2v_14B.t5_tokenizer = 'google/umt5-xxl'
14
+
15
+ # vae
16
+ t2v_14B.vae_checkpoint = 'Wan2.1_VAE.pth'
17
+ t2v_14B.vae_stride = (4, 8, 8)
18
+
19
+ # transformer
20
+ t2v_14B.patch_size = (1, 2, 2)
21
+ t2v_14B.dim = 5120
22
+ t2v_14B.ffn_dim = 13824
23
+ t2v_14B.freq_dim = 256
24
+ t2v_14B.num_heads = 40
25
+ t2v_14B.num_layers = 40
26
+ t2v_14B.window_size = (-1, -1)
27
+ t2v_14B.qk_norm = True
28
+ t2v_14B.cross_attn_norm = True
29
+ t2v_14B.eps = 1e-6
wan/configs/wan_t2v_1_3B.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ from easydict import EasyDict
3
+
4
+ from .shared_config import wan_shared_cfg
5
+
6
+ # ------------------------ Wan T2V 1.3B ------------------------#
7
+
8
+ t2v_1_3B = EasyDict(__name__='Config: Wan T2V 1.3B')
9
+ t2v_1_3B.update(wan_shared_cfg)
10
+
11
+ # t5
12
+ t2v_1_3B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
13
+ t2v_1_3B.t5_tokenizer = 'google/umt5-xxl'
14
+
15
+ # vae
16
+ t2v_1_3B.vae_checkpoint = 'Wan2.1_VAE.pth'
17
+ t2v_1_3B.vae_stride = (4, 8, 8)
18
+
19
+ # transformer
20
+ t2v_1_3B.patch_size = (1, 2, 2)
21
+ t2v_1_3B.dim = 1536
22
+ t2v_1_3B.ffn_dim = 8960
23
+ t2v_1_3B.freq_dim = 256
24
+ t2v_1_3B.num_heads = 12
25
+ t2v_1_3B.num_layers = 30
26
+ t2v_1_3B.window_size = (-1, -1)
27
+ t2v_1_3B.qk_norm = True
28
+ t2v_1_3B.cross_attn_norm = True
29
+ t2v_1_3B.eps = 1e-6
wan/distributed/__init__.py ADDED
File without changes
wan/distributed/fsdp.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ from functools import partial
3
+
4
+ import torch
5
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
6
+ from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
7
+ from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
8
+
9
+
10
+ def shard_model(
11
+ model,
12
+ device_id,
13
+ param_dtype=torch.bfloat16,
14
+ reduce_dtype=torch.float32,
15
+ buffer_dtype=torch.float32,
16
+ process_group=None,
17
+ sharding_strategy=ShardingStrategy.FULL_SHARD,
18
+ sync_module_states=True,
19
+ ):
20
+ model = FSDP(
21
+ module=model,
22
+ process_group=process_group,
23
+ sharding_strategy=sharding_strategy,
24
+ auto_wrap_policy=partial(
25
+ lambda_auto_wrap_policy, lambda_fn=lambda m: m in model.blocks),
26
+ mixed_precision=MixedPrecision(
27
+ param_dtype=param_dtype,
28
+ reduce_dtype=reduce_dtype,
29
+ buffer_dtype=buffer_dtype),
30
+ device_id=device_id,
31
+ use_orig_params=True,
32
+ sync_module_states=sync_module_states)
33
+ return model
wan/distributed/xdit_context_parallel.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import torch
3
+ import torch.cuda.amp as amp
4
+ from xfuser.core.distributed import (get_sequence_parallel_rank,
5
+ get_sequence_parallel_world_size,
6
+ get_sp_group)
7
+ from xfuser.core.long_ctx_attention import xFuserLongContextAttention
8
+
9
+ from ..modules.model import sinusoidal_embedding_1d
10
+
11
+
12
+ def pad_freqs(original_tensor, target_len):
13
+ seq_len, s1, s2 = original_tensor.shape
14
+ pad_size = target_len - seq_len
15
+ padding_tensor = torch.ones(
16
+ pad_size,
17
+ s1,
18
+ s2,
19
+ dtype=original_tensor.dtype,
20
+ device=original_tensor.device)
21
+ padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
22
+ return padded_tensor
23
+
24
+
25
+ @amp.autocast(enabled=False)
26
+ def rope_apply(x, grid_sizes, freqs):
27
+ """
28
+ x: [B, L, N, C].
29
+ grid_sizes: [B, 3].
30
+ freqs: [M, C // 2].
31
+ """
32
+ s, n, c = x.size(1), x.size(2), x.size(3) // 2
33
+ # split freqs
34
+ freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
35
+
36
+ # loop over samples
37
+ output = []
38
+ for i, (f, h, w) in enumerate(grid_sizes.tolist()):
39
+ seq_len = f * h * w
40
+
41
+ # precompute multipliers
42
+ x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape(
43
+ s, n, -1, 2))
44
+ freqs_i = torch.cat([
45
+ freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
46
+ freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
47
+ freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
48
+ ],
49
+ dim=-1).reshape(seq_len, 1, -1)
50
+
51
+ # apply rotary embedding
52
+ sp_size = get_sequence_parallel_world_size()
53
+ sp_rank = get_sequence_parallel_rank()
54
+ freqs_i = pad_freqs(freqs_i, s * sp_size)
55
+ s_per_rank = s
56
+ freqs_i_rank = freqs_i[(sp_rank * s_per_rank):((sp_rank + 1) *
57
+ s_per_rank), :, :]
58
+ x_i = torch.view_as_real(x_i * freqs_i_rank).flatten(2)
59
+ x_i = torch.cat([x_i, x[i, s:]])
60
+
61
+ # append to collection
62
+ output.append(x_i)
63
+ return torch.stack(output).float()
64
+
65
+
66
+ def usp_dit_forward(
67
+ self,
68
+ x,
69
+ t,
70
+ context,
71
+ seq_len,
72
+ clip_fea=None,
73
+ y=None,
74
+ ):
75
+ """
76
+ x: A list of videos each with shape [C, T, H, W].
77
+ t: [B].
78
+ context: A list of text embeddings each with shape [L, C].
79
+ """
80
+ if self.model_type == 'i2v':
81
+ assert clip_fea is not None and y is not None
82
+ # params
83
+ device = self.patch_embedding.weight.device
84
+ if self.freqs.device != device:
85
+ self.freqs = self.freqs.to(device)
86
+
87
+ if y is not None:
88
+ x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
89
+
90
+ # embeddings
91
+ x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
92
+ grid_sizes = torch.stack(
93
+ [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
94
+ x = [u.flatten(2).transpose(1, 2) for u in x]
95
+ seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
96
+ assert seq_lens.max() <= seq_len
97
+ x = torch.cat([
98
+ torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1)
99
+ for u in x
100
+ ])
101
+
102
+ # time embeddings
103
+ with amp.autocast(dtype=torch.float32):
104
+ e = self.time_embedding(
105
+ sinusoidal_embedding_1d(self.freq_dim, t).float())
106
+ e0 = self.time_projection(e).unflatten(1, (6, self.dim))
107
+ assert e.dtype == torch.float32 and e0.dtype == torch.float32
108
+
109
+ # context
110
+ context_lens = None
111
+ context = self.text_embedding(
112
+ torch.stack([
113
+ torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
114
+ for u in context
115
+ ]))
116
+
117
+ if clip_fea is not None:
118
+ context_clip = self.img_emb(clip_fea) # bs x 257 x dim
119
+ context = torch.concat([context_clip, context], dim=1)
120
+
121
+ # arguments
122
+ kwargs = dict(
123
+ e=e0,
124
+ seq_lens=seq_lens,
125
+ grid_sizes=grid_sizes,
126
+ freqs=self.freqs,
127
+ context=context,
128
+ context_lens=context_lens)
129
+
130
+ # Context Parallel
131
+ x = torch.chunk(
132
+ x, get_sequence_parallel_world_size(),
133
+ dim=1)[get_sequence_parallel_rank()]
134
+
135
+ for block in self.blocks:
136
+ x = block(x, **kwargs)
137
+
138
+ # head
139
+ x = self.head(x, e)
140
+
141
+ # Context Parallel
142
+ x = get_sp_group().all_gather(x, dim=1)
143
+
144
+ # unpatchify
145
+ x = self.unpatchify(x, grid_sizes)
146
+ return [u.float() for u in x]
147
+
148
+
149
+ def usp_attn_forward(self,
150
+ x,
151
+ seq_lens,
152
+ grid_sizes,
153
+ freqs,
154
+ dtype=torch.bfloat16):
155
+ b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
156
+ half_dtypes = (torch.float16, torch.bfloat16)
157
+
158
+ def half(x):
159
+ return x if x.dtype in half_dtypes else x.to(dtype)
160
+
161
+ # query, key, value function
162
+ def qkv_fn(x):
163
+ q = self.norm_q(self.q(x)).view(b, s, n, d)
164
+ k = self.norm_k(self.k(x)).view(b, s, n, d)
165
+ v = self.v(x).view(b, s, n, d)
166
+ return q, k, v
167
+
168
+ q, k, v = qkv_fn(x)
169
+ q = rope_apply(q, grid_sizes, freqs)
170
+ k = rope_apply(k, grid_sizes, freqs)
171
+
172
+ # TODO: We should use unpaded q,k,v for attention.
173
+ # k_lens = seq_lens // get_sequence_parallel_world_size()
174
+ # if k_lens is not None:
175
+ # q = torch.cat([u[:l] for u, l in zip(q, k_lens)]).unsqueeze(0)
176
+ # k = torch.cat([u[:l] for u, l in zip(k, k_lens)]).unsqueeze(0)
177
+ # v = torch.cat([u[:l] for u, l in zip(v, k_lens)]).unsqueeze(0)
178
+
179
+ x = xFuserLongContextAttention()(
180
+ None,
181
+ query=half(q),
182
+ key=half(k),
183
+ value=half(v),
184
+ window_size=self.window_size)
185
+
186
+ # TODO: padding after attention.
187
+ # x = torch.cat([x, x.new_zeros(b, s - x.size(1), n, d)], dim=1)
188
+
189
+ # output
190
+ x = x.flatten(2)
191
+ x = self.o(x)
192
+ return x
wan/image2video.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import gc
3
+ import logging
4
+ import math
5
+ import os
6
+ import random
7
+ import sys
8
+ import types
9
+ from contextlib import contextmanager
10
+ from functools import partial
11
+
12
+ import numpy as np
13
+ import torch
14
+ import torch.cuda.amp as amp
15
+ import torch.distributed as dist
16
+ import torchvision.transforms.functional as TF
17
+ from tqdm import tqdm
18
+
19
+ from .distributed.fsdp import shard_model
20
+ from .modules.clip import CLIPModel
21
+ from .modules.model import WanModel
22
+ from .modules.t5 import T5EncoderModel
23
+ from .modules.vae import WanVAE
24
+ from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
25
+ get_sampling_sigmas, retrieve_timesteps)
26
+ from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
27
+
28
+
29
+ class WanI2V:
30
+
31
+ def __init__(
32
+ self,
33
+ config,
34
+ checkpoint_dir,
35
+ device_id=0,
36
+ rank=0,
37
+ t5_fsdp=False,
38
+ dit_fsdp=False,
39
+ use_usp=False,
40
+ t5_cpu=False,
41
+ init_on_cpu=True,
42
+ ):
43
+ r"""
44
+ Initializes the image-to-video generation model components.
45
+
46
+ Args:
47
+ config (EasyDict):
48
+ Object containing model parameters initialized from config.py
49
+ checkpoint_dir (`str`):
50
+ Path to directory containing model checkpoints
51
+ device_id (`int`, *optional*, defaults to 0):
52
+ Id of target GPU device
53
+ rank (`int`, *optional*, defaults to 0):
54
+ Process rank for distributed training
55
+ t5_fsdp (`bool`, *optional*, defaults to False):
56
+ Enable FSDP sharding for T5 model
57
+ dit_fsdp (`bool`, *optional*, defaults to False):
58
+ Enable FSDP sharding for DiT model
59
+ use_usp (`bool`, *optional*, defaults to False):
60
+ Enable distribution strategy of USP.
61
+ t5_cpu (`bool`, *optional*, defaults to False):
62
+ Whether to place T5 model on CPU. Only works without t5_fsdp.
63
+ init_on_cpu (`bool`, *optional*, defaults to True):
64
+ Enable initializing Transformer Model on CPU. Only works without FSDP or USP.
65
+ """
66
+ self.device = torch.device(f"cuda:{device_id}")
67
+ self.config = config
68
+ self.rank = rank
69
+ self.use_usp = use_usp
70
+ self.t5_cpu = t5_cpu
71
+
72
+ self.num_train_timesteps = config.num_train_timesteps
73
+ self.param_dtype = config.param_dtype
74
+
75
+ shard_fn = partial(shard_model, device_id=device_id)
76
+ self.text_encoder = T5EncoderModel(
77
+ text_len=config.text_len,
78
+ dtype=config.t5_dtype,
79
+ device=torch.device('cpu'),
80
+ checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),
81
+ tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
82
+ shard_fn=shard_fn if t5_fsdp else None,
83
+ )
84
+
85
+ self.vae_stride = config.vae_stride
86
+ self.patch_size = config.patch_size
87
+ self.vae = WanVAE(
88
+ vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
89
+ device=self.device)
90
+
91
+ self.clip = CLIPModel(
92
+ dtype=config.clip_dtype,
93
+ device=self.device,
94
+ checkpoint_path=os.path.join(checkpoint_dir,
95
+ config.clip_checkpoint),
96
+ tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer))
97
+
98
+ logging.info(f"Creating WanModel from {checkpoint_dir}")
99
+ self.model = WanModel.from_pretrained(checkpoint_dir)
100
+ self.model.eval().requires_grad_(False)
101
+
102
+ if t5_fsdp or dit_fsdp or use_usp:
103
+ init_on_cpu = False
104
+
105
+ if use_usp:
106
+ from xfuser.core.distributed import \
107
+ get_sequence_parallel_world_size
108
+
109
+ from .distributed.xdit_context_parallel import (usp_attn_forward,
110
+ usp_dit_forward)
111
+ for block in self.model.blocks:
112
+ block.self_attn.forward = types.MethodType(
113
+ usp_attn_forward, block.self_attn)
114
+ self.model.forward = types.MethodType(usp_dit_forward, self.model)
115
+ self.sp_size = get_sequence_parallel_world_size()
116
+ else:
117
+ self.sp_size = 1
118
+
119
+ if dist.is_initialized():
120
+ dist.barrier()
121
+ if dit_fsdp:
122
+ self.model = shard_fn(self.model)
123
+ else:
124
+ if not init_on_cpu:
125
+ self.model.to(self.device)
126
+
127
+ self.sample_neg_prompt = config.sample_neg_prompt
128
+
129
+ def generate(self,
130
+ input_prompt,
131
+ img,
132
+ max_area=720 * 1280,
133
+ frame_num=81,
134
+ shift=5.0,
135
+ sample_solver='unipc',
136
+ sampling_steps=40,
137
+ guide_scale=5.0,
138
+ n_prompt="",
139
+ seed=-1,
140
+ offload_model=True):
141
+ r"""
142
+ Generates video frames from input image and text prompt using diffusion process.
143
+
144
+ Args:
145
+ input_prompt (`str`):
146
+ Text prompt for content generation.
147
+ img (PIL.Image.Image):
148
+ Input image tensor. Shape: [3, H, W]
149
+ max_area (`int`, *optional*, defaults to 720*1280):
150
+ Maximum pixel area for latent space calculation. Controls video resolution scaling
151
+ frame_num (`int`, *optional*, defaults to 81):
152
+ How many frames to sample from a video. The number should be 4n+1
153
+ shift (`float`, *optional*, defaults to 5.0):
154
+ Noise schedule shift parameter. Affects temporal dynamics
155
+ [NOTE]: If you want to generate a 480p video, it is recommended to set the shift value to 3.0.
156
+ sample_solver (`str`, *optional*, defaults to 'unipc'):
157
+ Solver used to sample the video.
158
+ sampling_steps (`int`, *optional*, defaults to 40):
159
+ Number of diffusion sampling steps. Higher values improve quality but slow generation
160
+ guide_scale (`float`, *optional*, defaults 5.0):
161
+ Classifier-free guidance scale. Controls prompt adherence vs. creativity
162
+ n_prompt (`str`, *optional*, defaults to ""):
163
+ Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`
164
+ seed (`int`, *optional*, defaults to -1):
165
+ Random seed for noise generation. If -1, use random seed
166
+ offload_model (`bool`, *optional*, defaults to True):
167
+ If True, offloads models to CPU during generation to save VRAM
168
+
169
+ Returns:
170
+ torch.Tensor:
171
+ Generated video frames tensor. Dimensions: (C, N H, W) where:
172
+ - C: Color channels (3 for RGB)
173
+ - N: Number of frames (81)
174
+ - H: Frame height (from max_area)
175
+ - W: Frame width from max_area)
176
+ """
177
+ img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device)
178
+
179
+ F = frame_num
180
+ h, w = img.shape[1:]
181
+ aspect_ratio = h / w
182
+ lat_h = round(
183
+ np.sqrt(max_area * aspect_ratio) // self.vae_stride[1] //
184
+ self.patch_size[1] * self.patch_size[1])
185
+ lat_w = round(
186
+ np.sqrt(max_area / aspect_ratio) // self.vae_stride[2] //
187
+ self.patch_size[2] * self.patch_size[2])
188
+ h = lat_h * self.vae_stride[1]
189
+ w = lat_w * self.vae_stride[2]
190
+
191
+ max_seq_len = ((F - 1) // self.vae_stride[0] + 1) * lat_h * lat_w // (
192
+ self.patch_size[1] * self.patch_size[2])
193
+ max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size
194
+
195
+ seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
196
+ seed_g = torch.Generator(device=self.device)
197
+ seed_g.manual_seed(seed)
198
+ noise = torch.randn(
199
+ 16,
200
+ 21,
201
+ lat_h,
202
+ lat_w,
203
+ dtype=torch.float32,
204
+ generator=seed_g,
205
+ device=self.device)
206
+
207
+ msk = torch.ones(1, 81, lat_h, lat_w, device=self.device)
208
+ msk[:, 1:] = 0
209
+ msk = torch.concat([
210
+ torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]
211
+ ],
212
+ dim=1)
213
+ msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
214
+ msk = msk.transpose(1, 2)[0]
215
+
216
+ if n_prompt == "":
217
+ n_prompt = self.sample_neg_prompt
218
+
219
+ # preprocess
220
+ if not self.t5_cpu:
221
+ self.text_encoder.model.to(self.device)
222
+ context = self.text_encoder([input_prompt], self.device)
223
+ context_null = self.text_encoder([n_prompt], self.device)
224
+ if offload_model:
225
+ self.text_encoder.model.cpu()
226
+ else:
227
+ context = self.text_encoder([input_prompt], torch.device('cpu'))
228
+ context_null = self.text_encoder([n_prompt], torch.device('cpu'))
229
+ context = [t.to(self.device) for t in context]
230
+ context_null = [t.to(self.device) for t in context_null]
231
+
232
+ self.clip.model.to(self.device)
233
+ clip_context = self.clip.visual([img[:, None, :, :]])
234
+ if offload_model:
235
+ self.clip.model.cpu()
236
+
237
+ y = self.vae.encode([
238
+ torch.concat([
239
+ torch.nn.functional.interpolate(
240
+ img[None].cpu(), size=(h, w), mode='bicubic').transpose(
241
+ 0, 1),
242
+ torch.zeros(3, 80, h, w)
243
+ ],
244
+ dim=1).to(self.device)
245
+ ])[0]
246
+ y = torch.concat([msk, y])
247
+
248
+ @contextmanager
249
+ def noop_no_sync():
250
+ yield
251
+
252
+ no_sync = getattr(self.model, 'no_sync', noop_no_sync)
253
+
254
+ # evaluation mode
255
+ with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync():
256
+
257
+ if sample_solver == 'unipc':
258
+ sample_scheduler = FlowUniPCMultistepScheduler(
259
+ num_train_timesteps=self.num_train_timesteps,
260
+ shift=1,
261
+ use_dynamic_shifting=False)
262
+ sample_scheduler.set_timesteps(
263
+ sampling_steps, device=self.device, shift=shift)
264
+ timesteps = sample_scheduler.timesteps
265
+ elif sample_solver == 'dpm++':
266
+ sample_scheduler = FlowDPMSolverMultistepScheduler(
267
+ num_train_timesteps=self.num_train_timesteps,
268
+ shift=1,
269
+ use_dynamic_shifting=False)
270
+ sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
271
+ timesteps, _ = retrieve_timesteps(
272
+ sample_scheduler,
273
+ device=self.device,
274
+ sigmas=sampling_sigmas)
275
+ else:
276
+ raise NotImplementedError("Unsupported solver.")
277
+
278
+ # sample videos
279
+ latent = noise
280
+
281
+ arg_c = {
282
+ 'context': [context[0]],
283
+ 'clip_fea': clip_context,
284
+ 'seq_len': max_seq_len,
285
+ 'y': [y],
286
+ }
287
+
288
+ arg_null = {
289
+ 'context': context_null,
290
+ 'clip_fea': clip_context,
291
+ 'seq_len': max_seq_len,
292
+ 'y': [y],
293
+ }
294
+
295
+ if offload_model:
296
+ torch.cuda.empty_cache()
297
+
298
+ self.model.to(self.device)
299
+ for _, t in enumerate(tqdm(timesteps)):
300
+ latent_model_input = [latent.to(self.device)]
301
+ timestep = [t]
302
+
303
+ timestep = torch.stack(timestep).to(self.device)
304
+
305
+ noise_pred_cond = self.model(
306
+ latent_model_input, t=timestep, **arg_c)[0].to(
307
+ torch.device('cpu') if offload_model else self.device)
308
+ if offload_model:
309
+ torch.cuda.empty_cache()
310
+ noise_pred_uncond = self.model(
311
+ latent_model_input, t=timestep, **arg_null)[0].to(
312
+ torch.device('cpu') if offload_model else self.device)
313
+ if offload_model:
314
+ torch.cuda.empty_cache()
315
+ noise_pred = noise_pred_uncond + guide_scale * (
316
+ noise_pred_cond - noise_pred_uncond)
317
+
318
+ latent = latent.to(
319
+ torch.device('cpu') if offload_model else self.device)
320
+
321
+ temp_x0 = sample_scheduler.step(
322
+ noise_pred.unsqueeze(0),
323
+ t,
324
+ latent.unsqueeze(0),
325
+ return_dict=False,
326
+ generator=seed_g)[0]
327
+ latent = temp_x0.squeeze(0)
328
+
329
+ x0 = [latent.to(self.device)]
330
+ del latent_model_input, timestep
331
+
332
+ if offload_model:
333
+ self.model.cpu()
334
+ torch.cuda.empty_cache()
335
+
336
+ if self.rank == 0:
337
+ videos = self.vae.decode(x0)
338
+
339
+ del noise, latent
340
+ del sample_scheduler
341
+ if offload_model:
342
+ gc.collect()
343
+ torch.cuda.synchronize()
344
+ if dist.is_initialized():
345
+ dist.barrier()
346
+
347
+ return videos[0] if self.rank == 0 else None
wan/modules/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .attention import flash_attention
2
+ from .model import WanModel
3
+ from .t5 import T5Decoder, T5Encoder, T5EncoderModel, T5Model
4
+ from .tokenizers import HuggingfaceTokenizer
5
+ from .vae import WanVAE
6
+
7
+ __all__ = [
8
+ 'WanVAE',
9
+ 'WanModel',
10
+ 'T5Model',
11
+ 'T5Encoder',
12
+ 'T5Decoder',
13
+ 'T5EncoderModel',
14
+ 'HuggingfaceTokenizer',
15
+ 'flash_attention',
16
+ ]
wan/modules/attention.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import torch
3
+
4
+ try:
5
+ import flash_attn_interface
6
+
7
+ def is_hopper_gpu():
8
+ if not torch.cuda.is_available():
9
+ return False
10
+ device_name = torch.cuda.get_device_name(0).lower()
11
+ return "h100" in device_name or "hopper" in device_name
12
+ FLASH_ATTN_3_AVAILABLE = is_hopper_gpu()
13
+ except ModuleNotFoundError:
14
+ FLASH_ATTN_3_AVAILABLE = False
15
+
16
+ try:
17
+ import flash_attn
18
+ FLASH_ATTN_2_AVAILABLE = True
19
+ except ModuleNotFoundError:
20
+ FLASH_ATTN_2_AVAILABLE = False
21
+
22
+ # FLASH_ATTN_3_AVAILABLE = False
23
+
24
+ import warnings
25
+
26
+ __all__ = [
27
+ 'flash_attention',
28
+ 'attention',
29
+ ]
30
+
31
+
32
+ def flash_attention(
33
+ q,
34
+ k,
35
+ v,
36
+ q_lens=None,
37
+ k_lens=None,
38
+ dropout_p=0.,
39
+ softmax_scale=None,
40
+ q_scale=None,
41
+ causal=False,
42
+ window_size=(-1, -1),
43
+ deterministic=False,
44
+ dtype=torch.bfloat16,
45
+ version=None,
46
+ ):
47
+ """
48
+ q: [B, Lq, Nq, C1].
49
+ k: [B, Lk, Nk, C1].
50
+ v: [B, Lk, Nk, C2]. Nq must be divisible by Nk.
51
+ q_lens: [B].
52
+ k_lens: [B].
53
+ dropout_p: float. Dropout probability.
54
+ softmax_scale: float. The scaling of QK^T before applying softmax.
55
+ causal: bool. Whether to apply causal attention mask.
56
+ window_size: (left right). If not (-1, -1), apply sliding window local attention.
57
+ deterministic: bool. If True, slightly slower and uses more memory.
58
+ dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16.
59
+ """
60
+ half_dtypes = (torch.float16, torch.bfloat16)
61
+ assert dtype in half_dtypes
62
+ assert q.device.type == 'cuda' and q.size(-1) <= 256
63
+
64
+ # params
65
+ b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype
66
+
67
+ def half(x):
68
+ return x if x.dtype in half_dtypes else x.to(dtype)
69
+
70
+ # preprocess query
71
+ if q_lens is None:
72
+ q = half(q.flatten(0, 1))
73
+ q_lens = torch.tensor(
74
+ [lq] * b, dtype=torch.int32).to(
75
+ device=q.device, non_blocking=True)
76
+ else:
77
+ q = half(torch.cat([u[:v] for u, v in zip(q, q_lens)]))
78
+
79
+ # preprocess key, value
80
+ if k_lens is None:
81
+ k = half(k.flatten(0, 1))
82
+ v = half(v.flatten(0, 1))
83
+ k_lens = torch.tensor(
84
+ [lk] * b, dtype=torch.int32).to(
85
+ device=k.device, non_blocking=True)
86
+ else:
87
+ k = half(torch.cat([u[:v] for u, v in zip(k, k_lens)]))
88
+ v = half(torch.cat([u[:v] for u, v in zip(v, k_lens)]))
89
+
90
+ q = q.to(v.dtype)
91
+ k = k.to(v.dtype)
92
+
93
+ if q_scale is not None:
94
+ q = q * q_scale
95
+
96
+ if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE:
97
+ warnings.warn(
98
+ 'Flash attention 3 is not available, use flash attention 2 instead.'
99
+ )
100
+
101
+ # apply attention
102
+ if (version is None or version == 3) and FLASH_ATTN_3_AVAILABLE:
103
+ # Note: dropout_p, window_size are not supported in FA3 now.
104
+ x = flash_attn_interface.flash_attn_varlen_func(
105
+ q=q,
106
+ k=k,
107
+ v=v,
108
+ cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
109
+ 0, dtype=torch.int32).to(q.device, non_blocking=True),
110
+ cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
111
+ 0, dtype=torch.int32).to(q.device, non_blocking=True),
112
+ max_seqlen_q=lq,
113
+ max_seqlen_k=lk,
114
+ softmax_scale=softmax_scale,
115
+ causal=causal,
116
+ deterministic=deterministic)[0].unflatten(0, (b, lq))
117
+ else:
118
+ assert FLASH_ATTN_2_AVAILABLE
119
+ x = flash_attn.flash_attn_varlen_func(
120
+ q=q,
121
+ k=k,
122
+ v=v,
123
+ cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
124
+ 0, dtype=torch.int32).to(q.device, non_blocking=True),
125
+ cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
126
+ 0, dtype=torch.int32).to(q.device, non_blocking=True),
127
+ max_seqlen_q=lq,
128
+ max_seqlen_k=lk,
129
+ dropout_p=dropout_p,
130
+ softmax_scale=softmax_scale,
131
+ causal=causal,
132
+ window_size=window_size,
133
+ deterministic=deterministic).unflatten(0, (b, lq))
134
+
135
+ # output
136
+ return x.type(out_dtype)
137
+
138
+
139
+ def attention(
140
+ q,
141
+ k,
142
+ v,
143
+ q_lens=None,
144
+ k_lens=None,
145
+ dropout_p=0.,
146
+ softmax_scale=None,
147
+ q_scale=None,
148
+ causal=False,
149
+ window_size=(-1, -1),
150
+ deterministic=False,
151
+ dtype=torch.bfloat16,
152
+ fa_version=None,
153
+ ):
154
+ if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE:
155
+ return flash_attention(
156
+ q=q,
157
+ k=k,
158
+ v=v,
159
+ q_lens=q_lens,
160
+ k_lens=k_lens,
161
+ dropout_p=dropout_p,
162
+ softmax_scale=softmax_scale,
163
+ q_scale=q_scale,
164
+ causal=causal,
165
+ window_size=window_size,
166
+ deterministic=deterministic,
167
+ dtype=dtype,
168
+ version=fa_version,
169
+ )
170
+ else:
171
+ if q_lens is not None or k_lens is not None:
172
+ warnings.warn(
173
+ 'Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance.'
174
+ )
175
+ attn_mask = None
176
+
177
+ q = q.transpose(1, 2).to(dtype)
178
+ k = k.transpose(1, 2).to(dtype)
179
+ v = v.transpose(1, 2).to(dtype)
180
+
181
+ out = torch.nn.functional.scaled_dot_product_attention(
182
+ q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p)
183
+
184
+ out = out.transpose(1, 2).contiguous()
185
+ return out
wan/modules/causal_model.py ADDED
@@ -0,0 +1,1127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from wan.modules.attention import attention
2
+ from wan.modules.model import (
3
+ WanRMSNorm,
4
+ rope_apply,
5
+ WanLayerNorm,
6
+ WAN_CROSSATTENTION_CLASSES,
7
+ rope_params,
8
+ MLPProj,
9
+ sinusoidal_embedding_1d
10
+ )
11
+ # from torch.nn.attention.flex_attention import create_block_mask, flex_attention
12
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
13
+ # from torch.nn.attention.flex_attention import BlockMask
14
+ from diffusers.models.modeling_utils import ModelMixin
15
+ import torch.nn as nn
16
+ import torch
17
+ import math
18
+ import torch.distributed as dist
19
+
20
+ # wan 1.3B model has a weird channel / head configurations and require max-autotune to work with flexattention
21
+ # see https://github.com/pytorch/pytorch/issues/133254
22
+ # change to default for other models
23
+ # flex_attention = torch.compile(
24
+ # flex_attention, dynamic=False, mode="max-autotune-no-cudagraphs")
25
+
26
+
27
+ def causal_rope_apply(x, grid_sizes, freqs, start_frame=0):
28
+ n, c = x.size(2), x.size(3) // 2
29
+
30
+ # split freqs
31
+ freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
32
+
33
+ # loop over samples
34
+ output = []
35
+
36
+ for i, (f, h, w) in enumerate(grid_sizes.tolist()):
37
+ seq_len = f * h * w
38
+
39
+ # precompute multipliers
40
+ x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float64).reshape(
41
+ seq_len, n, -1, 2))
42
+ freqs_i = torch.cat([
43
+ freqs[0][start_frame:start_frame + f].view(f, 1, 1, -1).expand(f, h, w, -1),
44
+ freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
45
+ freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
46
+ ],
47
+ dim=-1).reshape(seq_len, 1, -1)
48
+
49
+ # apply rotary embedding
50
+ x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
51
+ x_i = torch.cat([x_i, x[i, seq_len:]])
52
+
53
+ # append to collection
54
+ output.append(x_i)
55
+ return torch.stack(output).type_as(x)
56
+
57
+
58
+ class CausalWanSelfAttention(nn.Module):
59
+
60
+ def __init__(self,
61
+ dim,
62
+ num_heads,
63
+ local_attn_size=-1,
64
+ sink_size=1,
65
+ qk_norm=True,
66
+ eps=1e-6):
67
+ assert dim % num_heads == 0
68
+ super().__init__()
69
+ self.dim = dim
70
+ self.num_heads = num_heads
71
+ self.head_dim = dim // num_heads
72
+ self.local_attn_size = local_attn_size
73
+ self.qk_norm = qk_norm
74
+ self.eps = eps
75
+ self.frame_length = 1560
76
+ self.max_attention_size = 21 * self.frame_length
77
+ self.block_length = 3 * self.frame_length
78
+
79
+ # layers
80
+ self.q = nn.Linear(dim, dim)
81
+ self.k = nn.Linear(dim, dim)
82
+ self.v = nn.Linear(dim, dim)
83
+ self.o = nn.Linear(dim, dim)
84
+ self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
85
+ self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
86
+
87
+ def forward(
88
+ self,
89
+ x,
90
+ seq_lens,
91
+ grid_sizes,
92
+ freqs,
93
+ block_mask,
94
+ kv_cache=None,
95
+ current_start=0,
96
+ cache_start=None,
97
+ updating_cache=False
98
+ ):
99
+ r"""
100
+ Args:
101
+ x(Tensor): Shape [B, L, num_heads, C / num_heads]
102
+ seq_lens(Tensor): Shape [B]
103
+ grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
104
+ freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
105
+ block_mask (BlockMask)
106
+ """
107
+ b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
108
+ if cache_start is None:
109
+ cache_start = current_start
110
+
111
+ # query, key, value function
112
+ def qkv_fn(x):
113
+ q = self.norm_q(self.q(x)).view(b, s, n, d) # [B, L, 12, 128]
114
+ k = self.norm_k(self.k(x)).view(b, s, n, d) # [B, L, 12, 128]
115
+ v = self.v(x).view(b, s, n, d) # [B, L, 12, 128]
116
+ return q, k, v
117
+
118
+ q, k, v = qkv_fn(x)
119
+
120
+ if kv_cache is None:
121
+ # if it is teacher forcing training?
122
+ is_tf = (s == seq_lens[0].item() * 2)
123
+ if is_tf:
124
+ q_chunk = torch.chunk(q, 2, dim=1)
125
+ k_chunk = torch.chunk(k, 2, dim=1)
126
+ roped_query = []
127
+ roped_key = []
128
+ # rope should be same for clean and noisy parts
129
+ for ii in range(2):
130
+ rq = rope_apply(q_chunk[ii], grid_sizes, freqs).type_as(v)
131
+ rk = rope_apply(k_chunk[ii], grid_sizes, freqs).type_as(v)
132
+ roped_query.append(rq)
133
+ roped_key.append(rk)
134
+
135
+ roped_query = torch.cat(roped_query, dim=1)
136
+ roped_key = torch.cat(roped_key, dim=1)
137
+
138
+ padded_length = math.ceil(q.shape[1] / 128) * 128 - q.shape[1]
139
+ padded_roped_query = torch.cat(
140
+ [roped_query,
141
+ torch.zeros([q.shape[0], padded_length, q.shape[2], q.shape[3]],
142
+ device=q.device, dtype=v.dtype)],
143
+ dim=1
144
+ )
145
+
146
+ padded_roped_key = torch.cat(
147
+ [roped_key, torch.zeros([k.shape[0], padded_length, k.shape[2], k.shape[3]],
148
+ device=k.device, dtype=v.dtype)],
149
+ dim=1
150
+ )
151
+
152
+ padded_v = torch.cat(
153
+ [v, torch.zeros([v.shape[0], padded_length, v.shape[2], v.shape[3]],
154
+ device=v.device, dtype=v.dtype)],
155
+ dim=1
156
+ )
157
+
158
+ x = flex_attention(
159
+ query=padded_roped_query.transpose(2, 1),
160
+ key=padded_roped_key.transpose(2, 1),
161
+ value=padded_v.transpose(2, 1),
162
+ block_mask=block_mask
163
+ )[:, :, :-padded_length].transpose(2, 1)
164
+
165
+ else:
166
+ roped_query = rope_apply(q, grid_sizes, freqs).type_as(v)
167
+ roped_key = rope_apply(k, grid_sizes, freqs).type_as(v)
168
+
169
+ padded_length = math.ceil(q.shape[1] / 128) * 128 - q.shape[1]
170
+ padded_roped_query = torch.cat(
171
+ [roped_query,
172
+ torch.zeros([q.shape[0], padded_length, q.shape[2], q.shape[3]],
173
+ device=q.device, dtype=v.dtype)],
174
+ dim=1
175
+ )
176
+
177
+ padded_roped_key = torch.cat(
178
+ [roped_key, torch.zeros([k.shape[0], padded_length, k.shape[2], k.shape[3]],
179
+ device=k.device, dtype=v.dtype)],
180
+ dim=1
181
+ )
182
+
183
+ padded_v = torch.cat(
184
+ [v, torch.zeros([v.shape[0], padded_length, v.shape[2], v.shape[3]],
185
+ device=v.device, dtype=v.dtype)],
186
+ dim=1
187
+ )
188
+
189
+ x = flex_attention(
190
+ query=padded_roped_query.transpose(2, 1),
191
+ key=padded_roped_key.transpose(2, 1),
192
+ value=padded_v.transpose(2, 1),
193
+ block_mask=block_mask
194
+ )[:, :, :-padded_length].transpose(2, 1)
195
+ else:
196
+ frame_seqlen = math.prod(grid_sizes[0][1:]).item()
197
+ current_start_frame = current_start // frame_seqlen
198
+ roped_query = causal_rope_apply(
199
+ q, grid_sizes, freqs, start_frame=current_start_frame).type_as(v) # [B, L, 12, 128]
200
+ roped_key = causal_rope_apply(
201
+ k, grid_sizes, freqs, start_frame=current_start_frame).type_as(v) # [B, L, 12, 128]
202
+
203
+ grid_sizes_one_block = grid_sizes.clone()
204
+ grid_sizes_one_block[:,0] = 3
205
+
206
+ # only caching the first block
207
+ cache_end = cache_start + self.block_length
208
+ num_new_tokens = cache_end - kv_cache["global_end_index"].item()
209
+ kv_cache_size = kv_cache["k"].shape[1]
210
+
211
+ sink_tokens = 1 * self.block_length # we keep the first block in the cache
212
+
213
+ if (num_new_tokens > 0) and (
214
+ num_new_tokens + kv_cache["local_end_index"].item() > kv_cache_size):
215
+ num_evicted_tokens = num_new_tokens + kv_cache["local_end_index"].item() - kv_cache_size
216
+ num_rolled_tokens = kv_cache["local_end_index"].item() - num_evicted_tokens - sink_tokens
217
+ kv_cache["k"][:, sink_tokens:sink_tokens + num_rolled_tokens] = \
218
+ kv_cache["k"][:, sink_tokens + num_evicted_tokens:sink_tokens + num_evicted_tokens + num_rolled_tokens].clone()
219
+ kv_cache["v"][:, sink_tokens:sink_tokens + num_rolled_tokens] = \
220
+ kv_cache["v"][:, sink_tokens + num_evicted_tokens:sink_tokens + num_evicted_tokens + num_rolled_tokens].clone()
221
+
222
+ local_end_index = kv_cache["local_end_index"].item() + cache_end - \
223
+ kv_cache["global_end_index"].item() - num_evicted_tokens
224
+ local_start_index = local_end_index - self.block_length
225
+ kv_cache["k"][:, local_start_index:local_end_index] = roped_key[:, :self.block_length]
226
+ kv_cache["v"][:, local_start_index:local_end_index] = v[:, :self.block_length]
227
+ else:
228
+ local_end_index = kv_cache["local_end_index"].item() + cache_end - kv_cache["global_end_index"].item()
229
+ local_start_index = local_end_index - self.block_length
230
+ if local_start_index == 0: # first block is not roped in the cache
231
+ kv_cache["k"][:, local_start_index:local_end_index] = k[:, :self.block_length]
232
+ else:
233
+ kv_cache["k"][:, local_start_index:local_end_index] = roped_key[:, :self.block_length]
234
+
235
+ kv_cache["v"][:, local_start_index:local_end_index] = v[:, :self.block_length]
236
+
237
+ if num_new_tokens > 0: # prevent updating when caching clean frame
238
+ kv_cache["global_end_index"].fill_(cache_end)
239
+ kv_cache["local_end_index"].fill_(local_end_index)
240
+
241
+ if local_start_index == 0:
242
+ # no kv attn with cache
243
+ x = attention(
244
+ roped_query,
245
+ roped_key,
246
+ v)
247
+ else:
248
+ if updating_cache: # updating working cache with clean frame
249
+ extract_cache_end = local_end_index
250
+ extract_cache_start = max(0, local_end_index-self.max_attention_size)
251
+ working_cache_key = kv_cache["k"][:, extract_cache_start:extract_cache_end].clone()
252
+ working_cache_v = kv_cache["v"][:, extract_cache_start:extract_cache_end]
253
+
254
+ if extract_cache_start == 0: # rope the global first block in working cache
255
+ working_cache_key[:,:self.block_length] = causal_rope_apply(
256
+ working_cache_key[:,:self.block_length], grid_sizes_one_block, freqs, start_frame=0).type_as(v)
257
+
258
+ x = attention(
259
+ roped_query,
260
+ working_cache_key,
261
+ working_cache_v
262
+ )
263
+
264
+ else:
265
+ # 1. extract working cache
266
+ # calculate the length of working cache
267
+ query_length = roped_query.shape[1]
268
+ working_cache_max_length = self.max_attention_size - query_length - self.block_length
269
+
270
+ extract_cache_end = local_start_index
271
+ extract_cache_start = max(self.block_length, local_start_index - working_cache_max_length) # working cache does not include the first anchor block
272
+ working_cache_key = kv_cache["k"][:, extract_cache_start:extract_cache_end]
273
+ working_cache_v = kv_cache["v"][:, extract_cache_start:extract_cache_end]
274
+
275
+ # 2. extract anchor cache, roped as the past frame
276
+ working_cache_frame_length = working_cache_key.shape[1] // self.frame_length
277
+ rope_start_frame = current_start_frame - working_cache_frame_length - 3
278
+
279
+ anchor_cache_key = causal_rope_apply(
280
+ kv_cache["k"][:, :self.block_length], grid_sizes_one_block, freqs, start_frame=rope_start_frame).type_as(v)
281
+ anchor_cache_v = kv_cache["v"][:, :self.block_length]
282
+
283
+ # 3. attention with working cache and anchor cache
284
+ input_key = torch.cat([
285
+ anchor_cache_key,
286
+ working_cache_key,
287
+ roped_key
288
+ ], dim=1)
289
+
290
+ input_v = torch.cat([
291
+ anchor_cache_v,
292
+ working_cache_v,
293
+ v
294
+ ], dim=1)
295
+
296
+ x = attention(
297
+ roped_query,
298
+ input_key,
299
+ input_v
300
+ )
301
+
302
+
303
+ # output
304
+ x = x.flatten(2)
305
+ x = self.o(x)
306
+ return x
307
+
308
+
309
+ class CausalWanAttentionBlock(nn.Module):
310
+
311
+ def __init__(self,
312
+ cross_attn_type,
313
+ dim,
314
+ ffn_dim,
315
+ num_heads,
316
+ local_attn_size=-1,
317
+ sink_size=0,
318
+ qk_norm=True,
319
+ cross_attn_norm=False,
320
+ eps=1e-6):
321
+ super().__init__()
322
+ self.dim = dim
323
+ self.ffn_dim = ffn_dim
324
+ self.num_heads = num_heads
325
+ self.local_attn_size = local_attn_size
326
+ self.qk_norm = qk_norm
327
+ self.cross_attn_norm = cross_attn_norm
328
+ self.eps = eps
329
+
330
+ # layers
331
+ self.norm1 = WanLayerNorm(dim, eps)
332
+ self.self_attn = CausalWanSelfAttention(dim, num_heads, local_attn_size, sink_size, qk_norm, eps)
333
+ self.norm3 = WanLayerNorm(
334
+ dim, eps,
335
+ elementwise_affine=True) if cross_attn_norm else nn.Identity()
336
+ self.cross_attn = WAN_CROSSATTENTION_CLASSES[cross_attn_type](dim,
337
+ num_heads,
338
+ (-1, -1),
339
+ qk_norm,
340
+ eps)
341
+ self.norm2 = WanLayerNorm(dim, eps)
342
+ self.ffn = nn.Sequential(
343
+ nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'),
344
+ nn.Linear(ffn_dim, dim))
345
+
346
+ # modulation
347
+ self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
348
+
349
+ def forward(
350
+ self,
351
+ x,
352
+ e,
353
+ seq_lens,
354
+ grid_sizes,
355
+ freqs,
356
+ context,
357
+ context_lens,
358
+ block_mask,
359
+ updating_cache=False,
360
+ kv_cache=None,
361
+ crossattn_cache=None,
362
+ current_start=0,
363
+ cache_start=None
364
+ ):
365
+ r"""
366
+ Args:
367
+ x(Tensor): Shape [B, L, C]
368
+ e(Tensor): Shape [B, F, 6, C]
369
+ seq_lens(Tensor): Shape [B], length of each sequence in batch
370
+ grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
371
+ freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
372
+ """
373
+ num_frames, frame_seqlen = e.shape[1], x.shape[1] // e.shape[1]
374
+ # assert e.dtype == torch.float32
375
+ # with amp.autocast(dtype=torch.float32):
376
+ e = (self.modulation.unsqueeze(1) + e).chunk(6, dim=2)
377
+ # assert e[0].dtype == torch.float32
378
+
379
+ # self-attention
380
+ y = self.self_attn(
381
+ (self.norm1(x).unflatten(dim=1, sizes=(num_frames, frame_seqlen)) * (1 + e[1]) + e[0]).flatten(1, 2),
382
+ seq_lens, grid_sizes,
383
+ freqs, block_mask, kv_cache, current_start, cache_start, updating_cache=updating_cache)
384
+
385
+ # with amp.autocast(dtype=torch.float32):
386
+ x = x + (y.unflatten(dim=1, sizes=(num_frames, frame_seqlen)) * e[2]).flatten(1, 2)
387
+
388
+ # cross-attention & ffn function
389
+ def cross_attn_ffn(x, context, context_lens, e, crossattn_cache=None):
390
+ x = x + self.cross_attn(self.norm3(x), context,
391
+ context_lens, crossattn_cache=crossattn_cache)
392
+ y = self.ffn(
393
+ (self.norm2(x).unflatten(dim=1, sizes=(num_frames,
394
+ frame_seqlen)) * (1 + e[4]) + e[3]).flatten(1, 2)
395
+ )
396
+ # with amp.autocast(dtype=torch.float32):
397
+ x = x + (y.unflatten(dim=1, sizes=(num_frames,
398
+ frame_seqlen)) * e[5]).flatten(1, 2)
399
+ return x
400
+
401
+ x = cross_attn_ffn(x, context, context_lens, e, crossattn_cache)
402
+ return x
403
+
404
+
405
+ class CausalHead(nn.Module):
406
+
407
+ def __init__(self, dim, out_dim, patch_size, eps=1e-6):
408
+ super().__init__()
409
+ self.dim = dim
410
+ self.out_dim = out_dim
411
+ self.patch_size = patch_size
412
+ self.eps = eps
413
+
414
+ # layers
415
+ out_dim = math.prod(patch_size) * out_dim
416
+ self.norm = WanLayerNorm(dim, eps)
417
+ self.head = nn.Linear(dim, out_dim)
418
+
419
+ # modulation
420
+ self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
421
+
422
+ def forward(self, x, e):
423
+ r"""
424
+ Args:
425
+ x(Tensor): Shape [B, L1, C]
426
+ e(Tensor): Shape [B, F, 1, C]
427
+ """
428
+ # assert e.dtype == torch.float32
429
+ # with amp.autocast(dtype=torch.float32):
430
+ num_frames, frame_seqlen = e.shape[1], x.shape[1] // e.shape[1]
431
+ e = (self.modulation.unsqueeze(1) + e).chunk(2, dim=2)
432
+ x = (self.head(self.norm(x).unflatten(dim=1, sizes=(num_frames, frame_seqlen)) * (1 + e[1]) + e[0]))
433
+ return x
434
+
435
+
436
+ class CausalWanModel(ModelMixin, ConfigMixin):
437
+ r"""
438
+ Wan diffusion backbone supporting both text-to-video and image-to-video.
439
+ """
440
+
441
+ ignore_for_config = [
442
+ 'patch_size', 'cross_attn_norm', 'qk_norm', 'text_dim'
443
+ ]
444
+ _no_split_modules = ['WanAttentionBlock']
445
+ _supports_gradient_checkpointing = True
446
+
447
+ @register_to_config
448
+ def __init__(self,
449
+ model_type='t2v',
450
+ patch_size=(1, 2, 2),
451
+ text_len=512,
452
+ in_dim=16,
453
+ dim=2048,
454
+ ffn_dim=8192,
455
+ freq_dim=256,
456
+ text_dim=4096,
457
+ out_dim=16,
458
+ num_heads=16,
459
+ num_layers=32,
460
+ local_attn_size=-1,
461
+ sink_size=0,
462
+ qk_norm=True,
463
+ cross_attn_norm=True,
464
+ eps=1e-6):
465
+ r"""
466
+ Initialize the diffusion model backbone.
467
+
468
+ Args:
469
+ model_type (`str`, *optional*, defaults to 't2v'):
470
+ Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video)
471
+ patch_size (`tuple`, *optional*, defaults to (1, 2, 2)):
472
+ 3D patch dimensions for video embedding (t_patch, h_patch, w_patch)
473
+ text_len (`int`, *optional*, defaults to 512):
474
+ Fixed length for text embeddings
475
+ in_dim (`int`, *optional*, defaults to 16):
476
+ Input video channels (C_in)
477
+ dim (`int`, *optional*, defaults to 2048):
478
+ Hidden dimension of the transformer
479
+ ffn_dim (`int`, *optional*, defaults to 8192):
480
+ Intermediate dimension in feed-forward network
481
+ freq_dim (`int`, *optional*, defaults to 256):
482
+ Dimension for sinusoidal time embeddings
483
+ text_dim (`int`, *optional*, defaults to 4096):
484
+ Input dimension for text embeddings
485
+ out_dim (`int`, *optional*, defaults to 16):
486
+ Output video channels (C_out)
487
+ num_heads (`int`, *optional*, defaults to 16):
488
+ Number of attention heads
489
+ num_layers (`int`, *optional*, defaults to 32):
490
+ Number of transformer blocks
491
+ local_attn_size (`int`, *optional*, defaults to -1):
492
+ Window size for temporal local attention (-1 indicates global attention)
493
+ sink_size (`int`, *optional*, defaults to 0):
494
+ Size of the attention sink, we keep the first `sink_size` frames unchanged when rolling the KV cache
495
+ qk_norm (`bool`, *optional*, defaults to True):
496
+ Enable query/key normalization
497
+ cross_attn_norm (`bool`, *optional*, defaults to False):
498
+ Enable cross-attention normalization
499
+ eps (`float`, *optional*, defaults to 1e-6):
500
+ Epsilon value for normalization layers
501
+ """
502
+
503
+ super().__init__()
504
+
505
+ assert model_type in ['t2v', 'i2v']
506
+ self.model_type = model_type
507
+
508
+ self.patch_size = patch_size
509
+ self.text_len = text_len
510
+ self.in_dim = in_dim
511
+ self.dim = dim
512
+ self.ffn_dim = ffn_dim
513
+ self.freq_dim = freq_dim
514
+ self.text_dim = text_dim
515
+ self.out_dim = out_dim
516
+ self.num_heads = num_heads
517
+ self.num_layers = num_layers
518
+ self.local_attn_size = local_attn_size
519
+ self.qk_norm = qk_norm
520
+ self.cross_attn_norm = cross_attn_norm
521
+ self.eps = eps
522
+
523
+ # embeddings
524
+ self.patch_embedding = nn.Conv3d(
525
+ in_dim, dim, kernel_size=patch_size, stride=patch_size)
526
+ self.text_embedding = nn.Sequential(
527
+ nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'),
528
+ nn.Linear(dim, dim))
529
+
530
+ self.time_embedding = nn.Sequential(
531
+ nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
532
+ self.time_projection = nn.Sequential(
533
+ nn.SiLU(), nn.Linear(dim, dim * 6))
534
+
535
+ # blocks
536
+ cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn'
537
+ self.blocks = nn.ModuleList([
538
+ CausalWanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads,
539
+ local_attn_size, sink_size, qk_norm, cross_attn_norm, eps)
540
+ for _ in range(num_layers)
541
+ ])
542
+
543
+ # head
544
+ self.head = CausalHead(dim, out_dim, patch_size, eps)
545
+
546
+ # buffers (don't use register_buffer otherwise dtype will be changed in to())
547
+ assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
548
+ d = dim // num_heads
549
+ self.freqs = torch.cat([
550
+ rope_params(1024, d - 4 * (d // 6)),
551
+ rope_params(1024, 2 * (d // 6)),
552
+ rope_params(1024, 2 * (d // 6))
553
+ ],
554
+ dim=1)
555
+
556
+ if model_type == 'i2v':
557
+ self.img_emb = MLPProj(1280, dim)
558
+
559
+ # initialize weights
560
+ self.init_weights()
561
+
562
+ self.gradient_checkpointing = False
563
+
564
+ self.block_mask = None
565
+
566
+ self.num_frame_per_block = 1
567
+ self.independent_first_frame = False
568
+
569
+ def _set_gradient_checkpointing(self, module, value=False):
570
+ self.gradient_checkpointing = value
571
+
572
+ @staticmethod
573
+ def _prepare_blockwise_causal_attn_mask(
574
+ device: torch.device | str, num_frames: int = 21,
575
+ frame_seqlen: int = 1560, num_frame_per_block=1, local_attn_size=-1
576
+ ):
577
+ """
578
+ we will divide the token sequence into the following format
579
+ [1 latent frame] [1 latent frame] ... [1 latent frame]
580
+ We use flexattention to construct the attention mask
581
+ """
582
+ total_length = num_frames * frame_seqlen
583
+
584
+ # we do right padding to get to a multiple of 128
585
+ padded_length = math.ceil(total_length / 128) * 128 - total_length
586
+
587
+ ends = torch.zeros(total_length + padded_length,
588
+ device=device, dtype=torch.long)
589
+
590
+ # Block-wise causal mask will attend to all elements that are before the end of the current chunk
591
+ frame_indices = torch.arange(
592
+ start=0,
593
+ end=total_length,
594
+ step=frame_seqlen * num_frame_per_block,
595
+ device=device
596
+ )
597
+
598
+ for tmp in frame_indices:
599
+ ends[tmp:tmp + frame_seqlen * num_frame_per_block] = tmp + \
600
+ frame_seqlen * num_frame_per_block
601
+
602
+ def attention_mask(b, h, q_idx, kv_idx):
603
+ if local_attn_size == -1:
604
+ return (kv_idx < ends[q_idx]) | (q_idx == kv_idx)
605
+ else:
606
+ return ((kv_idx < ends[q_idx]) & (kv_idx >= (ends[q_idx] - local_attn_size * frame_seqlen))) | (q_idx == kv_idx)
607
+ # return ((kv_idx < total_length) & (q_idx < total_length)) | (q_idx == kv_idx) # bidirectional mask
608
+
609
+ block_mask = create_block_mask(attention_mask, B=None, H=None, Q_LEN=total_length + padded_length,
610
+ KV_LEN=total_length + padded_length, _compile=False, device=device)
611
+
612
+ import torch.distributed as dist
613
+ if not dist.is_initialized() or dist.get_rank() == 0:
614
+ print(
615
+ f" cache a block wise causal mask with block size of {num_frame_per_block} frames")
616
+ print(block_mask)
617
+
618
+ # import imageio
619
+ # import numpy as np
620
+ # from torch.nn.attention.flex_attention import create_mask
621
+
622
+ # mask = create_mask(attention_mask, B=None, H=None, Q_LEN=total_length +
623
+ # padded_length, KV_LEN=total_length + padded_length, device=device)
624
+ # import cv2
625
+ # mask = cv2.resize(mask[0, 0].cpu().float().numpy(), (1024, 1024))
626
+ # imageio.imwrite("mask_%d.jpg" % (0), np.uint8(255. * mask))
627
+
628
+ return block_mask
629
+
630
+ @staticmethod
631
+ def _prepare_teacher_forcing_mask(
632
+ device: torch.device | str, num_frames: int = 21,
633
+ frame_seqlen: int = 1560, num_frame_per_block=1
634
+ ):
635
+ """
636
+ we will divide the token sequence into the following format
637
+ [1 latent frame] [1 latent frame] ... [1 latent frame]
638
+ We use flexattention to construct the attention mask
639
+ """
640
+ # debug
641
+ DEBUG = False
642
+ if DEBUG:
643
+ num_frames = 9
644
+ frame_seqlen = 256
645
+
646
+ total_length = num_frames * frame_seqlen * 2
647
+
648
+ # we do right padding to get to a multiple of 128
649
+ padded_length = math.ceil(total_length / 128) * 128 - total_length
650
+
651
+ clean_ends = num_frames * frame_seqlen
652
+ # for clean context frames, we can construct their flex attention mask based on a [start, end] interval
653
+ context_ends = torch.zeros(total_length + padded_length, device=device, dtype=torch.long)
654
+ # for noisy frames, we need two intervals to construct the flex attention mask [context_start, context_end] [noisy_start, noisy_end]
655
+ noise_context_starts = torch.zeros(total_length + padded_length, device=device, dtype=torch.long)
656
+ noise_context_ends = torch.zeros(total_length + padded_length, device=device, dtype=torch.long)
657
+ noise_noise_starts = torch.zeros(total_length + padded_length, device=device, dtype=torch.long)
658
+ noise_noise_ends = torch.zeros(total_length + padded_length, device=device, dtype=torch.long)
659
+
660
+ # Block-wise causal mask will attend to all elements that are before the end of the current chunk
661
+ attention_block_size = frame_seqlen * num_frame_per_block
662
+ frame_indices = torch.arange(
663
+ start=0,
664
+ end=num_frames * frame_seqlen,
665
+ step=attention_block_size,
666
+ device=device, dtype=torch.long
667
+ )
668
+
669
+ # attention for clean context frames
670
+ for start in frame_indices:
671
+ context_ends[start:start + attention_block_size] = start + attention_block_size
672
+
673
+ noisy_image_start_list = torch.arange(
674
+ num_frames * frame_seqlen, total_length,
675
+ step=attention_block_size,
676
+ device=device, dtype=torch.long
677
+ )
678
+ noisy_image_end_list = noisy_image_start_list + attention_block_size
679
+
680
+ # attention for noisy frames
681
+ for block_index, (start, end) in enumerate(zip(noisy_image_start_list, noisy_image_end_list)):
682
+ # attend to noisy tokens within the same block
683
+ noise_noise_starts[start:end] = start
684
+ noise_noise_ends[start:end] = end
685
+ # attend to context tokens in previous blocks
686
+ # noise_context_starts[start:end] = 0
687
+ noise_context_ends[start:end] = block_index * attention_block_size
688
+
689
+ def attention_mask(b, h, q_idx, kv_idx):
690
+ # first design the mask for clean frames
691
+ clean_mask = (q_idx < clean_ends) & (kv_idx < context_ends[q_idx])
692
+ # then design the mask for noisy frames
693
+ # noisy frames will attend to all clean preceeding clean frames + itself
694
+ C1 = (kv_idx < noise_noise_ends[q_idx]) & (kv_idx >= noise_noise_starts[q_idx])
695
+ C2 = (kv_idx < noise_context_ends[q_idx]) & (kv_idx >= noise_context_starts[q_idx])
696
+ noise_mask = (q_idx >= clean_ends) & (C1 | C2)
697
+
698
+ eye_mask = q_idx == kv_idx
699
+ return eye_mask | clean_mask | noise_mask
700
+
701
+ block_mask = create_block_mask(attention_mask, B=None, H=None, Q_LEN=total_length + padded_length,
702
+ KV_LEN=total_length + padded_length, _compile=False, device=device)
703
+
704
+ if DEBUG:
705
+ print(block_mask)
706
+ import imageio
707
+ import numpy as np
708
+ from torch.nn.attention.flex_attention import create_mask
709
+
710
+ mask = create_mask(attention_mask, B=None, H=None, Q_LEN=total_length +
711
+ padded_length, KV_LEN=total_length + padded_length, device=device)
712
+ import cv2
713
+ mask = cv2.resize(mask[0, 0].cpu().float().numpy(), (1024, 1024))
714
+ imageio.imwrite("mask_%d.jpg" % (0), np.uint8(255. * mask))
715
+
716
+ return block_mask
717
+
718
+ @staticmethod
719
+ def _prepare_blockwise_causal_attn_mask_i2v(
720
+ device: torch.device | str, num_frames: int = 21,
721
+ frame_seqlen: int = 1560, num_frame_per_block=4, local_attn_size=-1
722
+ ):
723
+ """
724
+ we will divide the token sequence into the following format
725
+ [1 latent frame] [N latent frame] ... [N latent frame]
726
+ The first frame is separated out to support I2V generation
727
+ We use flexattention to construct the attention mask
728
+ """
729
+ total_length = num_frames * frame_seqlen
730
+
731
+ # we do right padding to get to a multiple of 128
732
+ padded_length = math.ceil(total_length / 128) * 128 - total_length
733
+
734
+ ends = torch.zeros(total_length + padded_length,
735
+ device=device, dtype=torch.long)
736
+
737
+ # special handling for the first frame
738
+ ends[:frame_seqlen] = frame_seqlen
739
+
740
+ # Block-wise causal mask will attend to all elements that are before the end of the current chunk
741
+ frame_indices = torch.arange(
742
+ start=frame_seqlen,
743
+ end=total_length,
744
+ step=frame_seqlen * num_frame_per_block,
745
+ device=device
746
+ )
747
+
748
+ for idx, tmp in enumerate(frame_indices):
749
+ ends[tmp:tmp + frame_seqlen * num_frame_per_block] = tmp + \
750
+ frame_seqlen * num_frame_per_block
751
+
752
+ def attention_mask(b, h, q_idx, kv_idx):
753
+ if local_attn_size == -1:
754
+ return (kv_idx < ends[q_idx]) | (q_idx == kv_idx)
755
+ else:
756
+ return ((kv_idx < ends[q_idx]) & (kv_idx >= (ends[q_idx] - local_attn_size * frame_seqlen))) | \
757
+ (q_idx == kv_idx)
758
+
759
+ block_mask = create_block_mask(attention_mask, B=None, H=None, Q_LEN=total_length + padded_length,
760
+ KV_LEN=total_length + padded_length, _compile=False, device=device)
761
+
762
+ if not dist.is_initialized() or dist.get_rank() == 0:
763
+ print(
764
+ f" cache a block wise causal mask with block size of {num_frame_per_block} frames")
765
+ print(block_mask)
766
+
767
+ # import imageio
768
+ # import numpy as np
769
+ # from torch.nn.attention.flex_attention import create_mask
770
+
771
+ # mask = create_mask(attention_mask, B=None, H=None, Q_LEN=total_length +
772
+ # padded_length, KV_LEN=total_length + padded_length, device=device)
773
+ # import cv2
774
+ # mask = cv2.resize(mask[0, 0].cpu().float().numpy(), (1024, 1024))
775
+ # imageio.imwrite("mask_%d.jpg" % (0), np.uint8(255. * mask))
776
+
777
+ return block_mask
778
+
779
+ def _forward_inference(
780
+ self,
781
+ x,
782
+ t,
783
+ context,
784
+ seq_len,
785
+ updating_cache=False,
786
+ clip_fea=None,
787
+ y=None,
788
+ kv_cache: dict = None,
789
+ crossattn_cache: dict = None,
790
+ current_start: int = 0,
791
+ cache_start: int = 0,
792
+ ):
793
+ r"""
794
+ Run the diffusion model with kv caching.
795
+ See Algorithm 2 of CausVid paper https://arxiv.org/abs/2412.07772 for details.
796
+ This function will be run for num_frame times.
797
+ Process the latent frames one by one (1560 tokens each)
798
+
799
+ Args:
800
+ x (List[Tensor]):
801
+ List of input video tensors, each with shape [C_in, F, H, W]
802
+ t (Tensor):
803
+ Diffusion timesteps tensor of shape [B]
804
+ context (List[Tensor]):
805
+ List of text embeddings each with shape [L, C]
806
+ seq_len (`int`):
807
+ Maximum sequence length for positional encoding
808
+ clip_fea (Tensor, *optional*):
809
+ CLIP image features for image-to-video mode
810
+ y (List[Tensor], *optional*):
811
+ Conditional video inputs for image-to-video mode, same shape as x
812
+
813
+ Returns:
814
+ List[Tensor]:
815
+ List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
816
+ """
817
+
818
+ if self.model_type == 'i2v':
819
+ assert clip_fea is not None and y is not None
820
+ # params
821
+ device = self.patch_embedding.weight.device
822
+ if self.freqs.device != device:
823
+ self.freqs = self.freqs.to(device)
824
+
825
+ if y is not None:
826
+ x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
827
+
828
+ # embeddings
829
+ x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
830
+ grid_sizes = torch.stack(
831
+ [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
832
+ x = [u.flatten(2).transpose(1, 2) for u in x]
833
+ seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
834
+ assert seq_lens.max() <= seq_len
835
+ x = torch.cat(x)
836
+ """
837
+ torch.cat([
838
+ torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
839
+ dim=1) for u in x
840
+ ])
841
+ """
842
+
843
+ # time embeddings
844
+ # with amp.autocast(dtype=torch.float32):
845
+ e = self.time_embedding(
846
+ sinusoidal_embedding_1d(self.freq_dim, t.flatten()).type_as(x))
847
+ e0 = self.time_projection(e).unflatten(
848
+ 1, (6, self.dim)).unflatten(dim=0, sizes=t.shape)
849
+ # assert e.dtype == torch.float32 and e0.dtype == torch.float32
850
+
851
+ # context
852
+ context_lens = None
853
+ context = self.text_embedding(
854
+ torch.stack([
855
+ torch.cat(
856
+ [u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
857
+ for u in context
858
+ ]))
859
+
860
+ if clip_fea is not None:
861
+ context_clip = self.img_emb(clip_fea) # bs x 257 x dim
862
+ context = torch.concat([context_clip, context], dim=1)
863
+
864
+ # arguments
865
+ kwargs = dict(
866
+ e=e0,
867
+ seq_lens=seq_lens,
868
+ grid_sizes=grid_sizes,
869
+ freqs=self.freqs,
870
+ context=context,
871
+ context_lens=context_lens,
872
+ block_mask=self.block_mask,
873
+ updating_cache=updating_cache,
874
+ )
875
+
876
+ def create_custom_forward(module):
877
+ def custom_forward(*inputs, **kwargs):
878
+ return module(*inputs, **kwargs)
879
+ return custom_forward
880
+
881
+ for block_index, block in enumerate(self.blocks):
882
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
883
+ kwargs.update(
884
+ {
885
+ "kv_cache": kv_cache[block_index],
886
+ "current_start": current_start,
887
+ "cache_start": cache_start
888
+ }
889
+ )
890
+ x = torch.utils.checkpoint.checkpoint(
891
+ create_custom_forward(block),
892
+ x, **kwargs,
893
+ use_reentrant=False,
894
+ )
895
+ else:
896
+ kwargs.update(
897
+ {
898
+ "kv_cache": kv_cache[block_index],
899
+ "crossattn_cache": crossattn_cache[block_index],
900
+ "current_start": current_start,
901
+ "cache_start": cache_start
902
+ }
903
+ )
904
+ x = block(x, **kwargs)
905
+
906
+ # head
907
+ x = self.head(x, e.unflatten(dim=0, sizes=t.shape).unsqueeze(2))
908
+ # unpatchify
909
+ x = self.unpatchify(x, grid_sizes)
910
+ return torch.stack(x)
911
+
912
+ def _forward_train(
913
+ self,
914
+ x,
915
+ t,
916
+ context,
917
+ seq_len,
918
+ clean_x=None,
919
+ aug_t=None,
920
+ clip_fea=None,
921
+ y=None,
922
+ ):
923
+ r"""
924
+ Forward pass through the diffusion model
925
+
926
+ Args:
927
+ x (List[Tensor]):
928
+ List of input video tensors, each with shape [C_in, F, H, W]
929
+ t (Tensor):
930
+ Diffusion timesteps tensor of shape [B]
931
+ context (List[Tensor]):
932
+ List of text embeddings each with shape [L, C]
933
+ seq_len (`int`):
934
+ Maximum sequence length for positional encoding
935
+ clip_fea (Tensor, *optional*):
936
+ CLIP image features for image-to-video mode
937
+ y (List[Tensor], *optional*):
938
+ Conditional video inputs for image-to-video mode, same shape as x
939
+
940
+ Returns:
941
+ List[Tensor]:
942
+ List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
943
+ """
944
+ if self.model_type == 'i2v':
945
+ assert clip_fea is not None and y is not None
946
+ # params
947
+ device = self.patch_embedding.weight.device
948
+ if self.freqs.device != device:
949
+ self.freqs = self.freqs.to(device)
950
+
951
+ # Construct blockwise causal attn mask
952
+ if self.block_mask is None:
953
+ if clean_x is not None:
954
+ if self.independent_first_frame:
955
+ raise NotImplementedError()
956
+ else:
957
+ self.block_mask = self._prepare_teacher_forcing_mask(
958
+ device, num_frames=x.shape[2],
959
+ frame_seqlen=x.shape[-2] * x.shape[-1] // (self.patch_size[1] * self.patch_size[2]),
960
+ num_frame_per_block=self.num_frame_per_block
961
+ )
962
+ else:
963
+ if self.independent_first_frame:
964
+ self.block_mask = self._prepare_blockwise_causal_attn_mask_i2v(
965
+ device, num_frames=x.shape[2],
966
+ frame_seqlen=x.shape[-2] * x.shape[-1] // (self.patch_size[1] * self.patch_size[2]),
967
+ num_frame_per_block=self.num_frame_per_block,
968
+ local_attn_size=self.local_attn_size
969
+ )
970
+ else:
971
+ self.block_mask = self._prepare_blockwise_causal_attn_mask(
972
+ device, num_frames=x.shape[2],
973
+ frame_seqlen=x.shape[-2] * x.shape[-1] // (self.patch_size[1] * self.patch_size[2]),
974
+ num_frame_per_block=self.num_frame_per_block,
975
+ local_attn_size=self.local_attn_size
976
+ )
977
+
978
+ if y is not None:
979
+ x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
980
+
981
+ # embeddings
982
+ x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
983
+
984
+ grid_sizes = torch.stack(
985
+ [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
986
+ x = [u.flatten(2).transpose(1, 2) for u in x]
987
+
988
+ seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
989
+ assert seq_lens.max() <= seq_len
990
+ x = torch.cat([
991
+ torch.cat([u, u.new_zeros(1, seq_lens[0] - u.size(1), u.size(2))],
992
+ dim=1) for u in x
993
+ ])
994
+
995
+ # time embeddings
996
+ # with amp.autocast(dtype=torch.float32):
997
+ e = self.time_embedding(
998
+ sinusoidal_embedding_1d(self.freq_dim, t.flatten()).type_as(x))
999
+ e0 = self.time_projection(e).unflatten(
1000
+ 1, (6, self.dim)).unflatten(dim=0, sizes=t.shape)
1001
+ # assert e.dtype == torch.float32 and e0.dtype == torch.float32
1002
+
1003
+ # context
1004
+ context_lens = None
1005
+ context = self.text_embedding(
1006
+ torch.stack([
1007
+ torch.cat(
1008
+ [u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
1009
+ for u in context
1010
+ ]))
1011
+
1012
+ if clip_fea is not None:
1013
+ context_clip = self.img_emb(clip_fea) # bs x 257 x dim
1014
+ context = torch.concat([context_clip, context], dim=1)
1015
+
1016
+ if clean_x is not None:
1017
+ clean_x = [self.patch_embedding(u.unsqueeze(0)) for u in clean_x]
1018
+ clean_x = [u.flatten(2).transpose(1, 2) for u in clean_x]
1019
+
1020
+ seq_lens_clean = torch.tensor([u.size(1) for u in clean_x], dtype=torch.long)
1021
+ assert seq_lens_clean.max() <= seq_len
1022
+ clean_x = torch.cat([
1023
+ torch.cat([u, u.new_zeros(1, seq_lens_clean[0] - u.size(1), u.size(2))], dim=1) for u in clean_x
1024
+ ])
1025
+
1026
+ x = torch.cat([clean_x, x], dim=1)
1027
+ if aug_t is None:
1028
+ aug_t = torch.zeros_like(t)
1029
+ e_clean = self.time_embedding(
1030
+ sinusoidal_embedding_1d(self.freq_dim, aug_t.flatten()).type_as(x))
1031
+ e0_clean = self.time_projection(e_clean).unflatten(
1032
+ 1, (6, self.dim)).unflatten(dim=0, sizes=t.shape)
1033
+ e0 = torch.cat([e0_clean, e0], dim=1)
1034
+
1035
+ # arguments
1036
+ kwargs = dict(
1037
+ e=e0,
1038
+ seq_lens=seq_lens,
1039
+ grid_sizes=grid_sizes,
1040
+ freqs=self.freqs,
1041
+ context=context,
1042
+ context_lens=context_lens,
1043
+ block_mask=self.block_mask)
1044
+
1045
+ def create_custom_forward(module):
1046
+ def custom_forward(*inputs, **kwargs):
1047
+ return module(*inputs, **kwargs)
1048
+ return custom_forward
1049
+
1050
+ for block in self.blocks:
1051
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
1052
+ x = torch.utils.checkpoint.checkpoint(
1053
+ create_custom_forward(block),
1054
+ x, **kwargs,
1055
+ use_reentrant=False,
1056
+ )
1057
+ else:
1058
+ x = block(x, **kwargs)
1059
+
1060
+ if clean_x is not None:
1061
+ x = x[:, x.shape[1] // 2:]
1062
+
1063
+ # head
1064
+ x = self.head(x, e.unflatten(dim=0, sizes=t.shape).unsqueeze(2))
1065
+
1066
+ # unpatchify
1067
+ x = self.unpatchify(x, grid_sizes)
1068
+ return torch.stack(x)
1069
+
1070
+ def forward(
1071
+ self,
1072
+ *args,
1073
+ **kwargs
1074
+ ):
1075
+ if kwargs.get('kv_cache', None) is not None:
1076
+ return self._forward_inference(*args, **kwargs)
1077
+ else:
1078
+ return self._forward_train(*args, **kwargs)
1079
+
1080
+ def unpatchify(self, x, grid_sizes):
1081
+ r"""
1082
+ Reconstruct video tensors from patch embeddings.
1083
+
1084
+ Args:
1085
+ x (List[Tensor]):
1086
+ List of patchified features, each with shape [L, C_out * prod(patch_size)]
1087
+ grid_sizes (Tensor):
1088
+ Original spatial-temporal grid dimensions before patching,
1089
+ shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)
1090
+
1091
+ Returns:
1092
+ List[Tensor]:
1093
+ Reconstructed video tensors with shape [C_out, F, H / 8, W / 8]
1094
+ """
1095
+
1096
+ c = self.out_dim
1097
+ out = []
1098
+ for u, v in zip(x, grid_sizes.tolist()):
1099
+ u = u[:math.prod(v)].view(*v, *self.patch_size, c)
1100
+ u = torch.einsum('fhwpqrc->cfphqwr', u)
1101
+ u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
1102
+ out.append(u)
1103
+ return out
1104
+
1105
+ def init_weights(self):
1106
+ r"""
1107
+ Initialize model parameters using Xavier initialization.
1108
+ """
1109
+
1110
+ # basic init
1111
+ for m in self.modules():
1112
+ if isinstance(m, nn.Linear):
1113
+ nn.init.xavier_uniform_(m.weight)
1114
+ if m.bias is not None:
1115
+ nn.init.zeros_(m.bias)
1116
+
1117
+ # init embeddings
1118
+ nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1))
1119
+ for m in self.text_embedding.modules():
1120
+ if isinstance(m, nn.Linear):
1121
+ nn.init.normal_(m.weight, std=.02)
1122
+ for m in self.time_embedding.modules():
1123
+ if isinstance(m, nn.Linear):
1124
+ nn.init.normal_(m.weight, std=.02)
1125
+
1126
+ # init output layer
1127
+ nn.init.zeros_(self.head.head.weight)