Spaces:
Running
on
Zero
Running
on
Zero
kunhaokhliu
commited on
Commit
·
5d2a97a
1
Parent(s):
21a626f
Add application file
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +8 -0
- LICENSE +81 -0
- README.md +1 -1
- app.py +187 -0
- configs/default_config.yaml +20 -0
- configs/rolling_forcing_dmd.yaml +48 -0
- inference.py +197 -0
- model/__init__.py +14 -0
- model/base.py +230 -0
- model/causvid.py +391 -0
- model/diffusion.py +125 -0
- model/dmd.py +332 -0
- model/gan.py +295 -0
- model/ode_regression.py +138 -0
- model/sid.py +283 -0
- pipeline/__init__.py +13 -0
- pipeline/bidirectional_diffusion_inference.py +110 -0
- pipeline/bidirectional_inference.py +71 -0
- pipeline/causal_diffusion_inference.py +342 -0
- pipeline/rolling_forcing_inference.py +372 -0
- pipeline/rolling_forcing_training.py +464 -0
- prompts/example_prompts.txt +16 -0
- requirements.txt +45 -0
- train.py +45 -0
- trainer/__init__.py +11 -0
- trainer/diffusion.py +265 -0
- trainer/distillation.py +398 -0
- trainer/gan.py +464 -0
- trainer/ode.py +242 -0
- utils/dataset.py +220 -0
- utils/distributed.py +125 -0
- utils/lmdb.py +72 -0
- utils/loss.py +81 -0
- utils/misc.py +39 -0
- utils/scheduler.py +194 -0
- utils/wan_wrapper.py +313 -0
- wan/README.md +2 -0
- wan/__init__.py +3 -0
- wan/configs/__init__.py +42 -0
- wan/configs/shared_config.py +19 -0
- wan/configs/wan_i2v_14B.py +35 -0
- wan/configs/wan_t2v_14B.py +29 -0
- wan/configs/wan_t2v_1_3B.py +29 -0
- wan/distributed/__init__.py +0 -0
- wan/distributed/fsdp.py +33 -0
- wan/distributed/xdit_context_parallel.py +192 -0
- wan/image2video.py +347 -0
- wan/modules/__init__.py +16 -0
- wan/modules/attention.py +185 -0
- wan/modules/causal_model.py +1127 -0
.gitignore
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__
|
| 2 |
+
*.egg-info
|
| 3 |
+
.cache
|
| 4 |
+
|
| 5 |
+
wan_models
|
| 6 |
+
checkpoints
|
| 7 |
+
videos
|
| 8 |
+
logs
|
LICENSE
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Tencent is pleased to support the community by making RollingForcing available.
|
| 2 |
+
|
| 3 |
+
Copyright (C) 2025 Tencent. All rights reserved.
|
| 4 |
+
|
| 5 |
+
The open-source software and/or models included in this distribution may have been modified by Tencent (“Tencent Modifications”). All Tencent Modifications are Copyright (C) Tencent.
|
| 6 |
+
|
| 7 |
+
RollingForcing is licensed under the License Terms of RollingForcing, except for the third-party components listed below, which remain licensed under their respective original terms. RollingForcing does not impose any additional restrictions beyond those specified in the original licenses of these third-party components. Users are required to comply with all applicable terms and conditions of the original licenses and to ensure that the use of these third-party components conforms to all relevant laws and regulations.
|
| 8 |
+
|
| 9 |
+
For the avoidance of doubt, RollingForcing refers solely to training code, inference code, parameters, and weights made publicly available by Tencent in accordance with the License Terms of RollingForcing.
|
| 10 |
+
|
| 11 |
+
Terms of the License Terms of RollingForcing:
|
| 12 |
+
--------------------------------------------------------------------
|
| 13 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, and /or sublicense copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
| 14 |
+
|
| 15 |
+
- You agree to use RollingForcing only for academic purposes, and refrain from using it for any commercial or production purposes under any circumstances.
|
| 16 |
+
|
| 17 |
+
- The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
| 18 |
+
|
| 19 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
Dependencies and Licenses:
|
| 24 |
+
|
| 25 |
+
This open-source project, RollingForcing, builds upon the following open-source models and/or software components, each of which remains licensed under its original license. Certain models or software may include modifications made by Tencent (“Tencent Modifications”), which are Copyright (C) Tencent.
|
| 26 |
+
|
| 27 |
+
In case you believe there have been errors in the attribution below, you may submit the concerns to us for review and correction.
|
| 28 |
+
|
| 29 |
+
Open Source Model Licensed under the Apache-2.0:
|
| 30 |
+
--------------------------------------------------------------------
|
| 31 |
+
1. Wan-AI/Wan2.1-T2V-1.3B
|
| 32 |
+
Copyright (c) 2025 Wan Team
|
| 33 |
+
|
| 34 |
+
Terms of the Apache-2.0:
|
| 35 |
+
--------------------------------------------------------------------
|
| 36 |
+
Apache License
|
| 37 |
+
Version 2.0, January 2004
|
| 38 |
+
http://www.apache.org/licenses/
|
| 39 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 40 |
+
|
| 41 |
+
Definitions.
|
| 42 |
+
|
| 43 |
+
"License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document.
|
| 44 |
+
|
| 45 |
+
"Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License.
|
| 46 |
+
|
| 47 |
+
"Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
|
| 48 |
+
|
| 49 |
+
"You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License.
|
| 50 |
+
|
| 51 |
+
"Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files.
|
| 52 |
+
|
| 53 |
+
"Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types.
|
| 54 |
+
|
| 55 |
+
"Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below).
|
| 56 |
+
|
| 57 |
+
"Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof.
|
| 58 |
+
|
| 59 |
+
"Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution."
|
| 60 |
+
|
| 61 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work.
|
| 62 |
+
|
| 63 |
+
Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form.
|
| 64 |
+
|
| 65 |
+
Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed.
|
| 66 |
+
|
| 67 |
+
Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:
|
| 68 |
+
|
| 69 |
+
(a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License.
|
| 70 |
+
|
| 71 |
+
Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions.
|
| 72 |
+
|
| 73 |
+
Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file.
|
| 74 |
+
|
| 75 |
+
Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License.
|
| 76 |
+
|
| 77 |
+
Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
|
| 78 |
+
|
| 79 |
+
Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
|
| 80 |
+
|
| 81 |
+
END OF TERMS AND CONDITIONS
|
README.md
CHANGED
|
@@ -8,7 +8,7 @@ sdk_version: 5.49.1
|
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
license: other
|
| 11 |
-
short_description: 'Rolling Forcing: Autoregressive Long Video Diffusion in Real'
|
| 12 |
---
|
| 13 |
|
| 14 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
license: other
|
| 11 |
+
short_description: 'Rolling Forcing: Autoregressive Long Video Diffusion in Real Time'
|
| 12 |
---
|
| 13 |
|
| 14 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import argparse
|
| 3 |
+
import time
|
| 4 |
+
from typing import Optional
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torchvision.io import write_video
|
| 8 |
+
from omegaconf import OmegaConf
|
| 9 |
+
from einops import rearrange
|
| 10 |
+
import app as gr
|
| 11 |
+
|
| 12 |
+
from pipeline import CausalInferencePipeline
|
| 13 |
+
from huggingface_hub import snapshot_download, hf_hub_download
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# -----------------------------
|
| 17 |
+
# Globals (loaded once per process)
|
| 18 |
+
# -----------------------------
|
| 19 |
+
_PIPELINE: Optional[torch.nn.Module] = None
|
| 20 |
+
_DEVICE: Optional[torch.device] = None
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def _ensure_gpu():
|
| 24 |
+
if not torch.cuda.is_available():
|
| 25 |
+
raise gr.Error("CUDA GPU is required to run this demo. Please run on a machine with an NVIDIA GPU.")
|
| 26 |
+
# Bind to GPU:0 by default
|
| 27 |
+
torch.cuda.set_device(0)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def _load_pipeline(config_path: str, checkpoint_path: Optional[str], use_ema: bool) -> torch.nn.Module:
|
| 31 |
+
global _PIPELINE, _DEVICE
|
| 32 |
+
if _PIPELINE is not None:
|
| 33 |
+
return _PIPELINE
|
| 34 |
+
|
| 35 |
+
_ensure_gpu()
|
| 36 |
+
_DEVICE = torch.device("cuda:0")
|
| 37 |
+
|
| 38 |
+
# Load and merge configs
|
| 39 |
+
config = OmegaConf.load(config_path)
|
| 40 |
+
default_config = OmegaConf.load("configs/default_config.yaml")
|
| 41 |
+
config = OmegaConf.merge(default_config, config)
|
| 42 |
+
|
| 43 |
+
# Choose pipeline type based on config
|
| 44 |
+
pipeline = CausalInferencePipeline(config, device=_DEVICE)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
# Load checkpoint if provided
|
| 48 |
+
if checkpoint_path and os.path.exists(checkpoint_path):
|
| 49 |
+
state_dict = torch.load(checkpoint_path, map_location="cpu")
|
| 50 |
+
if use_ema and 'generator_ema' in state_dict:
|
| 51 |
+
state_dict_to_load = state_dict['generator_ema']
|
| 52 |
+
# Remove possible FSDP prefix
|
| 53 |
+
from collections import OrderedDict
|
| 54 |
+
new_state_dict = OrderedDict()
|
| 55 |
+
for k, v in state_dict_to_load.items():
|
| 56 |
+
new_state_dict[k.replace("_fsdp_wrapped_module.", "")] = v
|
| 57 |
+
state_dict_to_load = new_state_dict
|
| 58 |
+
else:
|
| 59 |
+
state_dict_to_load = state_dict.get('generator', state_dict)
|
| 60 |
+
pipeline.generator.load_state_dict(state_dict_to_load, strict=False)
|
| 61 |
+
|
| 62 |
+
# The codebase assumes bfloat16 on GPU
|
| 63 |
+
pipeline = pipeline.to(device=_DEVICE, dtype=torch.bfloat16)
|
| 64 |
+
pipeline.eval()
|
| 65 |
+
|
| 66 |
+
# Quick sanity path check for Wan models to give friendly errors
|
| 67 |
+
wan_dir = os.path.join('wan_models', 'Wan2.1-T2V-1.3B')
|
| 68 |
+
if not os.path.isdir(wan_dir):
|
| 69 |
+
raise gr.Error(
|
| 70 |
+
"Wan2.1-T2V-1.3B not found at 'wan_models/Wan2.1-T2V-1.3B'.\n"
|
| 71 |
+
"Please download it first, e.g.:\n"
|
| 72 |
+
"huggingface-cli download Wan-AI/Wan2.1-T2V-1.3B --local-dir-use-symlinks False --local-dir wan_models/Wan2.1-T2V-1.3B"
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
_PIPELINE = pipeline
|
| 76 |
+
return _PIPELINE
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def build_predict(config_path: str, checkpoint_path: Optional[str], output_dir: str, use_ema: bool):
|
| 80 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 81 |
+
|
| 82 |
+
def predict(prompt: str, num_frames: int) -> str:
|
| 83 |
+
if not prompt or not prompt.strip():
|
| 84 |
+
raise gr.Error("Please enter a non-empty text prompt.")
|
| 85 |
+
|
| 86 |
+
num_frames = int(num_frames)
|
| 87 |
+
if num_frames % 3 != 0 or not (21 <= num_frames <= 252):
|
| 88 |
+
raise gr.Error("Number of frames must be a multiple of 3 between 21 and 252.")
|
| 89 |
+
|
| 90 |
+
pipeline = _load_pipeline(config_path, checkpoint_path, use_ema)
|
| 91 |
+
|
| 92 |
+
# Prepare inputs
|
| 93 |
+
prompts = [prompt.strip()]
|
| 94 |
+
noise = torch.randn([1, num_frames, 16, 60, 104], device=_DEVICE, dtype=torch.bfloat16)
|
| 95 |
+
|
| 96 |
+
torch.set_grad_enabled(False)
|
| 97 |
+
with torch.inference_mode(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
| 98 |
+
video = pipeline.inference_rolling_forcing(
|
| 99 |
+
noise=noise,
|
| 100 |
+
text_prompts=prompts,
|
| 101 |
+
return_latents=False,
|
| 102 |
+
initial_latent=None,
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
# video: [B=1, T, C, H, W] in [0,1]
|
| 106 |
+
video = rearrange(video, 'b t c h w -> b t h w c')[0]
|
| 107 |
+
video_uint8 = (video * 255.0).clamp(0, 255).to(torch.uint8).cpu()
|
| 108 |
+
|
| 109 |
+
# Save to a unique filepath
|
| 110 |
+
safe_stub = prompt[:60].replace(' ', '_').replace('/', '_')
|
| 111 |
+
ts = int(time.time())
|
| 112 |
+
filepath = os.path.join(output_dir, f"{safe_stub or 'video'}_{ts}.mp4")
|
| 113 |
+
write_video(filepath, video_uint8, fps=16)
|
| 114 |
+
print(f"Saved generated video to {filepath}")
|
| 115 |
+
|
| 116 |
+
return filepath
|
| 117 |
+
|
| 118 |
+
return predict
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def main():
|
| 122 |
+
parser = argparse.ArgumentParser()
|
| 123 |
+
parser.add_argument('--config_path', type=str, default='configs/rolling_forcing_dmd.yaml',
|
| 124 |
+
help='Path to the model config')
|
| 125 |
+
parser.add_argument('--checkpoint_path', type=str, default='checkpoints/rolling_forcing_dmd.pt',
|
| 126 |
+
help='Path to rolling forcing checkpoint (.pt). If missing, will run with base weights only if available.')
|
| 127 |
+
parser.add_argument('--output_dir', type=str, default='videos/gradio', help='Where to save generated videos')
|
| 128 |
+
parser.add_argument('--no_ema', action='store_true', help='Disable EMA weights when loading checkpoint')
|
| 129 |
+
args = parser.parse_args()
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
# Download checkpoint from HuggingFace if not present
|
| 133 |
+
# 1️⃣ Equivalent to:
|
| 134 |
+
# huggingface-cli download Wan-AI/Wan2.1-T2V-1.3B --local-dir wan_models/Wan2.1-T2V-1.3B
|
| 135 |
+
wan_model_dir = snapshot_download(
|
| 136 |
+
repo_id="Wan-AI/Wan2.1-T2V-1.3B",
|
| 137 |
+
local_dir="wan_models/Wan2.1-T2V-1.3B",
|
| 138 |
+
local_dir_use_symlinks=False, # same as --local-dir-use-symlinks False
|
| 139 |
+
)
|
| 140 |
+
print("Wan model downloaded to:", wan_model_dir)
|
| 141 |
+
|
| 142 |
+
# 2️⃣ Equivalent to:
|
| 143 |
+
# huggingface-cli download TencentARC/RollingForcing checkpoints/rolling_forcing_dmd.pt --local-dir .
|
| 144 |
+
rolling_ckpt_path = hf_hub_download(
|
| 145 |
+
repo_id="TencentARC/RollingForcing",
|
| 146 |
+
filename="checkpoints/rolling_forcing_dmd.pt",
|
| 147 |
+
local_dir=".", # where to store it
|
| 148 |
+
local_dir_use_symlinks=False,
|
| 149 |
+
)
|
| 150 |
+
print("RollingForcing checkpoint downloaded to:", rolling_ckpt_path)
|
| 151 |
+
|
| 152 |
+
predict = build_predict(
|
| 153 |
+
config_path=args.config_path,
|
| 154 |
+
checkpoint_path=args.checkpoint_path,
|
| 155 |
+
output_dir=args.output_dir,
|
| 156 |
+
use_ema=not args.no_ema,
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
demo = gr.Interface(
|
| 160 |
+
fn=predict,
|
| 161 |
+
inputs=[
|
| 162 |
+
gr.Textbox(label="Text Prompt", lines=2, placeholder="A cinematic shot of a girl dancing in the sunset."),
|
| 163 |
+
gr.Slider(label="Number of Latent Frames", minimum=21, maximum=252, step=3, value=21),
|
| 164 |
+
],
|
| 165 |
+
outputs=gr.Video(label="Generated Video", format="mp4"),
|
| 166 |
+
title="Rolling Forcing: Autoregressive Long Video Diffusion in Real Time",
|
| 167 |
+
description=(
|
| 168 |
+
"Enter a prompt and generate a video using the Rolling Forcing pipeline.\n"
|
| 169 |
+
"**Note:** although Rolling Forcing generates videos autoregressivelty, current Gradio demo does not support streaming outputs, so the entire video will be generated before it is displayed.\n"
|
| 170 |
+
"\n"
|
| 171 |
+
"If you find this demo useful, please consider giving it a ⭐ star on [GitHub](https://github.com/TencentARC/RollingForcing)--your support is crucial for sustaining this open-source project. "
|
| 172 |
+
"You can also dive deeper by reading the [paper](https://arxiv.org/abs/2509.25161) or exploring the [project page](https://kunhao-liu.github.io/Rolling_Forcing_Webpage) for more details."
|
| 173 |
+
),
|
| 174 |
+
allow_flagging='never',
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
try:
|
| 178 |
+
# Gradio <= 3.x
|
| 179 |
+
demo.queue(concurrency_count=1, max_size=2)
|
| 180 |
+
except TypeError:
|
| 181 |
+
# Gradio >= 4.x
|
| 182 |
+
demo.queue(max_size=2)
|
| 183 |
+
demo.launch(show_error=True)
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
if __name__ == "__main__":
|
| 187 |
+
main()
|
configs/default_config.yaml
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
independent_first_frame: false
|
| 2 |
+
warp_denoising_step: false
|
| 3 |
+
weight_decay: 0.01
|
| 4 |
+
same_step_across_blocks: true
|
| 5 |
+
discriminator_lr_multiplier: 1.0
|
| 6 |
+
last_step_only: false
|
| 7 |
+
i2v: false
|
| 8 |
+
num_training_frames: 27
|
| 9 |
+
gc_interval: 100
|
| 10 |
+
context_noise: 0
|
| 11 |
+
causal: true
|
| 12 |
+
|
| 13 |
+
ckpt_step: 0
|
| 14 |
+
prompt_name: MovieGenVideoBench
|
| 15 |
+
prompt_path: prompts/MovieGenVideoBench.txt
|
| 16 |
+
eval_first_n: 64
|
| 17 |
+
num_samples: 1
|
| 18 |
+
height: 480
|
| 19 |
+
width: 832
|
| 20 |
+
num_frames: 81
|
configs/rolling_forcing_dmd.yaml
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
generator_ckpt: checkpoints/ode_init.pt
|
| 2 |
+
generator_fsdp_wrap_strategy: size
|
| 3 |
+
real_score_fsdp_wrap_strategy: size
|
| 4 |
+
fake_score_fsdp_wrap_strategy: size
|
| 5 |
+
real_name: Wan2.1-T2V-14B
|
| 6 |
+
text_encoder_fsdp_wrap_strategy: size
|
| 7 |
+
denoising_step_list:
|
| 8 |
+
- 1000
|
| 9 |
+
- 800
|
| 10 |
+
- 600
|
| 11 |
+
- 400
|
| 12 |
+
- 200
|
| 13 |
+
warp_denoising_step: true # need to remove - 0 in denoising_step_list if warp_denoising_step is true
|
| 14 |
+
ts_schedule: false
|
| 15 |
+
num_train_timestep: 1000
|
| 16 |
+
timestep_shift: 5.0
|
| 17 |
+
guidance_scale: 3.0
|
| 18 |
+
denoising_loss_type: flow
|
| 19 |
+
mixed_precision: true
|
| 20 |
+
seed: 0
|
| 21 |
+
sharding_strategy: hybrid_full
|
| 22 |
+
lr: 1.5e-06
|
| 23 |
+
lr_critic: 4.0e-07
|
| 24 |
+
beta1: 0.0
|
| 25 |
+
beta2: 0.999
|
| 26 |
+
beta1_critic: 0.0
|
| 27 |
+
beta2_critic: 0.999
|
| 28 |
+
data_path: prompts/vidprom_filtered_extended.txt
|
| 29 |
+
batch_size: 1
|
| 30 |
+
ema_weight: 0.99
|
| 31 |
+
ema_start_step: 200
|
| 32 |
+
total_batch_size: 64
|
| 33 |
+
log_iters: 100
|
| 34 |
+
negative_prompt: '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走'
|
| 35 |
+
dfake_gen_update_ratio: 5
|
| 36 |
+
image_or_video_shape:
|
| 37 |
+
- 1
|
| 38 |
+
- 21
|
| 39 |
+
- 16
|
| 40 |
+
- 60
|
| 41 |
+
- 104
|
| 42 |
+
distribution_loss: dmd
|
| 43 |
+
trainer: score_distillation
|
| 44 |
+
gradient_checkpointing: true
|
| 45 |
+
num_frame_per_block: 3
|
| 46 |
+
load_raw_video: false
|
| 47 |
+
model_kwargs:
|
| 48 |
+
timestep_shift: 5.0
|
inference.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import torch
|
| 3 |
+
import os
|
| 4 |
+
from omegaconf import OmegaConf
|
| 5 |
+
from collections import OrderedDict
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
from torchvision import transforms
|
| 8 |
+
from torchvision.io import write_video
|
| 9 |
+
from einops import rearrange
|
| 10 |
+
import torch.distributed as dist
|
| 11 |
+
import imageio
|
| 12 |
+
from torch.utils.data import DataLoader, SequentialSampler
|
| 13 |
+
from torch.utils.data.distributed import DistributedSampler
|
| 14 |
+
|
| 15 |
+
from pipeline import (
|
| 16 |
+
CausalDiffusionInferencePipeline,
|
| 17 |
+
CausalInferencePipeline
|
| 18 |
+
)
|
| 19 |
+
from utils.dataset import TextDataset, TextImagePairDataset
|
| 20 |
+
from utils.misc import set_seed
|
| 21 |
+
|
| 22 |
+
parser = argparse.ArgumentParser()
|
| 23 |
+
parser.add_argument("--config_path", type=str, help="Path to the config file")
|
| 24 |
+
parser.add_argument("--checkpoint_path", type=str, help="Path to the checkpoint folder")
|
| 25 |
+
parser.add_argument("--data_path", type=str, help="Path to the dataset")
|
| 26 |
+
parser.add_argument("--extended_prompt_path", type=str, help="Path to the extended prompt")
|
| 27 |
+
parser.add_argument("--output_folder", type=str, help="Output folder")
|
| 28 |
+
parser.add_argument("--num_output_frames", type=int, default=21,
|
| 29 |
+
help="Number of overlap frames between sliding windows")
|
| 30 |
+
parser.add_argument("--i2v", action="store_true", help="Whether to perform I2V (or T2V by default)")
|
| 31 |
+
parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA parameters")
|
| 32 |
+
parser.add_argument("--seed", type=int, default=0, help="Random seed")
|
| 33 |
+
parser.add_argument("--num_samples", type=int, default=1, help="Number of samples to generate per prompt")
|
| 34 |
+
parser.add_argument("--save_with_index", action="store_true",
|
| 35 |
+
help="Whether to save the video using the index or prompt as the filename")
|
| 36 |
+
args = parser.parse_args()
|
| 37 |
+
|
| 38 |
+
# Initialize distributed inference
|
| 39 |
+
if "LOCAL_RANK" in os.environ:
|
| 40 |
+
dist.init_process_group(backend='nccl')
|
| 41 |
+
local_rank = int(os.environ["LOCAL_RANK"])
|
| 42 |
+
torch.cuda.set_device(local_rank)
|
| 43 |
+
device = torch.device(f"cuda:{local_rank}")
|
| 44 |
+
world_size = dist.get_world_size()
|
| 45 |
+
set_seed(args.seed + local_rank)
|
| 46 |
+
else:
|
| 47 |
+
device = torch.device("cuda")
|
| 48 |
+
local_rank = 0
|
| 49 |
+
world_size = 1
|
| 50 |
+
set_seed(args.seed)
|
| 51 |
+
|
| 52 |
+
torch.set_grad_enabled(False)
|
| 53 |
+
|
| 54 |
+
config = OmegaConf.load(args.config_path)
|
| 55 |
+
default_config = OmegaConf.load("configs/default_config.yaml")
|
| 56 |
+
config = OmegaConf.merge(default_config, config)
|
| 57 |
+
|
| 58 |
+
# Initialize pipeline
|
| 59 |
+
if hasattr(config, 'denoising_step_list'):
|
| 60 |
+
# Few-step inference
|
| 61 |
+
pipeline = CausalInferencePipeline(config, device=device)
|
| 62 |
+
else:
|
| 63 |
+
# Multi-step diffusion inference
|
| 64 |
+
pipeline = CausalDiffusionInferencePipeline(config, device=device)
|
| 65 |
+
|
| 66 |
+
if args.checkpoint_path:
|
| 67 |
+
state_dict = torch.load(args.checkpoint_path, map_location="cpu")
|
| 68 |
+
if args.use_ema:
|
| 69 |
+
state_dict_to_load = state_dict['generator_ema']
|
| 70 |
+
def remove_fsdp_prefix(state_dict):
|
| 71 |
+
new_state_dict = OrderedDict()
|
| 72 |
+
for key, value in state_dict.items():
|
| 73 |
+
if "_fsdp_wrapped_module." in key:
|
| 74 |
+
new_key = key.replace("_fsdp_wrapped_module.", "")
|
| 75 |
+
new_state_dict[new_key] = value
|
| 76 |
+
else:
|
| 77 |
+
new_state_dict[key] = value
|
| 78 |
+
return new_state_dict
|
| 79 |
+
state_dict_to_load = remove_fsdp_prefix(state_dict_to_load)
|
| 80 |
+
else:
|
| 81 |
+
state_dict_to_load = state_dict['generator']
|
| 82 |
+
pipeline.generator.load_state_dict(state_dict_to_load)
|
| 83 |
+
|
| 84 |
+
pipeline = pipeline.to(device=device, dtype=torch.bfloat16)
|
| 85 |
+
|
| 86 |
+
# Create dataset
|
| 87 |
+
if args.i2v:
|
| 88 |
+
assert not dist.is_initialized(), "I2V does not support distributed inference yet"
|
| 89 |
+
transform = transforms.Compose([
|
| 90 |
+
transforms.Resize((480, 832)),
|
| 91 |
+
transforms.ToTensor(),
|
| 92 |
+
transforms.Normalize([0.5], [0.5])
|
| 93 |
+
])
|
| 94 |
+
dataset = TextImagePairDataset(args.data_path, transform=transform)
|
| 95 |
+
else:
|
| 96 |
+
dataset = TextDataset(prompt_path=args.data_path, extended_prompt_path=args.extended_prompt_path)
|
| 97 |
+
num_prompts = len(dataset)
|
| 98 |
+
print(f"Number of prompts: {num_prompts}")
|
| 99 |
+
|
| 100 |
+
if dist.is_initialized():
|
| 101 |
+
sampler = DistributedSampler(dataset, shuffle=False, drop_last=True)
|
| 102 |
+
else:
|
| 103 |
+
sampler = SequentialSampler(dataset)
|
| 104 |
+
dataloader = DataLoader(dataset, batch_size=1, sampler=sampler, num_workers=0, drop_last=False)
|
| 105 |
+
|
| 106 |
+
# Create output directory (only on main process to avoid race conditions)
|
| 107 |
+
if local_rank == 0:
|
| 108 |
+
os.makedirs(args.output_folder, exist_ok=True)
|
| 109 |
+
|
| 110 |
+
if dist.is_initialized():
|
| 111 |
+
dist.barrier()
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def encode(self, videos: torch.Tensor) -> torch.Tensor:
|
| 115 |
+
device, dtype = videos[0].device, videos[0].dtype
|
| 116 |
+
scale = [self.mean.to(device=device, dtype=dtype),
|
| 117 |
+
1.0 / self.std.to(device=device, dtype=dtype)]
|
| 118 |
+
output = [
|
| 119 |
+
self.model.encode(u.unsqueeze(0), scale).float().squeeze(0)
|
| 120 |
+
for u in videos
|
| 121 |
+
]
|
| 122 |
+
|
| 123 |
+
output = torch.stack(output, dim=0)
|
| 124 |
+
return output
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
for i, batch_data in tqdm(enumerate(dataloader), disable=(local_rank != 0)):
|
| 128 |
+
idx = batch_data['idx'].item()
|
| 129 |
+
|
| 130 |
+
# For DataLoader batch_size=1, the batch_data is already a single item, but in a batch container
|
| 131 |
+
# Unpack the batch data for convenience
|
| 132 |
+
if isinstance(batch_data, dict):
|
| 133 |
+
batch = batch_data
|
| 134 |
+
elif isinstance(batch_data, list):
|
| 135 |
+
batch = batch_data[0] # First (and only) item in the batch
|
| 136 |
+
|
| 137 |
+
all_video = []
|
| 138 |
+
num_generated_frames = 0 # Number of generated (latent) frames
|
| 139 |
+
|
| 140 |
+
if args.i2v:
|
| 141 |
+
# For image-to-video, batch contains image and caption
|
| 142 |
+
prompt = batch['prompts'][0] # Get caption from batch
|
| 143 |
+
prompts = [prompt] * args.num_samples
|
| 144 |
+
|
| 145 |
+
# Process the image
|
| 146 |
+
image = batch['image'].squeeze(0).unsqueeze(0).unsqueeze(2).to(device=device, dtype=torch.bfloat16)
|
| 147 |
+
|
| 148 |
+
# Encode the input image as the first latent
|
| 149 |
+
initial_latent = pipeline.vae.encode_to_latent(image).to(device=device, dtype=torch.bfloat16)
|
| 150 |
+
initial_latent = initial_latent.repeat(args.num_samples, 1, 1, 1, 1)
|
| 151 |
+
|
| 152 |
+
sampled_noise = torch.randn(
|
| 153 |
+
[args.num_samples, args.num_output_frames - 1, 16, 60, 104], device=device, dtype=torch.bfloat16
|
| 154 |
+
)
|
| 155 |
+
else:
|
| 156 |
+
# For text-to-video, batch is just the text prompt
|
| 157 |
+
prompt = batch['prompts'][0]
|
| 158 |
+
extended_prompt = batch['extended_prompts'][0] if 'extended_prompts' in batch else None
|
| 159 |
+
if extended_prompt is not None:
|
| 160 |
+
prompts = [extended_prompt] * args.num_samples
|
| 161 |
+
else:
|
| 162 |
+
prompts = [prompt] * args.num_samples
|
| 163 |
+
initial_latent = None
|
| 164 |
+
|
| 165 |
+
sampled_noise = torch.randn(
|
| 166 |
+
[args.num_samples, args.num_output_frames, 16, 60, 104], device=device, dtype=torch.bfloat16
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
# Generate 81 frames
|
| 170 |
+
video, latents = pipeline.inference_rolling_forcing(
|
| 171 |
+
noise=sampled_noise,
|
| 172 |
+
text_prompts=prompts,
|
| 173 |
+
return_latents=True,
|
| 174 |
+
initial_latent=initial_latent,
|
| 175 |
+
)
|
| 176 |
+
current_video = rearrange(video, 'b t c h w -> b t h w c').cpu()
|
| 177 |
+
all_video.append(current_video)
|
| 178 |
+
num_generated_frames += latents.shape[1]
|
| 179 |
+
|
| 180 |
+
# Final output video
|
| 181 |
+
video = 255.0 * torch.cat(all_video, dim=1)
|
| 182 |
+
|
| 183 |
+
# Clear VAE cache
|
| 184 |
+
pipeline.vae.model.clear_cache()
|
| 185 |
+
|
| 186 |
+
# Save the video if the current prompt is not a dummy prompt
|
| 187 |
+
if idx < num_prompts:
|
| 188 |
+
model = "regular" if not args.use_ema else "ema"
|
| 189 |
+
for seed_idx in range(args.num_samples):
|
| 190 |
+
# All processes save their videos
|
| 191 |
+
if args.save_with_index:
|
| 192 |
+
output_path = os.path.join(args.output_folder, f'{idx}-{seed_idx}_{model}.mp4')
|
| 193 |
+
else:
|
| 194 |
+
output_path = os.path.join(args.output_folder, f'{prompt[:100]}-{seed_idx}.mp4')
|
| 195 |
+
write_video(output_path, video[seed_idx], fps=16)
|
| 196 |
+
# imageio.mimwrite(output_path, video[seed_idx], fps=16, quality=8, output_params=["-loglevel", "error"])
|
| 197 |
+
|
model/__init__.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .diffusion import CausalDiffusion
|
| 2 |
+
from .causvid import CausVid
|
| 3 |
+
from .dmd import DMD
|
| 4 |
+
from .gan import GAN
|
| 5 |
+
from .sid import SiD
|
| 6 |
+
from .ode_regression import ODERegression
|
| 7 |
+
__all__ = [
|
| 8 |
+
"CausalDiffusion",
|
| 9 |
+
"CausVid",
|
| 10 |
+
"DMD",
|
| 11 |
+
"GAN",
|
| 12 |
+
"SiD",
|
| 13 |
+
"ODERegression"
|
| 14 |
+
]
|
model/base.py
ADDED
|
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Tuple
|
| 2 |
+
from einops import rearrange
|
| 3 |
+
from torch import nn
|
| 4 |
+
import torch.distributed as dist
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
from pipeline import RollingForcingTrainingPipeline
|
| 8 |
+
from utils.loss import get_denoising_loss
|
| 9 |
+
from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder, WanVAEWrapper
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class BaseModel(nn.Module):
|
| 13 |
+
def __init__(self, args, device):
|
| 14 |
+
super().__init__()
|
| 15 |
+
self._initialize_models(args, device)
|
| 16 |
+
|
| 17 |
+
self.device = device
|
| 18 |
+
self.args = args
|
| 19 |
+
self.dtype = torch.bfloat16 if args.mixed_precision else torch.float32
|
| 20 |
+
if hasattr(args, "denoising_step_list"):
|
| 21 |
+
self.denoising_step_list = torch.tensor(args.denoising_step_list, dtype=torch.long)
|
| 22 |
+
if args.warp_denoising_step:
|
| 23 |
+
timesteps = torch.cat((self.scheduler.timesteps.cpu(), torch.tensor([0], dtype=torch.float32)))
|
| 24 |
+
self.denoising_step_list = timesteps[1000 - self.denoising_step_list]
|
| 25 |
+
|
| 26 |
+
def _initialize_models(self, args, device):
|
| 27 |
+
self.real_model_name = getattr(args, "real_name", "Wan2.1-T2V-1.3B")
|
| 28 |
+
self.fake_model_name = getattr(args, "fake_name", "Wan2.1-T2V-1.3B")
|
| 29 |
+
|
| 30 |
+
self.generator = WanDiffusionWrapper(**getattr(args, "model_kwargs", {}), is_causal=True)
|
| 31 |
+
self.generator.model.requires_grad_(True)
|
| 32 |
+
|
| 33 |
+
self.real_score = WanDiffusionWrapper(model_name=self.real_model_name, is_causal=False)
|
| 34 |
+
self.real_score.model.requires_grad_(False)
|
| 35 |
+
|
| 36 |
+
self.fake_score = WanDiffusionWrapper(model_name=self.fake_model_name, is_causal=False)
|
| 37 |
+
self.fake_score.model.requires_grad_(True)
|
| 38 |
+
|
| 39 |
+
self.text_encoder = WanTextEncoder()
|
| 40 |
+
self.text_encoder.requires_grad_(False)
|
| 41 |
+
|
| 42 |
+
self.vae = WanVAEWrapper()
|
| 43 |
+
self.vae.requires_grad_(False)
|
| 44 |
+
|
| 45 |
+
self.scheduler = self.generator.get_scheduler()
|
| 46 |
+
self.scheduler.timesteps = self.scheduler.timesteps.to(device)
|
| 47 |
+
|
| 48 |
+
def _get_timestep(
|
| 49 |
+
self,
|
| 50 |
+
min_timestep: int,
|
| 51 |
+
max_timestep: int,
|
| 52 |
+
batch_size: int,
|
| 53 |
+
num_frame: int,
|
| 54 |
+
num_frame_per_block: int,
|
| 55 |
+
uniform_timestep: bool = False
|
| 56 |
+
) -> torch.Tensor:
|
| 57 |
+
"""
|
| 58 |
+
Randomly generate a timestep tensor based on the generator's task type. It uniformly samples a timestep
|
| 59 |
+
from the range [min_timestep, max_timestep], and returns a tensor of shape [batch_size, num_frame].
|
| 60 |
+
- If uniform_timestep, it will use the same timestep for all frames.
|
| 61 |
+
- If not uniform_timestep, it will use a different timestep for each block.
|
| 62 |
+
"""
|
| 63 |
+
if uniform_timestep:
|
| 64 |
+
timestep = torch.randint(
|
| 65 |
+
min_timestep,
|
| 66 |
+
max_timestep,
|
| 67 |
+
[batch_size, 1],
|
| 68 |
+
device=self.device,
|
| 69 |
+
dtype=torch.long
|
| 70 |
+
).repeat(1, num_frame)
|
| 71 |
+
return timestep
|
| 72 |
+
else:
|
| 73 |
+
timestep = torch.randint(
|
| 74 |
+
min_timestep,
|
| 75 |
+
max_timestep,
|
| 76 |
+
[batch_size, num_frame],
|
| 77 |
+
device=self.device,
|
| 78 |
+
dtype=torch.long
|
| 79 |
+
)
|
| 80 |
+
# make the noise level the same within every block
|
| 81 |
+
if self.independent_first_frame:
|
| 82 |
+
# the first frame is always kept the same
|
| 83 |
+
timestep_from_second = timestep[:, 1:]
|
| 84 |
+
timestep_from_second = timestep_from_second.reshape(
|
| 85 |
+
timestep_from_second.shape[0], -1, num_frame_per_block)
|
| 86 |
+
timestep_from_second[:, :, 1:] = timestep_from_second[:, :, 0:1]
|
| 87 |
+
timestep_from_second = timestep_from_second.reshape(
|
| 88 |
+
timestep_from_second.shape[0], -1)
|
| 89 |
+
timestep = torch.cat([timestep[:, 0:1], timestep_from_second], dim=1)
|
| 90 |
+
else:
|
| 91 |
+
timestep = timestep.reshape(
|
| 92 |
+
timestep.shape[0], -1, num_frame_per_block)
|
| 93 |
+
timestep[:, :, 1:] = timestep[:, :, 0:1]
|
| 94 |
+
timestep = timestep.reshape(timestep.shape[0], -1)
|
| 95 |
+
return timestep
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class RollingForcingModel(BaseModel):
|
| 99 |
+
def __init__(self, args, device):
|
| 100 |
+
super().__init__(args, device)
|
| 101 |
+
self.denoising_loss_func = get_denoising_loss(args.denoising_loss_type)()
|
| 102 |
+
|
| 103 |
+
def _run_generator(
|
| 104 |
+
self,
|
| 105 |
+
image_or_video_shape,
|
| 106 |
+
conditional_dict: dict,
|
| 107 |
+
initial_latent: torch.tensor = None
|
| 108 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 109 |
+
"""
|
| 110 |
+
Optionally simulate the generator's input from noise using backward simulation
|
| 111 |
+
and then run the generator for one-step.
|
| 112 |
+
Input:
|
| 113 |
+
- image_or_video_shape: a list containing the shape of the image or video [B, F, C, H, W].
|
| 114 |
+
- conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
|
| 115 |
+
- unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings).
|
| 116 |
+
- clean_latent: a tensor containing the clean latents [B, F, C, H, W]. Need to be passed when no backward simulation is used.
|
| 117 |
+
- initial_latent: a tensor containing the initial latents [B, F, C, H, W].
|
| 118 |
+
Output:
|
| 119 |
+
- pred_image: a tensor with shape [B, F, C, H, W].
|
| 120 |
+
- denoised_timestep: an integer
|
| 121 |
+
"""
|
| 122 |
+
# Step 1: Sample noise and backward simulate the generator's input
|
| 123 |
+
assert getattr(self.args, "backward_simulation", True), "Backward simulation needs to be enabled"
|
| 124 |
+
if initial_latent is not None:
|
| 125 |
+
conditional_dict["initial_latent"] = initial_latent
|
| 126 |
+
if self.args.i2v:
|
| 127 |
+
noise_shape = [image_or_video_shape[0], image_or_video_shape[1] - 1, *image_or_video_shape[2:]]
|
| 128 |
+
else:
|
| 129 |
+
noise_shape = image_or_video_shape.copy()
|
| 130 |
+
|
| 131 |
+
# During training, the number of generated frames should be uniformly sampled from
|
| 132 |
+
# [21, self.num_training_frames], but still being a multiple of self.num_frame_per_block
|
| 133 |
+
min_num_frames = 20 if self.args.independent_first_frame else 21
|
| 134 |
+
max_num_frames = self.num_training_frames - 1 if self.args.independent_first_frame else self.num_training_frames
|
| 135 |
+
assert max_num_frames % self.num_frame_per_block == 0
|
| 136 |
+
assert min_num_frames % self.num_frame_per_block == 0
|
| 137 |
+
max_num_blocks = max_num_frames // self.num_frame_per_block
|
| 138 |
+
min_num_blocks = min_num_frames // self.num_frame_per_block
|
| 139 |
+
num_generated_blocks = torch.randint(min_num_blocks, max_num_blocks + 1, (1,), device=self.device)
|
| 140 |
+
dist.broadcast(num_generated_blocks, src=0)
|
| 141 |
+
num_generated_blocks = num_generated_blocks.item()
|
| 142 |
+
num_generated_frames = num_generated_blocks * self.num_frame_per_block
|
| 143 |
+
if self.args.independent_first_frame and initial_latent is None:
|
| 144 |
+
num_generated_frames += 1
|
| 145 |
+
min_num_frames += 1
|
| 146 |
+
# Sync num_generated_frames across all processes
|
| 147 |
+
noise_shape[1] = num_generated_frames
|
| 148 |
+
|
| 149 |
+
pred_image_or_video, denoised_timestep_from, denoised_timestep_to = self._consistency_backward_simulation(
|
| 150 |
+
noise=torch.randn(noise_shape,
|
| 151 |
+
device=self.device, dtype=self.dtype),
|
| 152 |
+
**conditional_dict,
|
| 153 |
+
)
|
| 154 |
+
# Slice last 21 frames
|
| 155 |
+
if pred_image_or_video.shape[1] > 21:
|
| 156 |
+
with torch.no_grad():
|
| 157 |
+
# Reencode to get image latent
|
| 158 |
+
latent_to_decode = pred_image_or_video[:, :-20, ...]
|
| 159 |
+
# Deccode to video
|
| 160 |
+
pixels = self.vae.decode_to_pixel(latent_to_decode)
|
| 161 |
+
frame = pixels[:, -1:, ...].to(self.dtype)
|
| 162 |
+
frame = rearrange(frame, "b t c h w -> b c t h w")
|
| 163 |
+
# Encode frame to get image latent
|
| 164 |
+
image_latent = self.vae.encode_to_latent(frame).to(self.dtype)
|
| 165 |
+
pred_image_or_video_last_21 = torch.cat([image_latent, pred_image_or_video[:, -20:, ...]], dim=1)
|
| 166 |
+
else:
|
| 167 |
+
pred_image_or_video_last_21 = pred_image_or_video
|
| 168 |
+
|
| 169 |
+
if num_generated_frames != min_num_frames:
|
| 170 |
+
# Currently, we do not use gradient for the first chunk, since it contains image latents
|
| 171 |
+
gradient_mask = torch.ones_like(pred_image_or_video_last_21, dtype=torch.bool)
|
| 172 |
+
if self.args.independent_first_frame:
|
| 173 |
+
gradient_mask[:, :1] = False
|
| 174 |
+
else:
|
| 175 |
+
gradient_mask[:, :self.num_frame_per_block] = False
|
| 176 |
+
else:
|
| 177 |
+
gradient_mask = None
|
| 178 |
+
|
| 179 |
+
pred_image_or_video_last_21 = pred_image_or_video_last_21.to(self.dtype)
|
| 180 |
+
return pred_image_or_video_last_21, gradient_mask, denoised_timestep_from, denoised_timestep_to
|
| 181 |
+
|
| 182 |
+
def _consistency_backward_simulation(
|
| 183 |
+
self,
|
| 184 |
+
noise: torch.Tensor,
|
| 185 |
+
**conditional_dict: dict
|
| 186 |
+
) -> torch.Tensor:
|
| 187 |
+
"""
|
| 188 |
+
Simulate the generator's input from noise to avoid training/inference mismatch.
|
| 189 |
+
See Sec 4.5 of the DMD2 paper (https://arxiv.org/abs/2405.14867) for details.
|
| 190 |
+
Here we use the consistency sampler (https://arxiv.org/abs/2303.01469)
|
| 191 |
+
Input:
|
| 192 |
+
- noise: a tensor sampled from N(0, 1) with shape [B, F, C, H, W] where the number of frame is 1 for images.
|
| 193 |
+
- conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
|
| 194 |
+
Output:
|
| 195 |
+
- output: a tensor with shape [B, T, F, C, H, W].
|
| 196 |
+
T is the total number of timesteps. output[0] is a pure noise and output[i] and i>0
|
| 197 |
+
represents the x0 prediction at each timestep.
|
| 198 |
+
"""
|
| 199 |
+
if self.inference_pipeline is None:
|
| 200 |
+
self._initialize_inference_pipeline()
|
| 201 |
+
|
| 202 |
+
infer_w_rolling = torch.rand(1, device=self.device) > 0.5
|
| 203 |
+
dist.broadcast(infer_w_rolling, src=0)
|
| 204 |
+
|
| 205 |
+
if infer_w_rolling:
|
| 206 |
+
return self.inference_pipeline.inference_with_rolling_forcing(
|
| 207 |
+
noise=noise, **conditional_dict
|
| 208 |
+
)
|
| 209 |
+
else:
|
| 210 |
+
return self.inference_pipeline.inference_with_self_forcing(
|
| 211 |
+
noise=noise, **conditional_dict
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
def _initialize_inference_pipeline(self):
|
| 215 |
+
"""
|
| 216 |
+
Lazy initialize the inference pipeline during the first backward simulation run.
|
| 217 |
+
Here we encapsulate the inference code with a model-dependent outside function.
|
| 218 |
+
We pass our FSDP-wrapped modules into the pipeline to save memory.
|
| 219 |
+
"""
|
| 220 |
+
self.inference_pipeline = RollingForcingTrainingPipeline(
|
| 221 |
+
denoising_step_list=self.denoising_step_list,
|
| 222 |
+
scheduler=self.scheduler,
|
| 223 |
+
generator=self.generator,
|
| 224 |
+
num_frame_per_block=self.num_frame_per_block,
|
| 225 |
+
independent_first_frame=self.args.independent_first_frame,
|
| 226 |
+
same_step_across_blocks=self.args.same_step_across_blocks,
|
| 227 |
+
last_step_only=self.args.last_step_only,
|
| 228 |
+
num_max_frames=self.num_training_frames,
|
| 229 |
+
context_noise=self.args.context_noise
|
| 230 |
+
)
|
model/causvid.py
ADDED
|
@@ -0,0 +1,391 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn.functional as F
|
| 2 |
+
from typing import Tuple
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from model.base import BaseModel
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class CausVid(BaseModel):
|
| 9 |
+
def __init__(self, args, device):
|
| 10 |
+
"""
|
| 11 |
+
Initialize the DMD (Distribution Matching Distillation) module.
|
| 12 |
+
This class is self-contained and compute generator and fake score losses
|
| 13 |
+
in the forward pass.
|
| 14 |
+
"""
|
| 15 |
+
super().__init__(args, device)
|
| 16 |
+
self.num_frame_per_block = getattr(args, "num_frame_per_block", 1)
|
| 17 |
+
self.num_training_frames = getattr(args, "num_training_frames", 21)
|
| 18 |
+
|
| 19 |
+
if self.num_frame_per_block > 1:
|
| 20 |
+
self.generator.model.num_frame_per_block = self.num_frame_per_block
|
| 21 |
+
|
| 22 |
+
self.independent_first_frame = getattr(args, "independent_first_frame", False)
|
| 23 |
+
if self.independent_first_frame:
|
| 24 |
+
self.generator.model.independent_first_frame = True
|
| 25 |
+
if args.gradient_checkpointing:
|
| 26 |
+
self.generator.enable_gradient_checkpointing()
|
| 27 |
+
self.fake_score.enable_gradient_checkpointing()
|
| 28 |
+
|
| 29 |
+
# Step 2: Initialize all dmd hyperparameters
|
| 30 |
+
self.num_train_timestep = args.num_train_timestep
|
| 31 |
+
self.min_step = int(0.02 * self.num_train_timestep)
|
| 32 |
+
self.max_step = int(0.98 * self.num_train_timestep)
|
| 33 |
+
if hasattr(args, "real_guidance_scale"):
|
| 34 |
+
self.real_guidance_scale = args.real_guidance_scale
|
| 35 |
+
self.fake_guidance_scale = args.fake_guidance_scale
|
| 36 |
+
else:
|
| 37 |
+
self.real_guidance_scale = args.guidance_scale
|
| 38 |
+
self.fake_guidance_scale = 0.0
|
| 39 |
+
self.timestep_shift = getattr(args, "timestep_shift", 1.0)
|
| 40 |
+
self.teacher_forcing = getattr(args, "teacher_forcing", False)
|
| 41 |
+
|
| 42 |
+
if getattr(self.scheduler, "alphas_cumprod", None) is not None:
|
| 43 |
+
self.scheduler.alphas_cumprod = self.scheduler.alphas_cumprod.to(device)
|
| 44 |
+
else:
|
| 45 |
+
self.scheduler.alphas_cumprod = None
|
| 46 |
+
|
| 47 |
+
def _compute_kl_grad(
|
| 48 |
+
self, noisy_image_or_video: torch.Tensor,
|
| 49 |
+
estimated_clean_image_or_video: torch.Tensor,
|
| 50 |
+
timestep: torch.Tensor,
|
| 51 |
+
conditional_dict: dict, unconditional_dict: dict,
|
| 52 |
+
normalization: bool = True
|
| 53 |
+
) -> Tuple[torch.Tensor, dict]:
|
| 54 |
+
"""
|
| 55 |
+
Compute the KL grad (eq 7 in https://arxiv.org/abs/2311.18828).
|
| 56 |
+
Input:
|
| 57 |
+
- noisy_image_or_video: a tensor with shape [B, F, C, H, W] where the number of frame is 1 for images.
|
| 58 |
+
- estimated_clean_image_or_video: a tensor with shape [B, F, C, H, W] representing the estimated clean image or video.
|
| 59 |
+
- timestep: a tensor with shape [B, F] containing the randomly generated timestep.
|
| 60 |
+
- conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
|
| 61 |
+
- unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings).
|
| 62 |
+
- normalization: a boolean indicating whether to normalize the gradient.
|
| 63 |
+
Output:
|
| 64 |
+
- kl_grad: a tensor representing the KL grad.
|
| 65 |
+
- kl_log_dict: a dictionary containing the intermediate tensors for logging.
|
| 66 |
+
"""
|
| 67 |
+
# Step 1: Compute the fake score
|
| 68 |
+
_, pred_fake_image_cond = self.fake_score(
|
| 69 |
+
noisy_image_or_video=noisy_image_or_video,
|
| 70 |
+
conditional_dict=conditional_dict,
|
| 71 |
+
timestep=timestep
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
if self.fake_guidance_scale != 0.0:
|
| 75 |
+
_, pred_fake_image_uncond = self.fake_score(
|
| 76 |
+
noisy_image_or_video=noisy_image_or_video,
|
| 77 |
+
conditional_dict=unconditional_dict,
|
| 78 |
+
timestep=timestep
|
| 79 |
+
)
|
| 80 |
+
pred_fake_image = pred_fake_image_cond + (
|
| 81 |
+
pred_fake_image_cond - pred_fake_image_uncond
|
| 82 |
+
) * self.fake_guidance_scale
|
| 83 |
+
else:
|
| 84 |
+
pred_fake_image = pred_fake_image_cond
|
| 85 |
+
|
| 86 |
+
# Step 2: Compute the real score
|
| 87 |
+
# We compute the conditional and unconditional prediction
|
| 88 |
+
# and add them together to achieve cfg (https://arxiv.org/abs/2207.12598)
|
| 89 |
+
_, pred_real_image_cond = self.real_score(
|
| 90 |
+
noisy_image_or_video=noisy_image_or_video,
|
| 91 |
+
conditional_dict=conditional_dict,
|
| 92 |
+
timestep=timestep
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
_, pred_real_image_uncond = self.real_score(
|
| 96 |
+
noisy_image_or_video=noisy_image_or_video,
|
| 97 |
+
conditional_dict=unconditional_dict,
|
| 98 |
+
timestep=timestep
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
pred_real_image = pred_real_image_cond + (
|
| 102 |
+
pred_real_image_cond - pred_real_image_uncond
|
| 103 |
+
) * self.real_guidance_scale
|
| 104 |
+
|
| 105 |
+
# Step 3: Compute the DMD gradient (DMD paper eq. 7).
|
| 106 |
+
grad = (pred_fake_image - pred_real_image)
|
| 107 |
+
|
| 108 |
+
# TODO: Change the normalizer for causal teacher
|
| 109 |
+
if normalization:
|
| 110 |
+
# Step 4: Gradient normalization (DMD paper eq. 8).
|
| 111 |
+
p_real = (estimated_clean_image_or_video - pred_real_image)
|
| 112 |
+
normalizer = torch.abs(p_real).mean(dim=[1, 2, 3, 4], keepdim=True)
|
| 113 |
+
grad = grad / normalizer
|
| 114 |
+
grad = torch.nan_to_num(grad)
|
| 115 |
+
|
| 116 |
+
return grad, {
|
| 117 |
+
"dmdtrain_gradient_norm": torch.mean(torch.abs(grad)).detach(),
|
| 118 |
+
"timestep": timestep.detach()
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
def compute_distribution_matching_loss(
|
| 122 |
+
self,
|
| 123 |
+
image_or_video: torch.Tensor,
|
| 124 |
+
conditional_dict: dict,
|
| 125 |
+
unconditional_dict: dict,
|
| 126 |
+
gradient_mask: torch.Tensor = None,
|
| 127 |
+
) -> Tuple[torch.Tensor, dict]:
|
| 128 |
+
"""
|
| 129 |
+
Compute the DMD loss (eq 7 in https://arxiv.org/abs/2311.18828).
|
| 130 |
+
Input:
|
| 131 |
+
- image_or_video: a tensor with shape [B, F, C, H, W] where the number of frame is 1 for images.
|
| 132 |
+
- conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
|
| 133 |
+
- unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings).
|
| 134 |
+
- gradient_mask: a boolean tensor with the same shape as image_or_video indicating which pixels to compute loss .
|
| 135 |
+
Output:
|
| 136 |
+
- dmd_loss: a scalar tensor representing the DMD loss.
|
| 137 |
+
- dmd_log_dict: a dictionary containing the intermediate tensors for logging.
|
| 138 |
+
"""
|
| 139 |
+
original_latent = image_or_video
|
| 140 |
+
|
| 141 |
+
batch_size, num_frame = image_or_video.shape[:2]
|
| 142 |
+
|
| 143 |
+
with torch.no_grad():
|
| 144 |
+
# Step 1: Randomly sample timestep based on the given schedule and corresponding noise
|
| 145 |
+
timestep = self._get_timestep(
|
| 146 |
+
0,
|
| 147 |
+
self.num_train_timestep,
|
| 148 |
+
batch_size,
|
| 149 |
+
num_frame,
|
| 150 |
+
self.num_frame_per_block,
|
| 151 |
+
uniform_timestep=True
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
if self.timestep_shift > 1:
|
| 155 |
+
timestep = self.timestep_shift * \
|
| 156 |
+
(timestep / 1000) / \
|
| 157 |
+
(1 + (self.timestep_shift - 1) * (timestep / 1000)) * 1000
|
| 158 |
+
timestep = timestep.clamp(self.min_step, self.max_step)
|
| 159 |
+
|
| 160 |
+
noise = torch.randn_like(image_or_video)
|
| 161 |
+
noisy_latent = self.scheduler.add_noise(
|
| 162 |
+
image_or_video.flatten(0, 1),
|
| 163 |
+
noise.flatten(0, 1),
|
| 164 |
+
timestep.flatten(0, 1)
|
| 165 |
+
).detach().unflatten(0, (batch_size, num_frame))
|
| 166 |
+
|
| 167 |
+
# Step 2: Compute the KL grad
|
| 168 |
+
grad, dmd_log_dict = self._compute_kl_grad(
|
| 169 |
+
noisy_image_or_video=noisy_latent,
|
| 170 |
+
estimated_clean_image_or_video=original_latent,
|
| 171 |
+
timestep=timestep,
|
| 172 |
+
conditional_dict=conditional_dict,
|
| 173 |
+
unconditional_dict=unconditional_dict
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
if gradient_mask is not None:
|
| 177 |
+
dmd_loss = 0.5 * F.mse_loss(original_latent.double(
|
| 178 |
+
)[gradient_mask], (original_latent.double() - grad.double()).detach()[gradient_mask], reduction="mean")
|
| 179 |
+
else:
|
| 180 |
+
dmd_loss = 0.5 * F.mse_loss(original_latent.double(
|
| 181 |
+
), (original_latent.double() - grad.double()).detach(), reduction="mean")
|
| 182 |
+
return dmd_loss, dmd_log_dict
|
| 183 |
+
|
| 184 |
+
def _run_generator(
|
| 185 |
+
self,
|
| 186 |
+
image_or_video_shape,
|
| 187 |
+
conditional_dict: dict,
|
| 188 |
+
clean_latent: torch.tensor
|
| 189 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 190 |
+
"""
|
| 191 |
+
Optionally simulate the generator's input from noise using backward simulation
|
| 192 |
+
and then run the generator for one-step.
|
| 193 |
+
Input:
|
| 194 |
+
- image_or_video_shape: a list containing the shape of the image or video [B, F, C, H, W].
|
| 195 |
+
- conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
|
| 196 |
+
- unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings).
|
| 197 |
+
- clean_latent: a tensor containing the clean latents [B, F, C, H, W]. Need to be passed when no backward simulation is used.
|
| 198 |
+
- initial_latent: a tensor containing the initial latents [B, F, C, H, W].
|
| 199 |
+
Output:
|
| 200 |
+
- pred_image: a tensor with shape [B, F, C, H, W].
|
| 201 |
+
"""
|
| 202 |
+
simulated_noisy_input = []
|
| 203 |
+
for timestep in self.denoising_step_list:
|
| 204 |
+
noise = torch.randn(
|
| 205 |
+
image_or_video_shape, device=self.device, dtype=self.dtype)
|
| 206 |
+
|
| 207 |
+
noisy_timestep = timestep * torch.ones(
|
| 208 |
+
image_or_video_shape[:2], device=self.device, dtype=torch.long)
|
| 209 |
+
|
| 210 |
+
if timestep != 0:
|
| 211 |
+
noisy_image = self.scheduler.add_noise(
|
| 212 |
+
clean_latent.flatten(0, 1),
|
| 213 |
+
noise.flatten(0, 1),
|
| 214 |
+
noisy_timestep.flatten(0, 1)
|
| 215 |
+
).unflatten(0, image_or_video_shape[:2])
|
| 216 |
+
else:
|
| 217 |
+
noisy_image = clean_latent
|
| 218 |
+
|
| 219 |
+
simulated_noisy_input.append(noisy_image)
|
| 220 |
+
|
| 221 |
+
simulated_noisy_input = torch.stack(simulated_noisy_input, dim=1)
|
| 222 |
+
|
| 223 |
+
# Step 2: Randomly sample a timestep and pick the corresponding input
|
| 224 |
+
index = self._get_timestep(
|
| 225 |
+
0,
|
| 226 |
+
len(self.denoising_step_list),
|
| 227 |
+
image_or_video_shape[0],
|
| 228 |
+
image_or_video_shape[1],
|
| 229 |
+
self.num_frame_per_block,
|
| 230 |
+
uniform_timestep=False
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
# select the corresponding timestep's noisy input from the stacked tensor [B, T, F, C, H, W]
|
| 234 |
+
noisy_input = torch.gather(
|
| 235 |
+
simulated_noisy_input, dim=1,
|
| 236 |
+
index=index.reshape(index.shape[0], 1, index.shape[1], 1, 1, 1).expand(
|
| 237 |
+
-1, -1, -1, *image_or_video_shape[2:]).to(self.device)
|
| 238 |
+
).squeeze(1)
|
| 239 |
+
|
| 240 |
+
timestep = self.denoising_step_list[index].to(self.device)
|
| 241 |
+
|
| 242 |
+
_, pred_image_or_video = self.generator(
|
| 243 |
+
noisy_image_or_video=noisy_input,
|
| 244 |
+
conditional_dict=conditional_dict,
|
| 245 |
+
timestep=timestep,
|
| 246 |
+
clean_x=clean_latent if self.teacher_forcing else None,
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
gradient_mask = None # timestep != 0
|
| 250 |
+
|
| 251 |
+
pred_image_or_video = pred_image_or_video.type_as(noisy_input)
|
| 252 |
+
|
| 253 |
+
return pred_image_or_video, gradient_mask
|
| 254 |
+
|
| 255 |
+
def generator_loss(
|
| 256 |
+
self,
|
| 257 |
+
image_or_video_shape,
|
| 258 |
+
conditional_dict: dict,
|
| 259 |
+
unconditional_dict: dict,
|
| 260 |
+
clean_latent: torch.Tensor,
|
| 261 |
+
initial_latent: torch.Tensor = None
|
| 262 |
+
) -> Tuple[torch.Tensor, dict]:
|
| 263 |
+
"""
|
| 264 |
+
Generate image/videos from noise and compute the DMD loss.
|
| 265 |
+
The noisy input to the generator is backward simulated.
|
| 266 |
+
This removes the need of any datasets during distillation.
|
| 267 |
+
See Sec 4.5 of the DMD2 paper (https://arxiv.org/abs/2405.14867) for details.
|
| 268 |
+
Input:
|
| 269 |
+
- image_or_video_shape: a list containing the shape of the image or video [B, F, C, H, W].
|
| 270 |
+
- conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
|
| 271 |
+
- unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings).
|
| 272 |
+
- clean_latent: a tensor containing the clean latents [B, F, C, H, W]. Need to be passed when no backward simulation is used.
|
| 273 |
+
Output:
|
| 274 |
+
- loss: a scalar tensor representing the generator loss.
|
| 275 |
+
- generator_log_dict: a dictionary containing the intermediate tensors for logging.
|
| 276 |
+
"""
|
| 277 |
+
# Step 1: Run generator on backward simulated noisy input
|
| 278 |
+
pred_image, gradient_mask = self._run_generator(
|
| 279 |
+
image_or_video_shape=image_or_video_shape,
|
| 280 |
+
conditional_dict=conditional_dict,
|
| 281 |
+
clean_latent=clean_latent
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
# Step 2: Compute the DMD loss
|
| 285 |
+
dmd_loss, dmd_log_dict = self.compute_distribution_matching_loss(
|
| 286 |
+
image_or_video=pred_image,
|
| 287 |
+
conditional_dict=conditional_dict,
|
| 288 |
+
unconditional_dict=unconditional_dict,
|
| 289 |
+
gradient_mask=gradient_mask
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
# Step 3: TODO: Implement the GAN loss
|
| 293 |
+
|
| 294 |
+
return dmd_loss, dmd_log_dict
|
| 295 |
+
|
| 296 |
+
def critic_loss(
|
| 297 |
+
self,
|
| 298 |
+
image_or_video_shape,
|
| 299 |
+
conditional_dict: dict,
|
| 300 |
+
unconditional_dict: dict,
|
| 301 |
+
clean_latent: torch.Tensor,
|
| 302 |
+
initial_latent: torch.Tensor = None
|
| 303 |
+
) -> Tuple[torch.Tensor, dict]:
|
| 304 |
+
"""
|
| 305 |
+
Generate image/videos from noise and train the critic with generated samples.
|
| 306 |
+
The noisy input to the generator is backward simulated.
|
| 307 |
+
This removes the need of any datasets during distillation.
|
| 308 |
+
See Sec 4.5 of the DMD2 paper (https://arxiv.org/abs/2405.14867) for details.
|
| 309 |
+
Input:
|
| 310 |
+
- image_or_video_shape: a list containing the shape of the image or video [B, F, C, H, W].
|
| 311 |
+
- conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
|
| 312 |
+
- unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings).
|
| 313 |
+
- clean_latent: a tensor containing the clean latents [B, F, C, H, W]. Need to be passed when no backward simulation is used.
|
| 314 |
+
Output:
|
| 315 |
+
- loss: a scalar tensor representing the generator loss.
|
| 316 |
+
- critic_log_dict: a dictionary containing the intermediate tensors for logging.
|
| 317 |
+
"""
|
| 318 |
+
|
| 319 |
+
# Step 1: Run generator on backward simulated noisy input
|
| 320 |
+
with torch.no_grad():
|
| 321 |
+
generated_image, _ = self._run_generator(
|
| 322 |
+
image_or_video_shape=image_or_video_shape,
|
| 323 |
+
conditional_dict=conditional_dict,
|
| 324 |
+
clean_latent=clean_latent
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
# Step 2: Compute the fake prediction
|
| 328 |
+
critic_timestep = self._get_timestep(
|
| 329 |
+
0,
|
| 330 |
+
self.num_train_timestep,
|
| 331 |
+
image_or_video_shape[0],
|
| 332 |
+
image_or_video_shape[1],
|
| 333 |
+
self.num_frame_per_block,
|
| 334 |
+
uniform_timestep=True
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
if self.timestep_shift > 1:
|
| 338 |
+
critic_timestep = self.timestep_shift * \
|
| 339 |
+
(critic_timestep / 1000) / (1 + (self.timestep_shift - 1) * (critic_timestep / 1000)) * 1000
|
| 340 |
+
|
| 341 |
+
critic_timestep = critic_timestep.clamp(self.min_step, self.max_step)
|
| 342 |
+
|
| 343 |
+
critic_noise = torch.randn_like(generated_image)
|
| 344 |
+
noisy_generated_image = self.scheduler.add_noise(
|
| 345 |
+
generated_image.flatten(0, 1),
|
| 346 |
+
critic_noise.flatten(0, 1),
|
| 347 |
+
critic_timestep.flatten(0, 1)
|
| 348 |
+
).unflatten(0, image_or_video_shape[:2])
|
| 349 |
+
|
| 350 |
+
_, pred_fake_image = self.fake_score(
|
| 351 |
+
noisy_image_or_video=noisy_generated_image,
|
| 352 |
+
conditional_dict=conditional_dict,
|
| 353 |
+
timestep=critic_timestep
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
# Step 3: Compute the denoising loss for the fake critic
|
| 357 |
+
if self.args.denoising_loss_type == "flow":
|
| 358 |
+
from utils.wan_wrapper import WanDiffusionWrapper
|
| 359 |
+
flow_pred = WanDiffusionWrapper._convert_x0_to_flow_pred(
|
| 360 |
+
scheduler=self.scheduler,
|
| 361 |
+
x0_pred=pred_fake_image.flatten(0, 1),
|
| 362 |
+
xt=noisy_generated_image.flatten(0, 1),
|
| 363 |
+
timestep=critic_timestep.flatten(0, 1)
|
| 364 |
+
)
|
| 365 |
+
pred_fake_noise = None
|
| 366 |
+
else:
|
| 367 |
+
flow_pred = None
|
| 368 |
+
pred_fake_noise = self.scheduler.convert_x0_to_noise(
|
| 369 |
+
x0=pred_fake_image.flatten(0, 1),
|
| 370 |
+
xt=noisy_generated_image.flatten(0, 1),
|
| 371 |
+
timestep=critic_timestep.flatten(0, 1)
|
| 372 |
+
).unflatten(0, image_or_video_shape[:2])
|
| 373 |
+
|
| 374 |
+
denoising_loss = self.denoising_loss_func(
|
| 375 |
+
x=generated_image.flatten(0, 1),
|
| 376 |
+
x_pred=pred_fake_image.flatten(0, 1),
|
| 377 |
+
noise=critic_noise.flatten(0, 1),
|
| 378 |
+
noise_pred=pred_fake_noise,
|
| 379 |
+
alphas_cumprod=self.scheduler.alphas_cumprod,
|
| 380 |
+
timestep=critic_timestep.flatten(0, 1),
|
| 381 |
+
flow_pred=flow_pred
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
# Step 4: TODO: Compute the GAN loss
|
| 385 |
+
|
| 386 |
+
# Step 5: Debugging Log
|
| 387 |
+
critic_log_dict = {
|
| 388 |
+
"critic_timestep": critic_timestep.detach()
|
| 389 |
+
}
|
| 390 |
+
|
| 391 |
+
return denoising_loss, critic_log_dict
|
model/diffusion.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Tuple
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
from model.base import BaseModel
|
| 5 |
+
from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder, WanVAEWrapper
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class CausalDiffusion(BaseModel):
|
| 9 |
+
def __init__(self, args, device):
|
| 10 |
+
"""
|
| 11 |
+
Initialize the Diffusion loss module.
|
| 12 |
+
"""
|
| 13 |
+
super().__init__(args, device)
|
| 14 |
+
self.num_frame_per_block = getattr(args, "num_frame_per_block", 1)
|
| 15 |
+
if self.num_frame_per_block > 1:
|
| 16 |
+
self.generator.model.num_frame_per_block = self.num_frame_per_block
|
| 17 |
+
self.independent_first_frame = getattr(args, "independent_first_frame", False)
|
| 18 |
+
if self.independent_first_frame:
|
| 19 |
+
self.generator.model.independent_first_frame = True
|
| 20 |
+
|
| 21 |
+
if args.gradient_checkpointing:
|
| 22 |
+
self.generator.enable_gradient_checkpointing()
|
| 23 |
+
|
| 24 |
+
# Step 2: Initialize all hyperparameters
|
| 25 |
+
self.num_train_timestep = args.num_train_timestep
|
| 26 |
+
self.min_step = int(0.02 * self.num_train_timestep)
|
| 27 |
+
self.max_step = int(0.98 * self.num_train_timestep)
|
| 28 |
+
self.guidance_scale = args.guidance_scale
|
| 29 |
+
self.timestep_shift = getattr(args, "timestep_shift", 1.0)
|
| 30 |
+
self.teacher_forcing = getattr(args, "teacher_forcing", False)
|
| 31 |
+
# Noise augmentation in teacher forcing, we add small noise to clean context latents
|
| 32 |
+
self.noise_augmentation_max_timestep = getattr(args, "noise_augmentation_max_timestep", 0)
|
| 33 |
+
|
| 34 |
+
def _initialize_models(self, args):
|
| 35 |
+
self.generator = WanDiffusionWrapper(**getattr(args, "model_kwargs", {}), is_causal=True)
|
| 36 |
+
self.generator.model.requires_grad_(True)
|
| 37 |
+
|
| 38 |
+
self.text_encoder = WanTextEncoder()
|
| 39 |
+
self.text_encoder.requires_grad_(False)
|
| 40 |
+
|
| 41 |
+
self.vae = WanVAEWrapper()
|
| 42 |
+
self.vae.requires_grad_(False)
|
| 43 |
+
|
| 44 |
+
def generator_loss(
|
| 45 |
+
self,
|
| 46 |
+
image_or_video_shape,
|
| 47 |
+
conditional_dict: dict,
|
| 48 |
+
unconditional_dict: dict,
|
| 49 |
+
clean_latent: torch.Tensor,
|
| 50 |
+
initial_latent: torch.Tensor = None
|
| 51 |
+
) -> Tuple[torch.Tensor, dict]:
|
| 52 |
+
"""
|
| 53 |
+
Generate image/videos from noise and compute the DMD loss.
|
| 54 |
+
The noisy input to the generator is backward simulated.
|
| 55 |
+
This removes the need of any datasets during distillation.
|
| 56 |
+
See Sec 4.5 of the DMD2 paper (https://arxiv.org/abs/2405.14867) for details.
|
| 57 |
+
Input:
|
| 58 |
+
- image_or_video_shape: a list containing the shape of the image or video [B, F, C, H, W].
|
| 59 |
+
- conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
|
| 60 |
+
- unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings).
|
| 61 |
+
- clean_latent: a tensor containing the clean latents [B, F, C, H, W]. Need to be passed when no backward simulation is used.
|
| 62 |
+
Output:
|
| 63 |
+
- loss: a scalar tensor representing the generator loss.
|
| 64 |
+
- generator_log_dict: a dictionary containing the intermediate tensors for logging.
|
| 65 |
+
"""
|
| 66 |
+
noise = torch.randn_like(clean_latent)
|
| 67 |
+
batch_size, num_frame = image_or_video_shape[:2]
|
| 68 |
+
|
| 69 |
+
# Step 2: Randomly sample a timestep and add noise to denoiser inputs
|
| 70 |
+
index = self._get_timestep(
|
| 71 |
+
0,
|
| 72 |
+
self.scheduler.num_train_timesteps,
|
| 73 |
+
image_or_video_shape[0],
|
| 74 |
+
image_or_video_shape[1],
|
| 75 |
+
self.num_frame_per_block,
|
| 76 |
+
uniform_timestep=False
|
| 77 |
+
)
|
| 78 |
+
timestep = self.scheduler.timesteps[index].to(dtype=self.dtype, device=self.device)
|
| 79 |
+
noisy_latents = self.scheduler.add_noise(
|
| 80 |
+
clean_latent.flatten(0, 1),
|
| 81 |
+
noise.flatten(0, 1),
|
| 82 |
+
timestep.flatten(0, 1)
|
| 83 |
+
).unflatten(0, (batch_size, num_frame))
|
| 84 |
+
training_target = self.scheduler.training_target(clean_latent, noise, timestep)
|
| 85 |
+
|
| 86 |
+
# Step 3: Noise augmentation, also add small noise to clean context latents
|
| 87 |
+
if self.noise_augmentation_max_timestep > 0:
|
| 88 |
+
index_clean_aug = self._get_timestep(
|
| 89 |
+
0,
|
| 90 |
+
self.noise_augmentation_max_timestep,
|
| 91 |
+
image_or_video_shape[0],
|
| 92 |
+
image_or_video_shape[1],
|
| 93 |
+
self.num_frame_per_block,
|
| 94 |
+
uniform_timestep=False
|
| 95 |
+
)
|
| 96 |
+
timestep_clean_aug = self.scheduler.timesteps[index_clean_aug].to(dtype=self.dtype, device=self.device)
|
| 97 |
+
clean_latent_aug = self.scheduler.add_noise(
|
| 98 |
+
clean_latent.flatten(0, 1),
|
| 99 |
+
noise.flatten(0, 1),
|
| 100 |
+
timestep_clean_aug.flatten(0, 1)
|
| 101 |
+
).unflatten(0, (batch_size, num_frame))
|
| 102 |
+
else:
|
| 103 |
+
clean_latent_aug = clean_latent
|
| 104 |
+
timestep_clean_aug = None
|
| 105 |
+
|
| 106 |
+
# Compute loss
|
| 107 |
+
flow_pred, x0_pred = self.generator(
|
| 108 |
+
noisy_image_or_video=noisy_latents,
|
| 109 |
+
conditional_dict=conditional_dict,
|
| 110 |
+
timestep=timestep,
|
| 111 |
+
clean_x=clean_latent_aug if self.teacher_forcing else None,
|
| 112 |
+
aug_t=timestep_clean_aug if self.teacher_forcing else None
|
| 113 |
+
)
|
| 114 |
+
# loss = torch.nn.functional.mse_loss(flow_pred.float(), training_target.float())
|
| 115 |
+
loss = torch.nn.functional.mse_loss(
|
| 116 |
+
flow_pred.float(), training_target.float(), reduction='none'
|
| 117 |
+
).mean(dim=(2, 3, 4))
|
| 118 |
+
loss = loss * self.scheduler.training_weight(timestep).unflatten(0, (batch_size, num_frame))
|
| 119 |
+
loss = loss.mean()
|
| 120 |
+
|
| 121 |
+
log_dict = {
|
| 122 |
+
"x0": clean_latent.detach(),
|
| 123 |
+
"x0_pred": x0_pred.detach()
|
| 124 |
+
}
|
| 125 |
+
return loss, log_dict
|
model/dmd.py
ADDED
|
@@ -0,0 +1,332 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pipeline import RollingForcingTrainingPipeline
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
from typing import Optional, Tuple
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from model.base import RollingForcingModel
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class DMD(RollingForcingModel):
|
| 10 |
+
def __init__(self, args, device):
|
| 11 |
+
"""
|
| 12 |
+
Initialize the DMD (Distribution Matching Distillation) module.
|
| 13 |
+
This class is self-contained and compute generator and fake score losses
|
| 14 |
+
in the forward pass.
|
| 15 |
+
"""
|
| 16 |
+
super().__init__(args, device)
|
| 17 |
+
self.num_frame_per_block = getattr(args, "num_frame_per_block", 1)
|
| 18 |
+
self.same_step_across_blocks = getattr(args, "same_step_across_blocks", True)
|
| 19 |
+
self.num_training_frames = getattr(args, "num_training_frames", 21)
|
| 20 |
+
|
| 21 |
+
if self.num_frame_per_block > 1:
|
| 22 |
+
self.generator.model.num_frame_per_block = self.num_frame_per_block
|
| 23 |
+
|
| 24 |
+
self.independent_first_frame = getattr(args, "independent_first_frame", False)
|
| 25 |
+
if self.independent_first_frame:
|
| 26 |
+
self.generator.model.independent_first_frame = True
|
| 27 |
+
if args.gradient_checkpointing:
|
| 28 |
+
self.generator.enable_gradient_checkpointing()
|
| 29 |
+
self.fake_score.enable_gradient_checkpointing()
|
| 30 |
+
|
| 31 |
+
# this will be init later with fsdp-wrapped modules
|
| 32 |
+
self.inference_pipeline: RollingForcingTrainingPipeline = None
|
| 33 |
+
|
| 34 |
+
# Step 2: Initialize all dmd hyperparameters
|
| 35 |
+
self.num_train_timestep = args.num_train_timestep
|
| 36 |
+
self.min_step = int(0.02 * self.num_train_timestep)
|
| 37 |
+
self.max_step = int(0.98 * self.num_train_timestep)
|
| 38 |
+
if hasattr(args, "real_guidance_scale"):
|
| 39 |
+
self.real_guidance_scale = args.real_guidance_scale
|
| 40 |
+
self.fake_guidance_scale = args.fake_guidance_scale
|
| 41 |
+
else:
|
| 42 |
+
self.real_guidance_scale = args.guidance_scale
|
| 43 |
+
self.fake_guidance_scale = 0.0
|
| 44 |
+
self.timestep_shift = getattr(args, "timestep_shift", 1.0)
|
| 45 |
+
self.ts_schedule = getattr(args, "ts_schedule", True)
|
| 46 |
+
self.ts_schedule_max = getattr(args, "ts_schedule_max", False)
|
| 47 |
+
self.min_score_timestep = getattr(args, "min_score_timestep", 0)
|
| 48 |
+
|
| 49 |
+
if getattr(self.scheduler, "alphas_cumprod", None) is not None:
|
| 50 |
+
self.scheduler.alphas_cumprod = self.scheduler.alphas_cumprod.to(device)
|
| 51 |
+
else:
|
| 52 |
+
self.scheduler.alphas_cumprod = None
|
| 53 |
+
|
| 54 |
+
def _compute_kl_grad(
|
| 55 |
+
self, noisy_image_or_video: torch.Tensor,
|
| 56 |
+
estimated_clean_image_or_video: torch.Tensor,
|
| 57 |
+
timestep: torch.Tensor,
|
| 58 |
+
conditional_dict: dict, unconditional_dict: dict,
|
| 59 |
+
normalization: bool = True
|
| 60 |
+
) -> Tuple[torch.Tensor, dict]:
|
| 61 |
+
"""
|
| 62 |
+
Compute the KL grad (eq 7 in https://arxiv.org/abs/2311.18828).
|
| 63 |
+
Input:
|
| 64 |
+
- noisy_image_or_video: a tensor with shape [B, F, C, H, W] where the number of frame is 1 for images.
|
| 65 |
+
- estimated_clean_image_or_video: a tensor with shape [B, F, C, H, W] representing the estimated clean image or video.
|
| 66 |
+
- timestep: a tensor with shape [B, F] containing the randomly generated timestep.
|
| 67 |
+
- conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
|
| 68 |
+
- unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings).
|
| 69 |
+
- normalization: a boolean indicating whether to normalize the gradient.
|
| 70 |
+
Output:
|
| 71 |
+
- kl_grad: a tensor representing the KL grad.
|
| 72 |
+
- kl_log_dict: a dictionary containing the intermediate tensors for logging.
|
| 73 |
+
"""
|
| 74 |
+
# Step 1: Compute the fake score
|
| 75 |
+
_, pred_fake_image_cond = self.fake_score(
|
| 76 |
+
noisy_image_or_video=noisy_image_or_video,
|
| 77 |
+
conditional_dict=conditional_dict,
|
| 78 |
+
timestep=timestep
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
if self.fake_guidance_scale != 0.0:
|
| 82 |
+
_, pred_fake_image_uncond = self.fake_score(
|
| 83 |
+
noisy_image_or_video=noisy_image_or_video,
|
| 84 |
+
conditional_dict=unconditional_dict,
|
| 85 |
+
timestep=timestep
|
| 86 |
+
)
|
| 87 |
+
pred_fake_image = pred_fake_image_cond + (
|
| 88 |
+
pred_fake_image_cond - pred_fake_image_uncond
|
| 89 |
+
) * self.fake_guidance_scale
|
| 90 |
+
else:
|
| 91 |
+
pred_fake_image = pred_fake_image_cond
|
| 92 |
+
|
| 93 |
+
# Step 2: Compute the real score
|
| 94 |
+
# We compute the conditional and unconditional prediction
|
| 95 |
+
# and add them together to achieve cfg (https://arxiv.org/abs/2207.12598)
|
| 96 |
+
_, pred_real_image_cond = self.real_score(
|
| 97 |
+
noisy_image_or_video=noisy_image_or_video,
|
| 98 |
+
conditional_dict=conditional_dict,
|
| 99 |
+
timestep=timestep
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
_, pred_real_image_uncond = self.real_score(
|
| 103 |
+
noisy_image_or_video=noisy_image_or_video,
|
| 104 |
+
conditional_dict=unconditional_dict,
|
| 105 |
+
timestep=timestep
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
pred_real_image = pred_real_image_cond + (
|
| 109 |
+
pred_real_image_cond - pred_real_image_uncond
|
| 110 |
+
) * self.real_guidance_scale
|
| 111 |
+
|
| 112 |
+
# Step 3: Compute the DMD gradient (DMD paper eq. 7).
|
| 113 |
+
grad = (pred_fake_image - pred_real_image)
|
| 114 |
+
|
| 115 |
+
# TODO: Change the normalizer for causal teacher
|
| 116 |
+
if normalization:
|
| 117 |
+
# Step 4: Gradient normalization (DMD paper eq. 8).
|
| 118 |
+
p_real = (estimated_clean_image_or_video - pred_real_image)
|
| 119 |
+
normalizer = torch.abs(p_real).mean(dim=[1, 2, 3, 4], keepdim=True)
|
| 120 |
+
grad = grad / normalizer
|
| 121 |
+
grad = torch.nan_to_num(grad)
|
| 122 |
+
|
| 123 |
+
return grad, {
|
| 124 |
+
"dmdtrain_gradient_norm": torch.mean(torch.abs(grad)).detach(),
|
| 125 |
+
"timestep": timestep.detach()
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
def compute_distribution_matching_loss(
|
| 129 |
+
self,
|
| 130 |
+
image_or_video: torch.Tensor,
|
| 131 |
+
conditional_dict: dict,
|
| 132 |
+
unconditional_dict: dict,
|
| 133 |
+
gradient_mask: Optional[torch.Tensor] = None,
|
| 134 |
+
denoised_timestep_from: int = 0,
|
| 135 |
+
denoised_timestep_to: int = 0
|
| 136 |
+
) -> Tuple[torch.Tensor, dict]:
|
| 137 |
+
"""
|
| 138 |
+
Compute the DMD loss (eq 7 in https://arxiv.org/abs/2311.18828).
|
| 139 |
+
Input:
|
| 140 |
+
- image_or_video: a tensor with shape [B, F, C, H, W] where the number of frame is 1 for images.
|
| 141 |
+
- conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
|
| 142 |
+
- unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings).
|
| 143 |
+
- gradient_mask: a boolean tensor with the same shape as image_or_video indicating which pixels to compute loss .
|
| 144 |
+
Output:
|
| 145 |
+
- dmd_loss: a scalar tensor representing the DMD loss.
|
| 146 |
+
- dmd_log_dict: a dictionary containing the intermediate tensors for logging.
|
| 147 |
+
"""
|
| 148 |
+
original_latent = image_or_video
|
| 149 |
+
|
| 150 |
+
batch_size, num_frame = image_or_video.shape[:2]
|
| 151 |
+
|
| 152 |
+
with torch.no_grad():
|
| 153 |
+
# Step 1: Randomly sample timestep based on the given schedule and corresponding noise
|
| 154 |
+
min_timestep = denoised_timestep_to if self.ts_schedule and denoised_timestep_to is not None else self.min_score_timestep
|
| 155 |
+
max_timestep = denoised_timestep_from if self.ts_schedule_max and denoised_timestep_from is not None else self.num_train_timestep
|
| 156 |
+
timestep = self._get_timestep(
|
| 157 |
+
min_timestep,
|
| 158 |
+
max_timestep,
|
| 159 |
+
batch_size,
|
| 160 |
+
num_frame,
|
| 161 |
+
self.num_frame_per_block,
|
| 162 |
+
uniform_timestep=True
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
# TODO:should we change it to `timestep = self.scheduler.timesteps[timestep]`?
|
| 166 |
+
if self.timestep_shift > 1:
|
| 167 |
+
timestep = self.timestep_shift * \
|
| 168 |
+
(timestep / 1000) / \
|
| 169 |
+
(1 + (self.timestep_shift - 1) * (timestep / 1000)) * 1000
|
| 170 |
+
timestep = timestep.clamp(self.min_step, self.max_step)
|
| 171 |
+
|
| 172 |
+
noise = torch.randn_like(image_or_video)
|
| 173 |
+
noisy_latent = self.scheduler.add_noise(
|
| 174 |
+
image_or_video.flatten(0, 1),
|
| 175 |
+
noise.flatten(0, 1),
|
| 176 |
+
timestep.flatten(0, 1)
|
| 177 |
+
).detach().unflatten(0, (batch_size, num_frame))
|
| 178 |
+
|
| 179 |
+
# Step 2: Compute the KL grad
|
| 180 |
+
grad, dmd_log_dict = self._compute_kl_grad(
|
| 181 |
+
noisy_image_or_video=noisy_latent,
|
| 182 |
+
estimated_clean_image_or_video=original_latent,
|
| 183 |
+
timestep=timestep,
|
| 184 |
+
conditional_dict=conditional_dict,
|
| 185 |
+
unconditional_dict=unconditional_dict
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
if gradient_mask is not None:
|
| 189 |
+
dmd_loss = 0.5 * F.mse_loss(original_latent.double(
|
| 190 |
+
)[gradient_mask], (original_latent.double() - grad.double()).detach()[gradient_mask], reduction="mean")
|
| 191 |
+
else:
|
| 192 |
+
dmd_loss = 0.5 * F.mse_loss(original_latent.double(
|
| 193 |
+
), (original_latent.double() - grad.double()).detach(), reduction="mean")
|
| 194 |
+
return dmd_loss, dmd_log_dict
|
| 195 |
+
|
| 196 |
+
def generator_loss(
|
| 197 |
+
self,
|
| 198 |
+
image_or_video_shape,
|
| 199 |
+
conditional_dict: dict,
|
| 200 |
+
unconditional_dict: dict,
|
| 201 |
+
clean_latent: torch.Tensor,
|
| 202 |
+
initial_latent: torch.Tensor = None
|
| 203 |
+
) -> Tuple[torch.Tensor, dict]:
|
| 204 |
+
"""
|
| 205 |
+
Generate image/videos from noise and compute the DMD loss.
|
| 206 |
+
The noisy input to the generator is backward simulated.
|
| 207 |
+
This removes the need of any datasets during distillation.
|
| 208 |
+
See Sec 4.5 of the DMD2 paper (https://arxiv.org/abs/2405.14867) for details.
|
| 209 |
+
Input:
|
| 210 |
+
- image_or_video_shape: a list containing the shape of the image or video [B, F, C, H, W].
|
| 211 |
+
- conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
|
| 212 |
+
- unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings).
|
| 213 |
+
- clean_latent: a tensor containing the clean latents [B, F, C, H, W]. Need to be passed when no backward simulation is used.
|
| 214 |
+
Output:
|
| 215 |
+
- loss: a scalar tensor representing the generator loss.
|
| 216 |
+
- generator_log_dict: a dictionary containing the intermediate tensors for logging.
|
| 217 |
+
"""
|
| 218 |
+
# Step 1: Unroll generator to obtain fake videos
|
| 219 |
+
pred_image, gradient_mask, denoised_timestep_from, denoised_timestep_to = self._run_generator(
|
| 220 |
+
image_or_video_shape=image_or_video_shape,
|
| 221 |
+
conditional_dict=conditional_dict,
|
| 222 |
+
initial_latent=initial_latent
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
# Step 2: Compute the DMD loss
|
| 226 |
+
dmd_loss, dmd_log_dict = self.compute_distribution_matching_loss(
|
| 227 |
+
image_or_video=pred_image,
|
| 228 |
+
conditional_dict=conditional_dict,
|
| 229 |
+
unconditional_dict=unconditional_dict,
|
| 230 |
+
gradient_mask=gradient_mask,
|
| 231 |
+
denoised_timestep_from=denoised_timestep_from,
|
| 232 |
+
denoised_timestep_to=denoised_timestep_to
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
return dmd_loss, dmd_log_dict
|
| 236 |
+
|
| 237 |
+
def critic_loss(
|
| 238 |
+
self,
|
| 239 |
+
image_or_video_shape,
|
| 240 |
+
conditional_dict: dict,
|
| 241 |
+
unconditional_dict: dict,
|
| 242 |
+
clean_latent: torch.Tensor,
|
| 243 |
+
initial_latent: torch.Tensor = None
|
| 244 |
+
) -> Tuple[torch.Tensor, dict]:
|
| 245 |
+
"""
|
| 246 |
+
Generate image/videos from noise and train the critic with generated samples.
|
| 247 |
+
The noisy input to the generator is backward simulated.
|
| 248 |
+
This removes the need of any datasets during distillation.
|
| 249 |
+
See Sec 4.5 of the DMD2 paper (https://arxiv.org/abs/2405.14867) for details.
|
| 250 |
+
Input:
|
| 251 |
+
- image_or_video_shape: a list containing the shape of the image or video [B, F, C, H, W].
|
| 252 |
+
- conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
|
| 253 |
+
- unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings).
|
| 254 |
+
- clean_latent: a tensor containing the clean latents [B, F, C, H, W]. Need to be passed when no backward simulation is used.
|
| 255 |
+
Output:
|
| 256 |
+
- loss: a scalar tensor representing the generator loss.
|
| 257 |
+
- critic_log_dict: a dictionary containing the intermediate tensors for logging.
|
| 258 |
+
"""
|
| 259 |
+
|
| 260 |
+
# Step 1: Run generator on backward simulated noisy input
|
| 261 |
+
with torch.no_grad():
|
| 262 |
+
generated_image, _, denoised_timestep_from, denoised_timestep_to = self._run_generator(
|
| 263 |
+
image_or_video_shape=image_or_video_shape,
|
| 264 |
+
conditional_dict=conditional_dict,
|
| 265 |
+
initial_latent=initial_latent
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
# Step 2: Compute the fake prediction
|
| 269 |
+
min_timestep = denoised_timestep_to if self.ts_schedule and denoised_timestep_to is not None else self.min_score_timestep
|
| 270 |
+
max_timestep = denoised_timestep_from if self.ts_schedule_max and denoised_timestep_from is not None else self.num_train_timestep
|
| 271 |
+
critic_timestep = self._get_timestep(
|
| 272 |
+
min_timestep,
|
| 273 |
+
max_timestep,
|
| 274 |
+
image_or_video_shape[0],
|
| 275 |
+
image_or_video_shape[1],
|
| 276 |
+
self.num_frame_per_block,
|
| 277 |
+
uniform_timestep=True
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
if self.timestep_shift > 1:
|
| 281 |
+
critic_timestep = self.timestep_shift * \
|
| 282 |
+
(critic_timestep / 1000) / (1 + (self.timestep_shift - 1) * (critic_timestep / 1000)) * 1000
|
| 283 |
+
|
| 284 |
+
critic_timestep = critic_timestep.clamp(self.min_step, self.max_step)
|
| 285 |
+
|
| 286 |
+
critic_noise = torch.randn_like(generated_image)
|
| 287 |
+
noisy_generated_image = self.scheduler.add_noise(
|
| 288 |
+
generated_image.flatten(0, 1),
|
| 289 |
+
critic_noise.flatten(0, 1),
|
| 290 |
+
critic_timestep.flatten(0, 1)
|
| 291 |
+
).unflatten(0, image_or_video_shape[:2])
|
| 292 |
+
|
| 293 |
+
_, pred_fake_image = self.fake_score(
|
| 294 |
+
noisy_image_or_video=noisy_generated_image,
|
| 295 |
+
conditional_dict=conditional_dict,
|
| 296 |
+
timestep=critic_timestep
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
# Step 3: Compute the denoising loss for the fake critic
|
| 300 |
+
if self.args.denoising_loss_type == "flow":
|
| 301 |
+
from utils.wan_wrapper import WanDiffusionWrapper
|
| 302 |
+
flow_pred = WanDiffusionWrapper._convert_x0_to_flow_pred(
|
| 303 |
+
scheduler=self.scheduler,
|
| 304 |
+
x0_pred=pred_fake_image.flatten(0, 1),
|
| 305 |
+
xt=noisy_generated_image.flatten(0, 1),
|
| 306 |
+
timestep=critic_timestep.flatten(0, 1)
|
| 307 |
+
)
|
| 308 |
+
pred_fake_noise = None
|
| 309 |
+
else:
|
| 310 |
+
flow_pred = None
|
| 311 |
+
pred_fake_noise = self.scheduler.convert_x0_to_noise(
|
| 312 |
+
x0=pred_fake_image.flatten(0, 1),
|
| 313 |
+
xt=noisy_generated_image.flatten(0, 1),
|
| 314 |
+
timestep=critic_timestep.flatten(0, 1)
|
| 315 |
+
).unflatten(0, image_or_video_shape[:2])
|
| 316 |
+
|
| 317 |
+
denoising_loss = self.denoising_loss_func(
|
| 318 |
+
x=generated_image.flatten(0, 1),
|
| 319 |
+
x_pred=pred_fake_image.flatten(0, 1),
|
| 320 |
+
noise=critic_noise.flatten(0, 1),
|
| 321 |
+
noise_pred=pred_fake_noise,
|
| 322 |
+
alphas_cumprod=self.scheduler.alphas_cumprod,
|
| 323 |
+
timestep=critic_timestep.flatten(0, 1),
|
| 324 |
+
flow_pred=flow_pred
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
# Step 5: Debugging Log
|
| 328 |
+
critic_log_dict = {
|
| 329 |
+
"critic_timestep": critic_timestep.detach()
|
| 330 |
+
}
|
| 331 |
+
|
| 332 |
+
return denoising_loss, critic_log_dict
|
model/gan.py
ADDED
|
@@ -0,0 +1,295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
from pipeline import RollingForcingTrainingPipeline
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from typing import Tuple
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
from model.base import RollingForcingModel
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class GAN(RollingForcingModel):
|
| 11 |
+
def __init__(self, args, device):
|
| 12 |
+
"""
|
| 13 |
+
Initialize the GAN module.
|
| 14 |
+
This class is self-contained and compute generator and fake score losses
|
| 15 |
+
in the forward pass.
|
| 16 |
+
"""
|
| 17 |
+
super().__init__(args, device)
|
| 18 |
+
self.num_frame_per_block = getattr(args, "num_frame_per_block", 1)
|
| 19 |
+
self.same_step_across_blocks = getattr(args, "same_step_across_blocks", True)
|
| 20 |
+
self.concat_time_embeddings = getattr(args, "concat_time_embeddings", False)
|
| 21 |
+
self.num_class = args.num_class
|
| 22 |
+
self.relativistic_discriminator = getattr(args, "relativistic_discriminator", False)
|
| 23 |
+
|
| 24 |
+
if self.num_frame_per_block > 1:
|
| 25 |
+
self.generator.model.num_frame_per_block = self.num_frame_per_block
|
| 26 |
+
|
| 27 |
+
self.fake_score.adding_cls_branch(
|
| 28 |
+
atten_dim=1536, num_class=args.num_class, time_embed_dim=1536 if self.concat_time_embeddings else 0)
|
| 29 |
+
self.fake_score.model.requires_grad_(True)
|
| 30 |
+
|
| 31 |
+
self.independent_first_frame = getattr(args, "independent_first_frame", False)
|
| 32 |
+
if self.independent_first_frame:
|
| 33 |
+
self.generator.model.independent_first_frame = True
|
| 34 |
+
if args.gradient_checkpointing:
|
| 35 |
+
self.generator.enable_gradient_checkpointing()
|
| 36 |
+
self.fake_score.enable_gradient_checkpointing()
|
| 37 |
+
|
| 38 |
+
# this will be init later with fsdp-wrapped modules
|
| 39 |
+
self.inference_pipeline: RollingForcingTrainingPipeline = None
|
| 40 |
+
|
| 41 |
+
# Step 2: Initialize all dmd hyperparameters
|
| 42 |
+
self.num_train_timestep = args.num_train_timestep
|
| 43 |
+
self.min_step = int(0.02 * self.num_train_timestep)
|
| 44 |
+
self.max_step = int(0.98 * self.num_train_timestep)
|
| 45 |
+
if hasattr(args, "real_guidance_scale"):
|
| 46 |
+
self.real_guidance_scale = args.real_guidance_scale
|
| 47 |
+
self.fake_guidance_scale = args.fake_guidance_scale
|
| 48 |
+
else:
|
| 49 |
+
self.real_guidance_scale = args.guidance_scale
|
| 50 |
+
self.fake_guidance_scale = 0.0
|
| 51 |
+
self.timestep_shift = getattr(args, "timestep_shift", 1.0)
|
| 52 |
+
self.critic_timestep_shift = getattr(args, "critic_timestep_shift", self.timestep_shift)
|
| 53 |
+
self.ts_schedule = getattr(args, "ts_schedule", True)
|
| 54 |
+
self.ts_schedule_max = getattr(args, "ts_schedule_max", False)
|
| 55 |
+
self.min_score_timestep = getattr(args, "min_score_timestep", 0)
|
| 56 |
+
|
| 57 |
+
self.gan_g_weight = getattr(args, "gan_g_weight", 1e-2)
|
| 58 |
+
self.gan_d_weight = getattr(args, "gan_d_weight", 1e-2)
|
| 59 |
+
self.r1_weight = getattr(args, "r1_weight", 0.0)
|
| 60 |
+
self.r2_weight = getattr(args, "r2_weight", 0.0)
|
| 61 |
+
self.r1_sigma = getattr(args, "r1_sigma", 0.01)
|
| 62 |
+
self.r2_sigma = getattr(args, "r2_sigma", 0.01)
|
| 63 |
+
|
| 64 |
+
if getattr(self.scheduler, "alphas_cumprod", None) is not None:
|
| 65 |
+
self.scheduler.alphas_cumprod = self.scheduler.alphas_cumprod.to(device)
|
| 66 |
+
else:
|
| 67 |
+
self.scheduler.alphas_cumprod = None
|
| 68 |
+
|
| 69 |
+
def _run_cls_pred_branch(self,
|
| 70 |
+
noisy_image_or_video: torch.Tensor,
|
| 71 |
+
conditional_dict: dict,
|
| 72 |
+
timestep: torch.Tensor) -> torch.Tensor:
|
| 73 |
+
"""
|
| 74 |
+
Run the classifier prediction branch on the generated image or video.
|
| 75 |
+
Input:
|
| 76 |
+
- image_or_video: a tensor with shape [B, F, C, H, W].
|
| 77 |
+
Output:
|
| 78 |
+
- cls_pred: a tensor with shape [B, 1, 1, 1, 1] representing the feature map for classification.
|
| 79 |
+
"""
|
| 80 |
+
_, _, noisy_logit = self.fake_score(
|
| 81 |
+
noisy_image_or_video=noisy_image_or_video,
|
| 82 |
+
conditional_dict=conditional_dict,
|
| 83 |
+
timestep=timestep,
|
| 84 |
+
classify_mode=True,
|
| 85 |
+
concat_time_embeddings=self.concat_time_embeddings
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
return noisy_logit
|
| 89 |
+
|
| 90 |
+
def generator_loss(
|
| 91 |
+
self,
|
| 92 |
+
image_or_video_shape,
|
| 93 |
+
conditional_dict: dict,
|
| 94 |
+
unconditional_dict: dict,
|
| 95 |
+
clean_latent: torch.Tensor,
|
| 96 |
+
initial_latent: torch.Tensor = None
|
| 97 |
+
) -> Tuple[torch.Tensor, dict]:
|
| 98 |
+
"""
|
| 99 |
+
Generate image/videos from noise and compute the DMD loss.
|
| 100 |
+
The noisy input to the generator is backward simulated.
|
| 101 |
+
This removes the need of any datasets during distillation.
|
| 102 |
+
See Sec 4.5 of the DMD2 paper (https://arxiv.org/abs/2405.14867) for details.
|
| 103 |
+
Input:
|
| 104 |
+
- image_or_video_shape: a list containing the shape of the image or video [B, F, C, H, W].
|
| 105 |
+
- conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
|
| 106 |
+
- unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings).
|
| 107 |
+
- clean_latent: a tensor containing the clean latents [B, F, C, H, W]. Need to be passed when no backward simulation is used.
|
| 108 |
+
Output:
|
| 109 |
+
- loss: a scalar tensor representing the generator loss.
|
| 110 |
+
- generator_log_dict: a dictionary containing the intermediate tensors for logging.
|
| 111 |
+
"""
|
| 112 |
+
# Step 1: Unroll generator to obtain fake videos
|
| 113 |
+
pred_image, gradient_mask, denoised_timestep_from, denoised_timestep_to = self._run_generator(
|
| 114 |
+
image_or_video_shape=image_or_video_shape,
|
| 115 |
+
conditional_dict=conditional_dict,
|
| 116 |
+
initial_latent=initial_latent
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
# Step 2: Get timestep and add noise to generated/real latents
|
| 120 |
+
min_timestep = denoised_timestep_to if self.ts_schedule and denoised_timestep_to is not None else self.min_score_timestep
|
| 121 |
+
max_timestep = denoised_timestep_from if self.ts_schedule_max and denoised_timestep_from is not None else self.num_train_timestep
|
| 122 |
+
critic_timestep = self._get_timestep(
|
| 123 |
+
min_timestep,
|
| 124 |
+
max_timestep,
|
| 125 |
+
image_or_video_shape[0],
|
| 126 |
+
image_or_video_shape[1],
|
| 127 |
+
self.num_frame_per_block,
|
| 128 |
+
uniform_timestep=True
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
if self.critic_timestep_shift > 1:
|
| 132 |
+
critic_timestep = self.critic_timestep_shift * \
|
| 133 |
+
(critic_timestep / 1000) / (1 + (self.critic_timestep_shift - 1) * (critic_timestep / 1000)) * 1000
|
| 134 |
+
|
| 135 |
+
critic_timestep = critic_timestep.clamp(self.min_step, self.max_step)
|
| 136 |
+
|
| 137 |
+
critic_noise = torch.randn_like(pred_image)
|
| 138 |
+
noisy_fake_latent = self.scheduler.add_noise(
|
| 139 |
+
pred_image.flatten(0, 1),
|
| 140 |
+
critic_noise.flatten(0, 1),
|
| 141 |
+
critic_timestep.flatten(0, 1)
|
| 142 |
+
).unflatten(0, image_or_video_shape[:2])
|
| 143 |
+
|
| 144 |
+
# Step 4: Compute the real GAN discriminator loss
|
| 145 |
+
real_image_or_video = clean_latent.clone()
|
| 146 |
+
critic_noise = torch.randn_like(real_image_or_video)
|
| 147 |
+
noisy_real_latent = self.scheduler.add_noise(
|
| 148 |
+
real_image_or_video.flatten(0, 1),
|
| 149 |
+
critic_noise.flatten(0, 1),
|
| 150 |
+
critic_timestep.flatten(0, 1)
|
| 151 |
+
).unflatten(0, image_or_video_shape[:2])
|
| 152 |
+
|
| 153 |
+
conditional_dict["prompt_embeds"] = torch.concatenate(
|
| 154 |
+
(conditional_dict["prompt_embeds"], conditional_dict["prompt_embeds"]), dim=0)
|
| 155 |
+
critic_timestep = torch.concatenate((critic_timestep, critic_timestep), dim=0)
|
| 156 |
+
noisy_latent = torch.concatenate((noisy_fake_latent, noisy_real_latent), dim=0)
|
| 157 |
+
_, _, noisy_logit = self.fake_score(
|
| 158 |
+
noisy_image_or_video=noisy_latent,
|
| 159 |
+
conditional_dict=conditional_dict,
|
| 160 |
+
timestep=critic_timestep,
|
| 161 |
+
classify_mode=True,
|
| 162 |
+
concat_time_embeddings=self.concat_time_embeddings
|
| 163 |
+
)
|
| 164 |
+
noisy_fake_logit, noisy_real_logit = noisy_logit.chunk(2, dim=0)
|
| 165 |
+
|
| 166 |
+
if not self.relativistic_discriminator:
|
| 167 |
+
gan_G_loss = F.softplus(-noisy_fake_logit.float()).mean() * self.gan_g_weight
|
| 168 |
+
else:
|
| 169 |
+
relative_fake_logit = noisy_fake_logit - noisy_real_logit
|
| 170 |
+
gan_G_loss = F.softplus(-relative_fake_logit.float()).mean() * self.gan_g_weight
|
| 171 |
+
|
| 172 |
+
return gan_G_loss
|
| 173 |
+
|
| 174 |
+
def critic_loss(
|
| 175 |
+
self,
|
| 176 |
+
image_or_video_shape,
|
| 177 |
+
conditional_dict: dict,
|
| 178 |
+
unconditional_dict: dict,
|
| 179 |
+
clean_latent: torch.Tensor,
|
| 180 |
+
real_image_or_video: torch.Tensor,
|
| 181 |
+
initial_latent: torch.Tensor = None
|
| 182 |
+
) -> Tuple[torch.Tensor, dict]:
|
| 183 |
+
"""
|
| 184 |
+
Generate image/videos from noise and train the critic with generated samples.
|
| 185 |
+
The noisy input to the generator is backward simulated.
|
| 186 |
+
This removes the need of any datasets during distillation.
|
| 187 |
+
See Sec 4.5 of the DMD2 paper (https://arxiv.org/abs/2405.14867) for details.
|
| 188 |
+
Input:
|
| 189 |
+
- image_or_video_shape: a list containing the shape of the image or video [B, F, C, H, W].
|
| 190 |
+
- conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
|
| 191 |
+
- unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings).
|
| 192 |
+
- clean_latent: a tensor containing the clean latents [B, F, C, H, W]. Need to be passed when no backward simulation is used.
|
| 193 |
+
Output:
|
| 194 |
+
- loss: a scalar tensor representing the generator loss.
|
| 195 |
+
- critic_log_dict: a dictionary containing the intermediate tensors for logging.
|
| 196 |
+
"""
|
| 197 |
+
|
| 198 |
+
# Step 1: Run generator on backward simulated noisy input
|
| 199 |
+
with torch.no_grad():
|
| 200 |
+
generated_image, _, denoised_timestep_from, denoised_timestep_to, num_sim_steps = self._run_generator(
|
| 201 |
+
image_or_video_shape=image_or_video_shape,
|
| 202 |
+
conditional_dict=conditional_dict,
|
| 203 |
+
initial_latent=initial_latent
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
# Step 2: Get timestep and add noise to generated/real latents
|
| 207 |
+
min_timestep = denoised_timestep_to if self.ts_schedule and denoised_timestep_to is not None else self.min_score_timestep
|
| 208 |
+
max_timestep = denoised_timestep_from if self.ts_schedule_max and denoised_timestep_from is not None else self.num_train_timestep
|
| 209 |
+
critic_timestep = self._get_timestep(
|
| 210 |
+
min_timestep,
|
| 211 |
+
max_timestep,
|
| 212 |
+
image_or_video_shape[0],
|
| 213 |
+
image_or_video_shape[1],
|
| 214 |
+
self.num_frame_per_block,
|
| 215 |
+
uniform_timestep=True
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
if self.critic_timestep_shift > 1:
|
| 219 |
+
critic_timestep = self.critic_timestep_shift * \
|
| 220 |
+
(critic_timestep / 1000) / (1 + (self.critic_timestep_shift - 1) * (critic_timestep / 1000)) * 1000
|
| 221 |
+
|
| 222 |
+
critic_timestep = critic_timestep.clamp(self.min_step, self.max_step)
|
| 223 |
+
|
| 224 |
+
critic_noise = torch.randn_like(generated_image)
|
| 225 |
+
noisy_fake_latent = self.scheduler.add_noise(
|
| 226 |
+
generated_image.flatten(0, 1),
|
| 227 |
+
critic_noise.flatten(0, 1),
|
| 228 |
+
critic_timestep.flatten(0, 1)
|
| 229 |
+
).unflatten(0, image_or_video_shape[:2])
|
| 230 |
+
|
| 231 |
+
# Step 4: Compute the real GAN discriminator loss
|
| 232 |
+
noisy_real_latent = self.scheduler.add_noise(
|
| 233 |
+
real_image_or_video.flatten(0, 1),
|
| 234 |
+
critic_noise.flatten(0, 1),
|
| 235 |
+
critic_timestep.flatten(0, 1)
|
| 236 |
+
).unflatten(0, image_or_video_shape[:2])
|
| 237 |
+
|
| 238 |
+
conditional_dict_cloned = copy.deepcopy(conditional_dict)
|
| 239 |
+
conditional_dict_cloned["prompt_embeds"] = torch.concatenate(
|
| 240 |
+
(conditional_dict_cloned["prompt_embeds"], conditional_dict_cloned["prompt_embeds"]), dim=0)
|
| 241 |
+
_, _, noisy_logit = self.fake_score(
|
| 242 |
+
noisy_image_or_video=torch.concatenate((noisy_fake_latent, noisy_real_latent), dim=0),
|
| 243 |
+
conditional_dict=conditional_dict_cloned,
|
| 244 |
+
timestep=torch.concatenate((critic_timestep, critic_timestep), dim=0),
|
| 245 |
+
classify_mode=True,
|
| 246 |
+
concat_time_embeddings=self.concat_time_embeddings
|
| 247 |
+
)
|
| 248 |
+
noisy_fake_logit, noisy_real_logit = noisy_logit.chunk(2, dim=0)
|
| 249 |
+
|
| 250 |
+
if not self.relativistic_discriminator:
|
| 251 |
+
gan_D_loss = F.softplus(-noisy_real_logit.float()).mean() + F.softplus(noisy_fake_logit.float()).mean()
|
| 252 |
+
else:
|
| 253 |
+
relative_real_logit = noisy_real_logit - noisy_fake_logit
|
| 254 |
+
gan_D_loss = F.softplus(-relative_real_logit.float()).mean()
|
| 255 |
+
gan_D_loss = gan_D_loss * self.gan_d_weight
|
| 256 |
+
|
| 257 |
+
# R1 regularization
|
| 258 |
+
if self.r1_weight > 0.:
|
| 259 |
+
noisy_real_latent_perturbed = noisy_real_latent.clone()
|
| 260 |
+
epison_real = self.r1_sigma * torch.randn_like(noisy_real_latent_perturbed)
|
| 261 |
+
noisy_real_latent_perturbed = noisy_real_latent_perturbed + epison_real
|
| 262 |
+
noisy_real_logit_perturbed = self._run_cls_pred_branch(
|
| 263 |
+
noisy_image_or_video=noisy_real_latent_perturbed,
|
| 264 |
+
conditional_dict=conditional_dict,
|
| 265 |
+
timestep=critic_timestep
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
r1_grad = (noisy_real_logit_perturbed - noisy_real_logit) / self.r1_sigma
|
| 269 |
+
r1_loss = self.r1_weight * torch.mean((r1_grad)**2)
|
| 270 |
+
else:
|
| 271 |
+
r1_loss = torch.zeros_like(gan_D_loss)
|
| 272 |
+
|
| 273 |
+
# R2 regularization
|
| 274 |
+
if self.r2_weight > 0.:
|
| 275 |
+
noisy_fake_latent_perturbed = noisy_fake_latent.clone()
|
| 276 |
+
epison_generated = self.r2_sigma * torch.randn_like(noisy_fake_latent_perturbed)
|
| 277 |
+
noisy_fake_latent_perturbed = noisy_fake_latent_perturbed + epison_generated
|
| 278 |
+
noisy_fake_logit_perturbed = self._run_cls_pred_branch(
|
| 279 |
+
noisy_image_or_video=noisy_fake_latent_perturbed,
|
| 280 |
+
conditional_dict=conditional_dict,
|
| 281 |
+
timestep=critic_timestep
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
r2_grad = (noisy_fake_logit_perturbed - noisy_fake_logit) / self.r2_sigma
|
| 285 |
+
r2_loss = self.r2_weight * torch.mean((r2_grad)**2)
|
| 286 |
+
else:
|
| 287 |
+
r2_loss = torch.zeros_like(r2_loss)
|
| 288 |
+
|
| 289 |
+
critic_log_dict = {
|
| 290 |
+
"critic_timestep": critic_timestep.detach(),
|
| 291 |
+
'noisy_real_logit': noisy_real_logit.detach(),
|
| 292 |
+
'noisy_fake_logit': noisy_fake_logit.detach(),
|
| 293 |
+
}
|
| 294 |
+
|
| 295 |
+
return (gan_D_loss, r1_loss, r2_loss), critic_log_dict
|
model/ode_regression.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn.functional as F
|
| 2 |
+
from typing import Tuple
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from model.base import BaseModel
|
| 6 |
+
from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder, WanVAEWrapper
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class ODERegression(BaseModel):
|
| 10 |
+
def __init__(self, args, device):
|
| 11 |
+
"""
|
| 12 |
+
Initialize the ODERegression module.
|
| 13 |
+
This class is self-contained and compute generator losses
|
| 14 |
+
in the forward pass given precomputed ode solution pairs.
|
| 15 |
+
This class supports the ode regression loss for both causal and bidirectional models.
|
| 16 |
+
See Sec 4.3 of CausVid https://arxiv.org/abs/2412.07772 for details
|
| 17 |
+
"""
|
| 18 |
+
super().__init__(args, device)
|
| 19 |
+
|
| 20 |
+
# Step 1: Initialize all models
|
| 21 |
+
|
| 22 |
+
self.generator = WanDiffusionWrapper(**getattr(args, "model_kwargs", {}), is_causal=True)
|
| 23 |
+
self.generator.model.requires_grad_(True)
|
| 24 |
+
if getattr(args, "generator_ckpt", False):
|
| 25 |
+
print(f"Loading pretrained generator from {args.generator_ckpt}")
|
| 26 |
+
state_dict = torch.load(args.generator_ckpt, map_location="cpu")[
|
| 27 |
+
'generator']
|
| 28 |
+
self.generator.load_state_dict(
|
| 29 |
+
state_dict, strict=True
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
self.num_frame_per_block = getattr(args, "num_frame_per_block", 1)
|
| 33 |
+
|
| 34 |
+
if self.num_frame_per_block > 1:
|
| 35 |
+
self.generator.model.num_frame_per_block = self.num_frame_per_block
|
| 36 |
+
|
| 37 |
+
self.independent_first_frame = getattr(args, "independent_first_frame", False)
|
| 38 |
+
if self.independent_first_frame:
|
| 39 |
+
self.generator.model.independent_first_frame = True
|
| 40 |
+
if args.gradient_checkpointing:
|
| 41 |
+
self.generator.enable_gradient_checkpointing()
|
| 42 |
+
|
| 43 |
+
# Step 2: Initialize all hyperparameters
|
| 44 |
+
self.timestep_shift = getattr(args, "timestep_shift", 1.0)
|
| 45 |
+
|
| 46 |
+
def _initialize_models(self, args):
|
| 47 |
+
self.generator = WanDiffusionWrapper(**getattr(args, "model_kwargs", {}), is_causal=True)
|
| 48 |
+
self.generator.model.requires_grad_(True)
|
| 49 |
+
|
| 50 |
+
self.text_encoder = WanTextEncoder()
|
| 51 |
+
self.text_encoder.requires_grad_(False)
|
| 52 |
+
|
| 53 |
+
self.vae = WanVAEWrapper()
|
| 54 |
+
self.vae.requires_grad_(False)
|
| 55 |
+
|
| 56 |
+
@torch.no_grad()
|
| 57 |
+
def _prepare_generator_input(self, ode_latent: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 58 |
+
"""
|
| 59 |
+
Given a tensor containing the whole ODE sampling trajectories,
|
| 60 |
+
randomly choose an intermediate timestep and return the latent as well as the corresponding timestep.
|
| 61 |
+
Input:
|
| 62 |
+
- ode_latent: a tensor containing the whole ODE sampling trajectories [batch_size, num_denoising_steps, num_frames, num_channels, height, width].
|
| 63 |
+
Output:
|
| 64 |
+
- noisy_input: a tensor containing the selected latent [batch_size, num_frames, num_channels, height, width].
|
| 65 |
+
- timestep: a tensor containing the corresponding timestep [batch_size].
|
| 66 |
+
"""
|
| 67 |
+
batch_size, num_denoising_steps, num_frames, num_channels, height, width = ode_latent.shape
|
| 68 |
+
|
| 69 |
+
# Step 1: Randomly choose a timestep for each frame
|
| 70 |
+
index = self._get_timestep(
|
| 71 |
+
0,
|
| 72 |
+
len(self.denoising_step_list),
|
| 73 |
+
batch_size,
|
| 74 |
+
num_frames,
|
| 75 |
+
self.num_frame_per_block,
|
| 76 |
+
uniform_timestep=False
|
| 77 |
+
)
|
| 78 |
+
if self.args.i2v:
|
| 79 |
+
index[:, 0] = len(self.denoising_step_list) - 1
|
| 80 |
+
|
| 81 |
+
noisy_input = torch.gather(
|
| 82 |
+
ode_latent, dim=1,
|
| 83 |
+
index=index.reshape(batch_size, 1, num_frames, 1, 1, 1).expand(
|
| 84 |
+
-1, -1, -1, num_channels, height, width).to(self.device)
|
| 85 |
+
).squeeze(1)
|
| 86 |
+
|
| 87 |
+
timestep = self.denoising_step_list[index].to(self.device)
|
| 88 |
+
|
| 89 |
+
# if self.extra_noise_step > 0:
|
| 90 |
+
# random_timestep = torch.randint(0, self.extra_noise_step, [
|
| 91 |
+
# batch_size, num_frames], device=self.device, dtype=torch.long)
|
| 92 |
+
# perturbed_noisy_input = self.scheduler.add_noise(
|
| 93 |
+
# noisy_input.flatten(0, 1),
|
| 94 |
+
# torch.randn_like(noisy_input.flatten(0, 1)),
|
| 95 |
+
# random_timestep.flatten(0, 1)
|
| 96 |
+
# ).detach().unflatten(0, (batch_size, num_frames)).type_as(noisy_input)
|
| 97 |
+
|
| 98 |
+
# noisy_input[timestep == 0] = perturbed_noisy_input[timestep == 0]
|
| 99 |
+
|
| 100 |
+
return noisy_input, timestep
|
| 101 |
+
|
| 102 |
+
def generator_loss(self, ode_latent: torch.Tensor, conditional_dict: dict) -> Tuple[torch.Tensor, dict]:
|
| 103 |
+
"""
|
| 104 |
+
Generate image/videos from noisy latents and compute the ODE regression loss.
|
| 105 |
+
Input:
|
| 106 |
+
- ode_latent: a tensor containing the ODE latents [batch_size, num_denoising_steps, num_frames, num_channels, height, width].
|
| 107 |
+
They are ordered from most noisy to clean latents.
|
| 108 |
+
- conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
|
| 109 |
+
Output:
|
| 110 |
+
- loss: a scalar tensor representing the generator loss.
|
| 111 |
+
- log_dict: a dictionary containing additional information for loss timestep breakdown.
|
| 112 |
+
"""
|
| 113 |
+
# Step 1: Run generator on noisy latents
|
| 114 |
+
target_latent = ode_latent[:, -1]
|
| 115 |
+
|
| 116 |
+
noisy_input, timestep = self._prepare_generator_input(
|
| 117 |
+
ode_latent=ode_latent)
|
| 118 |
+
|
| 119 |
+
_, pred_image_or_video = self.generator(
|
| 120 |
+
noisy_image_or_video=noisy_input,
|
| 121 |
+
conditional_dict=conditional_dict,
|
| 122 |
+
timestep=timestep
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
# Step 2: Compute the regression loss
|
| 126 |
+
mask = timestep != 0
|
| 127 |
+
|
| 128 |
+
loss = F.mse_loss(
|
| 129 |
+
pred_image_or_video[mask], target_latent[mask], reduction="mean")
|
| 130 |
+
|
| 131 |
+
log_dict = {
|
| 132 |
+
"unnormalized_loss": F.mse_loss(pred_image_or_video, target_latent, reduction='none').mean(dim=[1, 2, 3, 4]).detach(),
|
| 133 |
+
"timestep": timestep.float().mean(dim=1).detach(),
|
| 134 |
+
"input": noisy_input.detach(),
|
| 135 |
+
"output": pred_image_or_video.detach(),
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
return loss, log_dict
|
model/sid.py
ADDED
|
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pipeline import RollingForcingTrainingPipeline
|
| 2 |
+
from typing import Optional, Tuple
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from model.base import RollingForcingModel
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class SiD(RollingForcingModel):
|
| 9 |
+
def __init__(self, args, device):
|
| 10 |
+
"""
|
| 11 |
+
Initialize the DMD (Distribution Matching Distillation) module.
|
| 12 |
+
This class is self-contained and compute generator and fake score losses
|
| 13 |
+
in the forward pass.
|
| 14 |
+
"""
|
| 15 |
+
super().__init__(args, device)
|
| 16 |
+
self.num_frame_per_block = getattr(args, "num_frame_per_block", 1)
|
| 17 |
+
|
| 18 |
+
if self.num_frame_per_block > 1:
|
| 19 |
+
self.generator.model.num_frame_per_block = self.num_frame_per_block
|
| 20 |
+
|
| 21 |
+
if args.gradient_checkpointing:
|
| 22 |
+
self.generator.enable_gradient_checkpointing()
|
| 23 |
+
self.fake_score.enable_gradient_checkpointing()
|
| 24 |
+
self.real_score.enable_gradient_checkpointing()
|
| 25 |
+
|
| 26 |
+
# this will be init later with fsdp-wrapped modules
|
| 27 |
+
self.inference_pipeline: RollingForcingTrainingPipeline = None
|
| 28 |
+
|
| 29 |
+
# Step 2: Initialize all dmd hyperparameters
|
| 30 |
+
self.num_train_timestep = args.num_train_timestep
|
| 31 |
+
self.min_step = int(0.02 * self.num_train_timestep)
|
| 32 |
+
self.max_step = int(0.98 * self.num_train_timestep)
|
| 33 |
+
if hasattr(args, "real_guidance_scale"):
|
| 34 |
+
self.real_guidance_scale = args.real_guidance_scale
|
| 35 |
+
else:
|
| 36 |
+
self.real_guidance_scale = args.guidance_scale
|
| 37 |
+
self.timestep_shift = getattr(args, "timestep_shift", 1.0)
|
| 38 |
+
self.sid_alpha = getattr(args, "sid_alpha", 1.0)
|
| 39 |
+
self.ts_schedule = getattr(args, "ts_schedule", True)
|
| 40 |
+
self.ts_schedule_max = getattr(args, "ts_schedule_max", False)
|
| 41 |
+
|
| 42 |
+
if getattr(self.scheduler, "alphas_cumprod", None) is not None:
|
| 43 |
+
self.scheduler.alphas_cumprod = self.scheduler.alphas_cumprod.to(device)
|
| 44 |
+
else:
|
| 45 |
+
self.scheduler.alphas_cumprod = None
|
| 46 |
+
|
| 47 |
+
def compute_distribution_matching_loss(
|
| 48 |
+
self,
|
| 49 |
+
image_or_video: torch.Tensor,
|
| 50 |
+
conditional_dict: dict,
|
| 51 |
+
unconditional_dict: dict,
|
| 52 |
+
gradient_mask: Optional[torch.Tensor] = None,
|
| 53 |
+
denoised_timestep_from: int = 0,
|
| 54 |
+
denoised_timestep_to: int = 0
|
| 55 |
+
) -> Tuple[torch.Tensor, dict]:
|
| 56 |
+
"""
|
| 57 |
+
Compute the DMD loss (eq 7 in https://arxiv.org/abs/2311.18828).
|
| 58 |
+
Input:
|
| 59 |
+
- image_or_video: a tensor with shape [B, F, C, H, W] where the number of frame is 1 for images.
|
| 60 |
+
- conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
|
| 61 |
+
- unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings).
|
| 62 |
+
- gradient_mask: a boolean tensor with the same shape as image_or_video indicating which pixels to compute loss .
|
| 63 |
+
Output:
|
| 64 |
+
- dmd_loss: a scalar tensor representing the DMD loss.
|
| 65 |
+
- dmd_log_dict: a dictionary containing the intermediate tensors for logging.
|
| 66 |
+
"""
|
| 67 |
+
original_latent = image_or_video
|
| 68 |
+
|
| 69 |
+
batch_size, num_frame = image_or_video.shape[:2]
|
| 70 |
+
|
| 71 |
+
# Step 1: Randomly sample timestep based on the given schedule and corresponding noise
|
| 72 |
+
min_timestep = denoised_timestep_to if self.ts_schedule and denoised_timestep_to is not None else self.min_score_timestep
|
| 73 |
+
max_timestep = denoised_timestep_from if self.ts_schedule_max and denoised_timestep_from is not None else self.num_train_timestep
|
| 74 |
+
timestep = self._get_timestep(
|
| 75 |
+
min_timestep,
|
| 76 |
+
max_timestep,
|
| 77 |
+
batch_size,
|
| 78 |
+
num_frame,
|
| 79 |
+
self.num_frame_per_block,
|
| 80 |
+
uniform_timestep=True
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
if self.timestep_shift > 1:
|
| 84 |
+
timestep = self.timestep_shift * \
|
| 85 |
+
(timestep / 1000) / \
|
| 86 |
+
(1 + (self.timestep_shift - 1) * (timestep / 1000)) * 1000
|
| 87 |
+
timestep = timestep.clamp(self.min_step, self.max_step)
|
| 88 |
+
|
| 89 |
+
noise = torch.randn_like(image_or_video)
|
| 90 |
+
noisy_latent = self.scheduler.add_noise(
|
| 91 |
+
image_or_video.flatten(0, 1),
|
| 92 |
+
noise.flatten(0, 1),
|
| 93 |
+
timestep.flatten(0, 1)
|
| 94 |
+
).unflatten(0, (batch_size, num_frame))
|
| 95 |
+
|
| 96 |
+
# Step 2: SiD (May be wrap it?)
|
| 97 |
+
noisy_image_or_video = noisy_latent
|
| 98 |
+
# Step 2.1: Compute the fake score
|
| 99 |
+
_, pred_fake_image = self.fake_score(
|
| 100 |
+
noisy_image_or_video=noisy_image_or_video,
|
| 101 |
+
conditional_dict=conditional_dict,
|
| 102 |
+
timestep=timestep
|
| 103 |
+
)
|
| 104 |
+
# Step 2.2: Compute the real score
|
| 105 |
+
# We compute the conditional and unconditional prediction
|
| 106 |
+
# and add them together to achieve cfg (https://arxiv.org/abs/2207.12598)
|
| 107 |
+
# NOTE: This step may cause OOM issue, which can be addressed by the CFG-free technique
|
| 108 |
+
|
| 109 |
+
_, pred_real_image_cond = self.real_score(
|
| 110 |
+
noisy_image_or_video=noisy_image_or_video,
|
| 111 |
+
conditional_dict=conditional_dict,
|
| 112 |
+
timestep=timestep
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
_, pred_real_image_uncond = self.real_score(
|
| 116 |
+
noisy_image_or_video=noisy_image_or_video,
|
| 117 |
+
conditional_dict=unconditional_dict,
|
| 118 |
+
timestep=timestep
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
pred_real_image = pred_real_image_cond + (
|
| 122 |
+
pred_real_image_cond - pred_real_image_uncond
|
| 123 |
+
) * self.real_guidance_scale
|
| 124 |
+
|
| 125 |
+
# Step 2.3: SiD Loss
|
| 126 |
+
# TODO: Add alpha
|
| 127 |
+
# TODO: Double?
|
| 128 |
+
sid_loss = (pred_real_image.double() - pred_fake_image.double()) * ((pred_real_image.double() - original_latent.double()) - self.sid_alpha * (pred_real_image.double() - pred_fake_image.double()))
|
| 129 |
+
|
| 130 |
+
# Step 2.4: Loss normalizer
|
| 131 |
+
with torch.no_grad():
|
| 132 |
+
p_real = (original_latent - pred_real_image)
|
| 133 |
+
normalizer = torch.abs(p_real).mean(dim=[1, 2, 3, 4], keepdim=True)
|
| 134 |
+
sid_loss = sid_loss / normalizer
|
| 135 |
+
|
| 136 |
+
sid_loss = torch.nan_to_num(sid_loss)
|
| 137 |
+
num_frame = sid_loss.shape[1]
|
| 138 |
+
sid_loss = sid_loss.mean()
|
| 139 |
+
|
| 140 |
+
sid_log_dict = {
|
| 141 |
+
"dmdtrain_gradient_norm": torch.zeros_like(sid_loss),
|
| 142 |
+
"timestep": timestep.detach()
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
return sid_loss, sid_log_dict
|
| 146 |
+
|
| 147 |
+
def generator_loss(
|
| 148 |
+
self,
|
| 149 |
+
image_or_video_shape,
|
| 150 |
+
conditional_dict: dict,
|
| 151 |
+
unconditional_dict: dict,
|
| 152 |
+
clean_latent: torch.Tensor,
|
| 153 |
+
initial_latent: torch.Tensor = None
|
| 154 |
+
) -> Tuple[torch.Tensor, dict]:
|
| 155 |
+
"""
|
| 156 |
+
Generate image/videos from noise and compute the DMD loss.
|
| 157 |
+
The noisy input to the generator is backward simulated.
|
| 158 |
+
This removes the need of any datasets during distillation.
|
| 159 |
+
See Sec 4.5 of the DMD2 paper (https://arxiv.org/abs/2405.14867) for details.
|
| 160 |
+
Input:
|
| 161 |
+
- image_or_video_shape: a list containing the shape of the image or video [B, F, C, H, W].
|
| 162 |
+
- conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
|
| 163 |
+
- unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings).
|
| 164 |
+
- clean_latent: a tensor containing the clean latents [B, F, C, H, W]. Need to be passed when no backward simulation is used.
|
| 165 |
+
Output:
|
| 166 |
+
- loss: a scalar tensor representing the generator loss.
|
| 167 |
+
- generator_log_dict: a dictionary containing the intermediate tensors for logging.
|
| 168 |
+
"""
|
| 169 |
+
# Step 1: Unroll generator to obtain fake videos
|
| 170 |
+
pred_image, gradient_mask, denoised_timestep_from, denoised_timestep_to = self._run_generator(
|
| 171 |
+
image_or_video_shape=image_or_video_shape,
|
| 172 |
+
conditional_dict=conditional_dict,
|
| 173 |
+
initial_latent=initial_latent
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
# Step 2: Compute the DMD loss
|
| 177 |
+
dmd_loss, dmd_log_dict = self.compute_distribution_matching_loss(
|
| 178 |
+
image_or_video=pred_image,
|
| 179 |
+
conditional_dict=conditional_dict,
|
| 180 |
+
unconditional_dict=unconditional_dict,
|
| 181 |
+
gradient_mask=gradient_mask,
|
| 182 |
+
denoised_timestep_from=denoised_timestep_from,
|
| 183 |
+
denoised_timestep_to=denoised_timestep_to
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
return dmd_loss, dmd_log_dict
|
| 187 |
+
|
| 188 |
+
def critic_loss(
|
| 189 |
+
self,
|
| 190 |
+
image_or_video_shape,
|
| 191 |
+
conditional_dict: dict,
|
| 192 |
+
unconditional_dict: dict,
|
| 193 |
+
clean_latent: torch.Tensor,
|
| 194 |
+
initial_latent: torch.Tensor = None
|
| 195 |
+
) -> Tuple[torch.Tensor, dict]:
|
| 196 |
+
"""
|
| 197 |
+
Generate image/videos from noise and train the critic with generated samples.
|
| 198 |
+
The noisy input to the generator is backward simulated.
|
| 199 |
+
This removes the need of any datasets during distillation.
|
| 200 |
+
See Sec 4.5 of the DMD2 paper (https://arxiv.org/abs/2405.14867) for details.
|
| 201 |
+
Input:
|
| 202 |
+
- image_or_video_shape: a list containing the shape of the image or video [B, F, C, H, W].
|
| 203 |
+
- conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
|
| 204 |
+
- unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings).
|
| 205 |
+
- clean_latent: a tensor containing the clean latents [B, F, C, H, W]. Need to be passed when no backward simulation is used.
|
| 206 |
+
Output:
|
| 207 |
+
- loss: a scalar tensor representing the generator loss.
|
| 208 |
+
- critic_log_dict: a dictionary containing the intermediate tensors for logging.
|
| 209 |
+
"""
|
| 210 |
+
|
| 211 |
+
# Step 1: Run generator on backward simulated noisy input
|
| 212 |
+
with torch.no_grad():
|
| 213 |
+
generated_image, _, denoised_timestep_from, denoised_timestep_to = self._run_generator(
|
| 214 |
+
image_or_video_shape=image_or_video_shape,
|
| 215 |
+
conditional_dict=conditional_dict,
|
| 216 |
+
initial_latent=initial_latent
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
# Step 2: Compute the fake prediction
|
| 220 |
+
min_timestep = denoised_timestep_to if self.ts_schedule and denoised_timestep_to is not None else self.min_score_timestep
|
| 221 |
+
max_timestep = denoised_timestep_from if self.ts_schedule_max and denoised_timestep_from is not None else self.num_train_timestep
|
| 222 |
+
critic_timestep = self._get_timestep(
|
| 223 |
+
min_timestep,
|
| 224 |
+
max_timestep,
|
| 225 |
+
image_or_video_shape[0],
|
| 226 |
+
image_or_video_shape[1],
|
| 227 |
+
self.num_frame_per_block,
|
| 228 |
+
uniform_timestep=True
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
if self.timestep_shift > 1:
|
| 232 |
+
critic_timestep = self.timestep_shift * \
|
| 233 |
+
(critic_timestep / 1000) / (1 + (self.timestep_shift - 1) * (critic_timestep / 1000)) * 1000
|
| 234 |
+
|
| 235 |
+
critic_timestep = critic_timestep.clamp(self.min_step, self.max_step)
|
| 236 |
+
|
| 237 |
+
critic_noise = torch.randn_like(generated_image)
|
| 238 |
+
noisy_generated_image = self.scheduler.add_noise(
|
| 239 |
+
generated_image.flatten(0, 1),
|
| 240 |
+
critic_noise.flatten(0, 1),
|
| 241 |
+
critic_timestep.flatten(0, 1)
|
| 242 |
+
).unflatten(0, image_or_video_shape[:2])
|
| 243 |
+
|
| 244 |
+
_, pred_fake_image = self.fake_score(
|
| 245 |
+
noisy_image_or_video=noisy_generated_image,
|
| 246 |
+
conditional_dict=conditional_dict,
|
| 247 |
+
timestep=critic_timestep
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
# Step 3: Compute the denoising loss for the fake critic
|
| 251 |
+
if self.args.denoising_loss_type == "flow":
|
| 252 |
+
from utils.wan_wrapper import WanDiffusionWrapper
|
| 253 |
+
flow_pred = WanDiffusionWrapper._convert_x0_to_flow_pred(
|
| 254 |
+
scheduler=self.scheduler,
|
| 255 |
+
x0_pred=pred_fake_image.flatten(0, 1),
|
| 256 |
+
xt=noisy_generated_image.flatten(0, 1),
|
| 257 |
+
timestep=critic_timestep.flatten(0, 1)
|
| 258 |
+
)
|
| 259 |
+
pred_fake_noise = None
|
| 260 |
+
else:
|
| 261 |
+
flow_pred = None
|
| 262 |
+
pred_fake_noise = self.scheduler.convert_x0_to_noise(
|
| 263 |
+
x0=pred_fake_image.flatten(0, 1),
|
| 264 |
+
xt=noisy_generated_image.flatten(0, 1),
|
| 265 |
+
timestep=critic_timestep.flatten(0, 1)
|
| 266 |
+
).unflatten(0, image_or_video_shape[:2])
|
| 267 |
+
|
| 268 |
+
denoising_loss = self.denoising_loss_func(
|
| 269 |
+
x=generated_image.flatten(0, 1),
|
| 270 |
+
x_pred=pred_fake_image.flatten(0, 1),
|
| 271 |
+
noise=critic_noise.flatten(0, 1),
|
| 272 |
+
noise_pred=pred_fake_noise,
|
| 273 |
+
alphas_cumprod=self.scheduler.alphas_cumprod,
|
| 274 |
+
timestep=critic_timestep.flatten(0, 1),
|
| 275 |
+
flow_pred=flow_pred
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
# Step 5: Debugging Log
|
| 279 |
+
critic_log_dict = {
|
| 280 |
+
"critic_timestep": critic_timestep.detach()
|
| 281 |
+
}
|
| 282 |
+
|
| 283 |
+
return denoising_loss, critic_log_dict
|
pipeline/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .bidirectional_diffusion_inference import BidirectionalDiffusionInferencePipeline
|
| 2 |
+
from .bidirectional_inference import BidirectionalInferencePipeline
|
| 3 |
+
from .causal_diffusion_inference import CausalDiffusionInferencePipeline
|
| 4 |
+
from .rolling_forcing_inference import CausalInferencePipeline
|
| 5 |
+
from .rolling_forcing_training import RollingForcingTrainingPipeline
|
| 6 |
+
|
| 7 |
+
__all__ = [
|
| 8 |
+
"BidirectionalDiffusionInferencePipeline",
|
| 9 |
+
"BidirectionalInferencePipeline",
|
| 10 |
+
"CausalDiffusionInferencePipeline",
|
| 11 |
+
"CausalInferencePipeline",
|
| 12 |
+
"RollingForcingTrainingPipeline"
|
| 13 |
+
]
|
pipeline/bidirectional_diffusion_inference.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from tqdm import tqdm
|
| 2 |
+
from typing import List
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from wan.utils.fm_solvers import FlowDPMSolverMultistepScheduler, get_sampling_sigmas, retrieve_timesteps
|
| 6 |
+
from wan.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
| 7 |
+
from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder, WanVAEWrapper
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class BidirectionalDiffusionInferencePipeline(torch.nn.Module):
|
| 11 |
+
def __init__(
|
| 12 |
+
self,
|
| 13 |
+
args,
|
| 14 |
+
device,
|
| 15 |
+
generator=None,
|
| 16 |
+
text_encoder=None,
|
| 17 |
+
vae=None
|
| 18 |
+
):
|
| 19 |
+
super().__init__()
|
| 20 |
+
# Step 1: Initialize all models
|
| 21 |
+
self.generator = WanDiffusionWrapper(
|
| 22 |
+
**getattr(args, "model_kwargs", {}), is_causal=False) if generator is None else generator
|
| 23 |
+
self.text_encoder = WanTextEncoder() if text_encoder is None else text_encoder
|
| 24 |
+
self.vae = WanVAEWrapper() if vae is None else vae
|
| 25 |
+
|
| 26 |
+
# Step 2: Initialize scheduler
|
| 27 |
+
self.num_train_timesteps = args.num_train_timestep
|
| 28 |
+
self.sampling_steps = 50
|
| 29 |
+
self.sample_solver = 'unipc'
|
| 30 |
+
self.shift = 8.0
|
| 31 |
+
|
| 32 |
+
self.args = args
|
| 33 |
+
|
| 34 |
+
def inference(
|
| 35 |
+
self,
|
| 36 |
+
noise: torch.Tensor,
|
| 37 |
+
text_prompts: List[str],
|
| 38 |
+
return_latents=False
|
| 39 |
+
) -> torch.Tensor:
|
| 40 |
+
"""
|
| 41 |
+
Perform inference on the given noise and text prompts.
|
| 42 |
+
Inputs:
|
| 43 |
+
noise (torch.Tensor): The input noise tensor of shape
|
| 44 |
+
(batch_size, num_frames, num_channels, height, width).
|
| 45 |
+
text_prompts (List[str]): The list of text prompts.
|
| 46 |
+
Outputs:
|
| 47 |
+
video (torch.Tensor): The generated video tensor of shape
|
| 48 |
+
(batch_size, num_frames, num_channels, height, width). It is normalized to be in the range [0, 1].
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
conditional_dict = self.text_encoder(
|
| 52 |
+
text_prompts=text_prompts
|
| 53 |
+
)
|
| 54 |
+
unconditional_dict = self.text_encoder(
|
| 55 |
+
text_prompts=[self.args.negative_prompt] * len(text_prompts)
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
latents = noise
|
| 59 |
+
|
| 60 |
+
sample_scheduler = self._initialize_sample_scheduler(noise)
|
| 61 |
+
for _, t in enumerate(tqdm(sample_scheduler.timesteps)):
|
| 62 |
+
latent_model_input = latents
|
| 63 |
+
timestep = t * torch.ones([latents.shape[0], 21], device=noise.device, dtype=torch.float32)
|
| 64 |
+
|
| 65 |
+
flow_pred_cond, _ = self.generator(latent_model_input, conditional_dict, timestep)
|
| 66 |
+
flow_pred_uncond, _ = self.generator(latent_model_input, unconditional_dict, timestep)
|
| 67 |
+
|
| 68 |
+
flow_pred = flow_pred_uncond + self.args.guidance_scale * (
|
| 69 |
+
flow_pred_cond - flow_pred_uncond)
|
| 70 |
+
|
| 71 |
+
temp_x0 = sample_scheduler.step(
|
| 72 |
+
flow_pred.unsqueeze(0),
|
| 73 |
+
t,
|
| 74 |
+
latents.unsqueeze(0),
|
| 75 |
+
return_dict=False)[0]
|
| 76 |
+
latents = temp_x0.squeeze(0)
|
| 77 |
+
|
| 78 |
+
x0 = latents
|
| 79 |
+
video = self.vae.decode_to_pixel(x0)
|
| 80 |
+
video = (video * 0.5 + 0.5).clamp(0, 1)
|
| 81 |
+
|
| 82 |
+
del sample_scheduler
|
| 83 |
+
|
| 84 |
+
if return_latents:
|
| 85 |
+
return video, latents
|
| 86 |
+
else:
|
| 87 |
+
return video
|
| 88 |
+
|
| 89 |
+
def _initialize_sample_scheduler(self, noise):
|
| 90 |
+
if self.sample_solver == 'unipc':
|
| 91 |
+
sample_scheduler = FlowUniPCMultistepScheduler(
|
| 92 |
+
num_train_timesteps=self.num_train_timesteps,
|
| 93 |
+
shift=1,
|
| 94 |
+
use_dynamic_shifting=False)
|
| 95 |
+
sample_scheduler.set_timesteps(
|
| 96 |
+
self.sampling_steps, device=noise.device, shift=self.shift)
|
| 97 |
+
self.timesteps = sample_scheduler.timesteps
|
| 98 |
+
elif self.sample_solver == 'dpm++':
|
| 99 |
+
sample_scheduler = FlowDPMSolverMultistepScheduler(
|
| 100 |
+
num_train_timesteps=self.num_train_timesteps,
|
| 101 |
+
shift=1,
|
| 102 |
+
use_dynamic_shifting=False)
|
| 103 |
+
sampling_sigmas = get_sampling_sigmas(self.sampling_steps, self.shift)
|
| 104 |
+
self.timesteps, _ = retrieve_timesteps(
|
| 105 |
+
sample_scheduler,
|
| 106 |
+
device=noise.device,
|
| 107 |
+
sigmas=sampling_sigmas)
|
| 108 |
+
else:
|
| 109 |
+
raise NotImplementedError("Unsupported solver.")
|
| 110 |
+
return sample_scheduler
|
pipeline/bidirectional_inference.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder, WanVAEWrapper
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class BidirectionalInferencePipeline(torch.nn.Module):
|
| 8 |
+
def __init__(
|
| 9 |
+
self,
|
| 10 |
+
args,
|
| 11 |
+
device,
|
| 12 |
+
generator=None,
|
| 13 |
+
text_encoder=None,
|
| 14 |
+
vae=None
|
| 15 |
+
):
|
| 16 |
+
super().__init__()
|
| 17 |
+
# Step 1: Initialize all models
|
| 18 |
+
self.generator = WanDiffusionWrapper(
|
| 19 |
+
**getattr(args, "model_kwargs", {}), is_causal=False) if generator is None else generator
|
| 20 |
+
self.text_encoder = WanTextEncoder() if text_encoder is None else text_encoder
|
| 21 |
+
self.vae = WanVAEWrapper() if vae is None else vae
|
| 22 |
+
|
| 23 |
+
# Step 2: Initialize all bidirectional wan hyperparmeters
|
| 24 |
+
self.scheduler = self.generator.get_scheduler()
|
| 25 |
+
self.denoising_step_list = torch.tensor(
|
| 26 |
+
args.denoising_step_list, dtype=torch.long, device=device)
|
| 27 |
+
if self.denoising_step_list[-1] == 0:
|
| 28 |
+
self.denoising_step_list = self.denoising_step_list[:-1] # remove the zero timestep for inference
|
| 29 |
+
if args.warp_denoising_step:
|
| 30 |
+
timesteps = torch.cat((self.scheduler.timesteps.cpu(), torch.tensor([0], dtype=torch.float32)))
|
| 31 |
+
self.denoising_step_list = timesteps[1000 - self.denoising_step_list]
|
| 32 |
+
|
| 33 |
+
def inference(self, noise: torch.Tensor, text_prompts: List[str]) -> torch.Tensor:
|
| 34 |
+
"""
|
| 35 |
+
Perform inference on the given noise and text prompts.
|
| 36 |
+
Inputs:
|
| 37 |
+
noise (torch.Tensor): The input noise tensor of shape
|
| 38 |
+
(batch_size, num_frames, num_channels, height, width).
|
| 39 |
+
text_prompts (List[str]): The list of text prompts.
|
| 40 |
+
Outputs:
|
| 41 |
+
video (torch.Tensor): The generated video tensor of shape
|
| 42 |
+
(batch_size, num_frames, num_channels, height, width). It is normalized to be in the range [0, 1].
|
| 43 |
+
"""
|
| 44 |
+
conditional_dict = self.text_encoder(
|
| 45 |
+
text_prompts=text_prompts
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
# initial point
|
| 49 |
+
noisy_image_or_video = noise
|
| 50 |
+
|
| 51 |
+
# use the last n-1 timesteps to simulate the generator's input
|
| 52 |
+
for index, current_timestep in enumerate(self.denoising_step_list[:-1]):
|
| 53 |
+
_, pred_image_or_video = self.generator(
|
| 54 |
+
noisy_image_or_video=noisy_image_or_video,
|
| 55 |
+
conditional_dict=conditional_dict,
|
| 56 |
+
timestep=torch.ones(
|
| 57 |
+
noise.shape[:2], dtype=torch.long, device=noise.device) * current_timestep
|
| 58 |
+
) # [B, F, C, H, W]
|
| 59 |
+
|
| 60 |
+
next_timestep = self.denoising_step_list[index + 1] * torch.ones(
|
| 61 |
+
noise.shape[:2], dtype=torch.long, device=noise.device)
|
| 62 |
+
|
| 63 |
+
noisy_image_or_video = self.scheduler.add_noise(
|
| 64 |
+
pred_image_or_video.flatten(0, 1),
|
| 65 |
+
torch.randn_like(pred_image_or_video.flatten(0, 1)),
|
| 66 |
+
next_timestep.flatten(0, 1)
|
| 67 |
+
).unflatten(0, noise.shape[:2])
|
| 68 |
+
|
| 69 |
+
video = self.vae.decode_to_pixel(pred_image_or_video)
|
| 70 |
+
video = (video * 0.5 + 0.5).clamp(0, 1)
|
| 71 |
+
return video
|
pipeline/causal_diffusion_inference.py
ADDED
|
@@ -0,0 +1,342 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from tqdm import tqdm
|
| 2 |
+
from typing import List, Optional
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from wan.utils.fm_solvers import FlowDPMSolverMultistepScheduler, get_sampling_sigmas, retrieve_timesteps
|
| 6 |
+
from wan.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
| 7 |
+
from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder, WanVAEWrapper
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class CausalDiffusionInferencePipeline(torch.nn.Module):
|
| 11 |
+
def __init__(
|
| 12 |
+
self,
|
| 13 |
+
args,
|
| 14 |
+
device,
|
| 15 |
+
generator=None,
|
| 16 |
+
text_encoder=None,
|
| 17 |
+
vae=None
|
| 18 |
+
):
|
| 19 |
+
super().__init__()
|
| 20 |
+
# Step 1: Initialize all models
|
| 21 |
+
self.generator = WanDiffusionWrapper(
|
| 22 |
+
**getattr(args, "model_kwargs", {}), is_causal=True) if generator is None else generator
|
| 23 |
+
self.text_encoder = WanTextEncoder() if text_encoder is None else text_encoder
|
| 24 |
+
self.vae = WanVAEWrapper() if vae is None else vae
|
| 25 |
+
|
| 26 |
+
# Step 2: Initialize scheduler
|
| 27 |
+
self.num_train_timesteps = args.num_train_timestep
|
| 28 |
+
self.sampling_steps = 50
|
| 29 |
+
self.sample_solver = 'unipc'
|
| 30 |
+
self.shift = args.timestep_shift
|
| 31 |
+
|
| 32 |
+
self.num_transformer_blocks = 30
|
| 33 |
+
self.frame_seq_length = 1560
|
| 34 |
+
|
| 35 |
+
self.kv_cache_pos = None
|
| 36 |
+
self.kv_cache_neg = None
|
| 37 |
+
self.crossattn_cache_pos = None
|
| 38 |
+
self.crossattn_cache_neg = None
|
| 39 |
+
self.args = args
|
| 40 |
+
self.num_frame_per_block = getattr(args, "num_frame_per_block", 1)
|
| 41 |
+
self.independent_first_frame = args.independent_first_frame
|
| 42 |
+
self.local_attn_size = self.generator.model.local_attn_size
|
| 43 |
+
|
| 44 |
+
print(f"KV inference with {self.num_frame_per_block} frames per block")
|
| 45 |
+
|
| 46 |
+
if self.num_frame_per_block > 1:
|
| 47 |
+
self.generator.model.num_frame_per_block = self.num_frame_per_block
|
| 48 |
+
|
| 49 |
+
def inference(
|
| 50 |
+
self,
|
| 51 |
+
noise: torch.Tensor,
|
| 52 |
+
text_prompts: List[str],
|
| 53 |
+
initial_latent: Optional[torch.Tensor] = None,
|
| 54 |
+
return_latents: bool = False,
|
| 55 |
+
start_frame_index: Optional[int] = 0
|
| 56 |
+
) -> torch.Tensor:
|
| 57 |
+
"""
|
| 58 |
+
Perform inference on the given noise and text prompts.
|
| 59 |
+
Inputs:
|
| 60 |
+
noise (torch.Tensor): The input noise tensor of shape
|
| 61 |
+
(batch_size, num_output_frames, num_channels, height, width).
|
| 62 |
+
text_prompts (List[str]): The list of text prompts.
|
| 63 |
+
initial_latent (torch.Tensor): The initial latent tensor of shape
|
| 64 |
+
(batch_size, num_input_frames, num_channels, height, width).
|
| 65 |
+
If num_input_frames is 1, perform image to video.
|
| 66 |
+
If num_input_frames is greater than 1, perform video extension.
|
| 67 |
+
return_latents (bool): Whether to return the latents.
|
| 68 |
+
start_frame_index (int): In long video generation, where does the current window start?
|
| 69 |
+
Outputs:
|
| 70 |
+
video (torch.Tensor): The generated video tensor of shape
|
| 71 |
+
(batch_size, num_frames, num_channels, height, width). It is normalized to be in the range [0, 1].
|
| 72 |
+
"""
|
| 73 |
+
batch_size, num_frames, num_channels, height, width = noise.shape
|
| 74 |
+
if not self.independent_first_frame or (self.independent_first_frame and initial_latent is not None):
|
| 75 |
+
# If the first frame is independent and the first frame is provided, then the number of frames in the
|
| 76 |
+
# noise should still be a multiple of num_frame_per_block
|
| 77 |
+
assert num_frames % self.num_frame_per_block == 0
|
| 78 |
+
num_blocks = num_frames // self.num_frame_per_block
|
| 79 |
+
elif self.independent_first_frame and initial_latent is None:
|
| 80 |
+
# Using a [1, 4, 4, 4, 4, 4] model to generate a video without image conditioning
|
| 81 |
+
assert (num_frames - 1) % self.num_frame_per_block == 0
|
| 82 |
+
num_blocks = (num_frames - 1) // self.num_frame_per_block
|
| 83 |
+
num_input_frames = initial_latent.shape[1] if initial_latent is not None else 0
|
| 84 |
+
num_output_frames = num_frames + num_input_frames # add the initial latent frames
|
| 85 |
+
conditional_dict = self.text_encoder(
|
| 86 |
+
text_prompts=text_prompts
|
| 87 |
+
)
|
| 88 |
+
unconditional_dict = self.text_encoder(
|
| 89 |
+
text_prompts=[self.args.negative_prompt] * len(text_prompts)
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
output = torch.zeros(
|
| 93 |
+
[batch_size, num_output_frames, num_channels, height, width],
|
| 94 |
+
device=noise.device,
|
| 95 |
+
dtype=noise.dtype
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
# Step 1: Initialize KV cache to all zeros
|
| 99 |
+
if self.kv_cache_pos is None:
|
| 100 |
+
self._initialize_kv_cache(
|
| 101 |
+
batch_size=batch_size,
|
| 102 |
+
dtype=noise.dtype,
|
| 103 |
+
device=noise.device
|
| 104 |
+
)
|
| 105 |
+
self._initialize_crossattn_cache(
|
| 106 |
+
batch_size=batch_size,
|
| 107 |
+
dtype=noise.dtype,
|
| 108 |
+
device=noise.device
|
| 109 |
+
)
|
| 110 |
+
else:
|
| 111 |
+
# reset cross attn cache
|
| 112 |
+
for block_index in range(self.num_transformer_blocks):
|
| 113 |
+
self.crossattn_cache_pos[block_index]["is_init"] = False
|
| 114 |
+
self.crossattn_cache_neg[block_index]["is_init"] = False
|
| 115 |
+
# reset kv cache
|
| 116 |
+
for block_index in range(len(self.kv_cache_pos)):
|
| 117 |
+
self.kv_cache_pos[block_index]["global_end_index"] = torch.tensor(
|
| 118 |
+
[0], dtype=torch.long, device=noise.device)
|
| 119 |
+
self.kv_cache_pos[block_index]["local_end_index"] = torch.tensor(
|
| 120 |
+
[0], dtype=torch.long, device=noise.device)
|
| 121 |
+
self.kv_cache_neg[block_index]["global_end_index"] = torch.tensor(
|
| 122 |
+
[0], dtype=torch.long, device=noise.device)
|
| 123 |
+
self.kv_cache_neg[block_index]["local_end_index"] = torch.tensor(
|
| 124 |
+
[0], dtype=torch.long, device=noise.device)
|
| 125 |
+
|
| 126 |
+
# Step 2: Cache context feature
|
| 127 |
+
current_start_frame = start_frame_index
|
| 128 |
+
cache_start_frame = 0
|
| 129 |
+
if initial_latent is not None:
|
| 130 |
+
timestep = torch.ones([batch_size, 1], device=noise.device, dtype=torch.int64) * 0
|
| 131 |
+
if self.independent_first_frame:
|
| 132 |
+
# Assume num_input_frames is 1 + self.num_frame_per_block * num_input_blocks
|
| 133 |
+
assert (num_input_frames - 1) % self.num_frame_per_block == 0
|
| 134 |
+
num_input_blocks = (num_input_frames - 1) // self.num_frame_per_block
|
| 135 |
+
output[:, :1] = initial_latent[:, :1]
|
| 136 |
+
self.generator(
|
| 137 |
+
noisy_image_or_video=initial_latent[:, :1],
|
| 138 |
+
conditional_dict=conditional_dict,
|
| 139 |
+
timestep=timestep * 0,
|
| 140 |
+
kv_cache=self.kv_cache_pos,
|
| 141 |
+
crossattn_cache=self.crossattn_cache_pos,
|
| 142 |
+
current_start=current_start_frame * self.frame_seq_length,
|
| 143 |
+
cache_start=cache_start_frame * self.frame_seq_length
|
| 144 |
+
)
|
| 145 |
+
self.generator(
|
| 146 |
+
noisy_image_or_video=initial_latent[:, :1],
|
| 147 |
+
conditional_dict=unconditional_dict,
|
| 148 |
+
timestep=timestep * 0,
|
| 149 |
+
kv_cache=self.kv_cache_neg,
|
| 150 |
+
crossattn_cache=self.crossattn_cache_neg,
|
| 151 |
+
current_start=current_start_frame * self.frame_seq_length,
|
| 152 |
+
cache_start=cache_start_frame * self.frame_seq_length
|
| 153 |
+
)
|
| 154 |
+
current_start_frame += 1
|
| 155 |
+
cache_start_frame += 1
|
| 156 |
+
else:
|
| 157 |
+
# Assume num_input_frames is self.num_frame_per_block * num_input_blocks
|
| 158 |
+
assert num_input_frames % self.num_frame_per_block == 0
|
| 159 |
+
num_input_blocks = num_input_frames // self.num_frame_per_block
|
| 160 |
+
|
| 161 |
+
for block_index in range(num_input_blocks):
|
| 162 |
+
current_ref_latents = \
|
| 163 |
+
initial_latent[:, cache_start_frame:cache_start_frame + self.num_frame_per_block]
|
| 164 |
+
output[:, cache_start_frame:cache_start_frame + self.num_frame_per_block] = current_ref_latents
|
| 165 |
+
self.generator(
|
| 166 |
+
noisy_image_or_video=current_ref_latents,
|
| 167 |
+
conditional_dict=conditional_dict,
|
| 168 |
+
timestep=timestep * 0,
|
| 169 |
+
kv_cache=self.kv_cache_pos,
|
| 170 |
+
crossattn_cache=self.crossattn_cache_pos,
|
| 171 |
+
current_start=current_start_frame * self.frame_seq_length,
|
| 172 |
+
cache_start=cache_start_frame * self.frame_seq_length
|
| 173 |
+
)
|
| 174 |
+
self.generator(
|
| 175 |
+
noisy_image_or_video=current_ref_latents,
|
| 176 |
+
conditional_dict=unconditional_dict,
|
| 177 |
+
timestep=timestep * 0,
|
| 178 |
+
kv_cache=self.kv_cache_neg,
|
| 179 |
+
crossattn_cache=self.crossattn_cache_neg,
|
| 180 |
+
current_start=current_start_frame * self.frame_seq_length,
|
| 181 |
+
cache_start=cache_start_frame * self.frame_seq_length
|
| 182 |
+
)
|
| 183 |
+
current_start_frame += self.num_frame_per_block
|
| 184 |
+
cache_start_frame += self.num_frame_per_block
|
| 185 |
+
|
| 186 |
+
# Step 3: Temporal denoising loop
|
| 187 |
+
all_num_frames = [self.num_frame_per_block] * num_blocks
|
| 188 |
+
if self.independent_first_frame and initial_latent is None:
|
| 189 |
+
all_num_frames = [1] + all_num_frames
|
| 190 |
+
for current_num_frames in all_num_frames:
|
| 191 |
+
noisy_input = noise[
|
| 192 |
+
:, cache_start_frame - num_input_frames:cache_start_frame + current_num_frames - num_input_frames]
|
| 193 |
+
latents = noisy_input
|
| 194 |
+
|
| 195 |
+
# Step 3.1: Spatial denoising loop
|
| 196 |
+
sample_scheduler = self._initialize_sample_scheduler(noise)
|
| 197 |
+
for _, t in enumerate(tqdm(sample_scheduler.timesteps)):
|
| 198 |
+
latent_model_input = latents
|
| 199 |
+
timestep = t * torch.ones(
|
| 200 |
+
[batch_size, current_num_frames], device=noise.device, dtype=torch.float32
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
flow_pred_cond, _ = self.generator(
|
| 204 |
+
noisy_image_or_video=latent_model_input,
|
| 205 |
+
conditional_dict=conditional_dict,
|
| 206 |
+
timestep=timestep,
|
| 207 |
+
kv_cache=self.kv_cache_pos,
|
| 208 |
+
crossattn_cache=self.crossattn_cache_pos,
|
| 209 |
+
current_start=current_start_frame * self.frame_seq_length,
|
| 210 |
+
cache_start=cache_start_frame * self.frame_seq_length
|
| 211 |
+
)
|
| 212 |
+
flow_pred_uncond, _ = self.generator(
|
| 213 |
+
noisy_image_or_video=latent_model_input,
|
| 214 |
+
conditional_dict=unconditional_dict,
|
| 215 |
+
timestep=timestep,
|
| 216 |
+
kv_cache=self.kv_cache_neg,
|
| 217 |
+
crossattn_cache=self.crossattn_cache_neg,
|
| 218 |
+
current_start=current_start_frame * self.frame_seq_length,
|
| 219 |
+
cache_start=cache_start_frame * self.frame_seq_length
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
flow_pred = flow_pred_uncond + self.args.guidance_scale * (
|
| 223 |
+
flow_pred_cond - flow_pred_uncond)
|
| 224 |
+
|
| 225 |
+
temp_x0 = sample_scheduler.step(
|
| 226 |
+
flow_pred,
|
| 227 |
+
t,
|
| 228 |
+
latents,
|
| 229 |
+
return_dict=False)[0]
|
| 230 |
+
latents = temp_x0
|
| 231 |
+
print(f"kv_cache['local_end_index']: {self.kv_cache_pos[0]['local_end_index']}")
|
| 232 |
+
print(f"kv_cache['global_end_index']: {self.kv_cache_pos[0]['global_end_index']}")
|
| 233 |
+
|
| 234 |
+
# Step 3.2: record the model's output
|
| 235 |
+
output[:, cache_start_frame:cache_start_frame + current_num_frames] = latents
|
| 236 |
+
|
| 237 |
+
# Step 3.3: rerun with timestep zero to update KV cache using clean context
|
| 238 |
+
self.generator(
|
| 239 |
+
noisy_image_or_video=latents,
|
| 240 |
+
conditional_dict=conditional_dict,
|
| 241 |
+
timestep=timestep * 0,
|
| 242 |
+
kv_cache=self.kv_cache_pos,
|
| 243 |
+
crossattn_cache=self.crossattn_cache_pos,
|
| 244 |
+
current_start=current_start_frame * self.frame_seq_length,
|
| 245 |
+
cache_start=cache_start_frame * self.frame_seq_length
|
| 246 |
+
)
|
| 247 |
+
self.generator(
|
| 248 |
+
noisy_image_or_video=latents,
|
| 249 |
+
conditional_dict=unconditional_dict,
|
| 250 |
+
timestep=timestep * 0,
|
| 251 |
+
kv_cache=self.kv_cache_neg,
|
| 252 |
+
crossattn_cache=self.crossattn_cache_neg,
|
| 253 |
+
current_start=current_start_frame * self.frame_seq_length,
|
| 254 |
+
cache_start=cache_start_frame * self.frame_seq_length
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
# Step 3.4: update the start and end frame indices
|
| 258 |
+
current_start_frame += current_num_frames
|
| 259 |
+
cache_start_frame += current_num_frames
|
| 260 |
+
|
| 261 |
+
# Step 4: Decode the output
|
| 262 |
+
video = self.vae.decode_to_pixel(output)
|
| 263 |
+
video = (video * 0.5 + 0.5).clamp(0, 1)
|
| 264 |
+
|
| 265 |
+
if return_latents:
|
| 266 |
+
return video, output
|
| 267 |
+
else:
|
| 268 |
+
return video
|
| 269 |
+
|
| 270 |
+
def _initialize_kv_cache(self, batch_size, dtype, device):
|
| 271 |
+
"""
|
| 272 |
+
Initialize a Per-GPU KV cache for the Wan model.
|
| 273 |
+
"""
|
| 274 |
+
kv_cache_pos = []
|
| 275 |
+
kv_cache_neg = []
|
| 276 |
+
if self.local_attn_size != -1:
|
| 277 |
+
# Use the local attention size to compute the KV cache size
|
| 278 |
+
kv_cache_size = self.local_attn_size * self.frame_seq_length
|
| 279 |
+
else:
|
| 280 |
+
# Use the default KV cache size
|
| 281 |
+
kv_cache_size = 32760
|
| 282 |
+
|
| 283 |
+
for _ in range(self.num_transformer_blocks):
|
| 284 |
+
kv_cache_pos.append({
|
| 285 |
+
"k": torch.zeros([batch_size, kv_cache_size, 12, 128], dtype=dtype, device=device),
|
| 286 |
+
"v": torch.zeros([batch_size, kv_cache_size, 12, 128], dtype=dtype, device=device),
|
| 287 |
+
"global_end_index": torch.tensor([0], dtype=torch.long, device=device),
|
| 288 |
+
"local_end_index": torch.tensor([0], dtype=torch.long, device=device)
|
| 289 |
+
})
|
| 290 |
+
kv_cache_neg.append({
|
| 291 |
+
"k": torch.zeros([batch_size, kv_cache_size, 12, 128], dtype=dtype, device=device),
|
| 292 |
+
"v": torch.zeros([batch_size, kv_cache_size, 12, 128], dtype=dtype, device=device),
|
| 293 |
+
"global_end_index": torch.tensor([0], dtype=torch.long, device=device),
|
| 294 |
+
"local_end_index": torch.tensor([0], dtype=torch.long, device=device)
|
| 295 |
+
})
|
| 296 |
+
|
| 297 |
+
self.kv_cache_pos = kv_cache_pos # always store the clean cache
|
| 298 |
+
self.kv_cache_neg = kv_cache_neg # always store the clean cache
|
| 299 |
+
|
| 300 |
+
def _initialize_crossattn_cache(self, batch_size, dtype, device):
|
| 301 |
+
"""
|
| 302 |
+
Initialize a Per-GPU cross-attention cache for the Wan model.
|
| 303 |
+
"""
|
| 304 |
+
crossattn_cache_pos = []
|
| 305 |
+
crossattn_cache_neg = []
|
| 306 |
+
for _ in range(self.num_transformer_blocks):
|
| 307 |
+
crossattn_cache_pos.append({
|
| 308 |
+
"k": torch.zeros([batch_size, 512, 12, 128], dtype=dtype, device=device),
|
| 309 |
+
"v": torch.zeros([batch_size, 512, 12, 128], dtype=dtype, device=device),
|
| 310 |
+
"is_init": False
|
| 311 |
+
})
|
| 312 |
+
crossattn_cache_neg.append({
|
| 313 |
+
"k": torch.zeros([batch_size, 512, 12, 128], dtype=dtype, device=device),
|
| 314 |
+
"v": torch.zeros([batch_size, 512, 12, 128], dtype=dtype, device=device),
|
| 315 |
+
"is_init": False
|
| 316 |
+
})
|
| 317 |
+
|
| 318 |
+
self.crossattn_cache_pos = crossattn_cache_pos # always store the clean cache
|
| 319 |
+
self.crossattn_cache_neg = crossattn_cache_neg # always store the clean cache
|
| 320 |
+
|
| 321 |
+
def _initialize_sample_scheduler(self, noise):
|
| 322 |
+
if self.sample_solver == 'unipc':
|
| 323 |
+
sample_scheduler = FlowUniPCMultistepScheduler(
|
| 324 |
+
num_train_timesteps=self.num_train_timesteps,
|
| 325 |
+
shift=1,
|
| 326 |
+
use_dynamic_shifting=False)
|
| 327 |
+
sample_scheduler.set_timesteps(
|
| 328 |
+
self.sampling_steps, device=noise.device, shift=self.shift)
|
| 329 |
+
self.timesteps = sample_scheduler.timesteps
|
| 330 |
+
elif self.sample_solver == 'dpm++':
|
| 331 |
+
sample_scheduler = FlowDPMSolverMultistepScheduler(
|
| 332 |
+
num_train_timesteps=self.num_train_timesteps,
|
| 333 |
+
shift=1,
|
| 334 |
+
use_dynamic_shifting=False)
|
| 335 |
+
sampling_sigmas = get_sampling_sigmas(self.sampling_steps, self.shift)
|
| 336 |
+
self.timesteps, _ = retrieve_timesteps(
|
| 337 |
+
sample_scheduler,
|
| 338 |
+
device=noise.device,
|
| 339 |
+
sigmas=sampling_sigmas)
|
| 340 |
+
else:
|
| 341 |
+
raise NotImplementedError("Unsupported solver.")
|
| 342 |
+
return sample_scheduler
|
pipeline/rolling_forcing_inference.py
ADDED
|
@@ -0,0 +1,372 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Optional
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder, WanVAEWrapper
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class CausalInferencePipeline(torch.nn.Module):
|
| 8 |
+
def __init__(
|
| 9 |
+
self,
|
| 10 |
+
args,
|
| 11 |
+
device,
|
| 12 |
+
generator=None,
|
| 13 |
+
text_encoder=None,
|
| 14 |
+
vae=None
|
| 15 |
+
):
|
| 16 |
+
super().__init__()
|
| 17 |
+
# Step 1: Initialize all models
|
| 18 |
+
self.generator = WanDiffusionWrapper(
|
| 19 |
+
**getattr(args, "model_kwargs", {}), is_causal=True) if generator is None else generator
|
| 20 |
+
self.text_encoder = WanTextEncoder() if text_encoder is None else text_encoder
|
| 21 |
+
self.vae = WanVAEWrapper() if vae is None else vae
|
| 22 |
+
|
| 23 |
+
# Step 2: Initialize all causal hyperparmeters
|
| 24 |
+
self.scheduler = self.generator.get_scheduler()
|
| 25 |
+
self.denoising_step_list = torch.tensor(
|
| 26 |
+
args.denoising_step_list, dtype=torch.long)
|
| 27 |
+
if args.warp_denoising_step:
|
| 28 |
+
timesteps = torch.cat((self.scheduler.timesteps.cpu(), torch.tensor([0], dtype=torch.float32)))
|
| 29 |
+
self.denoising_step_list = timesteps[1000 - self.denoising_step_list]
|
| 30 |
+
|
| 31 |
+
self.num_transformer_blocks = 30
|
| 32 |
+
self.frame_seq_length = 1560
|
| 33 |
+
|
| 34 |
+
self.kv_cache_clean = None
|
| 35 |
+
self.args = args
|
| 36 |
+
self.num_frame_per_block = getattr(args, "num_frame_per_block", 1)
|
| 37 |
+
self.independent_first_frame = args.independent_first_frame
|
| 38 |
+
self.local_attn_size = self.generator.model.local_attn_size
|
| 39 |
+
|
| 40 |
+
print(f"KV inference with {self.num_frame_per_block} frames per block")
|
| 41 |
+
|
| 42 |
+
if self.num_frame_per_block > 1:
|
| 43 |
+
self.generator.model.num_frame_per_block = self.num_frame_per_block
|
| 44 |
+
|
| 45 |
+
def inference_rolling_forcing(
|
| 46 |
+
self,
|
| 47 |
+
noise: torch.Tensor,
|
| 48 |
+
text_prompts: List[str],
|
| 49 |
+
initial_latent: Optional[torch.Tensor] = None,
|
| 50 |
+
return_latents: bool = False,
|
| 51 |
+
profile: bool = False
|
| 52 |
+
) -> torch.Tensor:
|
| 53 |
+
"""
|
| 54 |
+
Perform inference on the given noise and text prompts.
|
| 55 |
+
Inputs:
|
| 56 |
+
noise (torch.Tensor): The input noise tensor of shape
|
| 57 |
+
(batch_size, num_output_frames, num_channels, height, width).
|
| 58 |
+
text_prompts (List[str]): The list of text prompts.
|
| 59 |
+
initial_latent (torch.Tensor): The initial latent tensor of shape
|
| 60 |
+
(batch_size, num_input_frames, num_channels, height, width).
|
| 61 |
+
If num_input_frames is 1, perform image to video.
|
| 62 |
+
If num_input_frames is greater than 1, perform video extension.
|
| 63 |
+
return_latents (bool): Whether to return the latents.
|
| 64 |
+
Outputs:
|
| 65 |
+
video (torch.Tensor): The generated video tensor of shape
|
| 66 |
+
(batch_size, num_output_frames, num_channels, height, width).
|
| 67 |
+
It is normalized to be in the range [0, 1].
|
| 68 |
+
"""
|
| 69 |
+
batch_size, num_frames, num_channels, height, width = noise.shape
|
| 70 |
+
if not self.independent_first_frame or (self.independent_first_frame and initial_latent is not None):
|
| 71 |
+
# If the first frame is independent and the first frame is provided, then the number of frames in the
|
| 72 |
+
# noise should still be a multiple of num_frame_per_block
|
| 73 |
+
assert num_frames % self.num_frame_per_block == 0
|
| 74 |
+
num_blocks = num_frames // self.num_frame_per_block
|
| 75 |
+
else:
|
| 76 |
+
# Using a [1, 4, 4, 4, 4, 4, ...] model to generate a video without image conditioning
|
| 77 |
+
assert (num_frames - 1) % self.num_frame_per_block == 0
|
| 78 |
+
num_blocks = (num_frames - 1) // self.num_frame_per_block
|
| 79 |
+
num_input_frames = initial_latent.shape[1] if initial_latent is not None else 0
|
| 80 |
+
num_output_frames = num_frames + num_input_frames # add the initial latent frames
|
| 81 |
+
conditional_dict = self.text_encoder(
|
| 82 |
+
text_prompts=text_prompts
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
output = torch.zeros(
|
| 86 |
+
[batch_size, num_output_frames, num_channels, height, width],
|
| 87 |
+
device=noise.device,
|
| 88 |
+
dtype=noise.dtype
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
# Set up profiling if requested
|
| 92 |
+
if profile:
|
| 93 |
+
init_start = torch.cuda.Event(enable_timing=True)
|
| 94 |
+
init_end = torch.cuda.Event(enable_timing=True)
|
| 95 |
+
diffusion_start = torch.cuda.Event(enable_timing=True)
|
| 96 |
+
diffusion_end = torch.cuda.Event(enable_timing=True)
|
| 97 |
+
vae_start = torch.cuda.Event(enable_timing=True)
|
| 98 |
+
vae_end = torch.cuda.Event(enable_timing=True)
|
| 99 |
+
block_times = []
|
| 100 |
+
block_start = torch.cuda.Event(enable_timing=True)
|
| 101 |
+
block_end = torch.cuda.Event(enable_timing=True)
|
| 102 |
+
init_start.record()
|
| 103 |
+
|
| 104 |
+
# Step 1: Initialize KV cache to all zeros
|
| 105 |
+
if self.kv_cache_clean is None:
|
| 106 |
+
self._initialize_kv_cache(
|
| 107 |
+
batch_size=batch_size,
|
| 108 |
+
dtype=noise.dtype,
|
| 109 |
+
device=noise.device
|
| 110 |
+
)
|
| 111 |
+
self._initialize_crossattn_cache(
|
| 112 |
+
batch_size=batch_size,
|
| 113 |
+
dtype=noise.dtype,
|
| 114 |
+
device=noise.device
|
| 115 |
+
)
|
| 116 |
+
else:
|
| 117 |
+
# reset cross attn cache
|
| 118 |
+
for block_index in range(self.num_transformer_blocks):
|
| 119 |
+
self.crossattn_cache[block_index]["is_init"] = False
|
| 120 |
+
# reset kv cache
|
| 121 |
+
for block_index in range(len(self.kv_cache_clean)):
|
| 122 |
+
self.kv_cache_clean[block_index]["global_end_index"] = torch.tensor(
|
| 123 |
+
[0], dtype=torch.long, device=noise.device)
|
| 124 |
+
self.kv_cache_clean[block_index]["local_end_index"] = torch.tensor(
|
| 125 |
+
[0], dtype=torch.long, device=noise.device)
|
| 126 |
+
|
| 127 |
+
# Step 2: Cache context feature
|
| 128 |
+
if initial_latent is not None:
|
| 129 |
+
timestep = torch.ones([batch_size, 1], device=noise.device, dtype=torch.int64) * 0
|
| 130 |
+
if self.independent_first_frame:
|
| 131 |
+
# Assume num_input_frames is 1 + self.num_frame_per_block * num_input_blocks
|
| 132 |
+
assert (num_input_frames - 1) % self.num_frame_per_block == 0
|
| 133 |
+
num_input_blocks = (num_input_frames - 1) // self.num_frame_per_block
|
| 134 |
+
output[:, :1] = initial_latent[:, :1]
|
| 135 |
+
self.generator(
|
| 136 |
+
noisy_image_or_video=initial_latent[:, :1],
|
| 137 |
+
conditional_dict=conditional_dict,
|
| 138 |
+
timestep=timestep * 0,
|
| 139 |
+
kv_cache=self.kv_cache_clean,
|
| 140 |
+
crossattn_cache=self.crossattn_cache,
|
| 141 |
+
current_start=current_start_frame * self.frame_seq_length,
|
| 142 |
+
)
|
| 143 |
+
current_start_frame += 1
|
| 144 |
+
else:
|
| 145 |
+
# Assume num_input_frames is self.num_frame_per_block * num_input_blocks
|
| 146 |
+
assert num_input_frames % self.num_frame_per_block == 0
|
| 147 |
+
num_input_blocks = num_input_frames // self.num_frame_per_block
|
| 148 |
+
|
| 149 |
+
for _ in range(num_input_blocks):
|
| 150 |
+
current_ref_latents = \
|
| 151 |
+
initial_latent[:, current_start_frame:current_start_frame + self.num_frame_per_block]
|
| 152 |
+
output[:, current_start_frame:current_start_frame + self.num_frame_per_block] = current_ref_latents
|
| 153 |
+
self.generator(
|
| 154 |
+
noisy_image_or_video=current_ref_latents,
|
| 155 |
+
conditional_dict=conditional_dict,
|
| 156 |
+
timestep=timestep * 0,
|
| 157 |
+
kv_cache=self.kv_cache_clean,
|
| 158 |
+
crossattn_cache=self.crossattn_cache,
|
| 159 |
+
current_start=current_start_frame * self.frame_seq_length,
|
| 160 |
+
)
|
| 161 |
+
current_start_frame += self.num_frame_per_block
|
| 162 |
+
|
| 163 |
+
if profile:
|
| 164 |
+
init_end.record()
|
| 165 |
+
torch.cuda.synchronize()
|
| 166 |
+
diffusion_start.record()
|
| 167 |
+
|
| 168 |
+
# implementing rolling forcing
|
| 169 |
+
# construct the rolling forcing windows
|
| 170 |
+
num_denoising_steps = len(self.denoising_step_list)
|
| 171 |
+
rolling_window_length_blocks = num_denoising_steps
|
| 172 |
+
window_start_blocks = []
|
| 173 |
+
window_end_blocks = []
|
| 174 |
+
window_num = num_blocks + rolling_window_length_blocks - 1
|
| 175 |
+
|
| 176 |
+
for window_index in range(window_num):
|
| 177 |
+
start_block = max(0, window_index - rolling_window_length_blocks + 1)
|
| 178 |
+
end_block = min(num_blocks - 1, window_index)
|
| 179 |
+
window_start_blocks.append(start_block)
|
| 180 |
+
window_end_blocks.append(end_block)
|
| 181 |
+
|
| 182 |
+
# init noisy cache
|
| 183 |
+
noisy_cache = torch.zeros(
|
| 184 |
+
[batch_size, num_output_frames, num_channels, height, width],
|
| 185 |
+
device=noise.device,
|
| 186 |
+
dtype=noise.dtype
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
# init denosing timestep, same accross windows
|
| 190 |
+
shared_timestep = torch.ones(
|
| 191 |
+
[batch_size, rolling_window_length_blocks * self.num_frame_per_block],
|
| 192 |
+
device=noise.device,
|
| 193 |
+
dtype=torch.float32)
|
| 194 |
+
|
| 195 |
+
for index, current_timestep in enumerate(reversed(self.denoising_step_list)): # from clean to noisy
|
| 196 |
+
shared_timestep[:, index * self.num_frame_per_block:(index + 1) * self.num_frame_per_block] *= current_timestep
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
# Denoising loop with rolling forcing
|
| 200 |
+
for window_index in range(window_num):
|
| 201 |
+
|
| 202 |
+
if profile:
|
| 203 |
+
block_start.record()
|
| 204 |
+
|
| 205 |
+
print('window_index:', window_index)
|
| 206 |
+
start_block = window_start_blocks[window_index]
|
| 207 |
+
end_block = window_end_blocks[window_index] # include
|
| 208 |
+
print(f"start_block: {start_block}, end_block: {end_block}")
|
| 209 |
+
|
| 210 |
+
current_start_frame = start_block * self.num_frame_per_block
|
| 211 |
+
current_end_frame = (end_block + 1) * self.num_frame_per_block # not include
|
| 212 |
+
current_num_frames = current_end_frame - current_start_frame
|
| 213 |
+
|
| 214 |
+
# noisy_input: new noise and previous denoised noisy frames, only last block is pure noise
|
| 215 |
+
if current_num_frames == rolling_window_length_blocks * self.num_frame_per_block or current_start_frame == 0:
|
| 216 |
+
noisy_input = torch.cat([
|
| 217 |
+
noisy_cache[:, current_start_frame : current_end_frame - self.num_frame_per_block],
|
| 218 |
+
noise[:, current_end_frame - self.num_frame_per_block : current_end_frame ]
|
| 219 |
+
], dim=1)
|
| 220 |
+
else: # at the end of the video
|
| 221 |
+
noisy_input = noisy_cache[:, current_start_frame:current_end_frame]
|
| 222 |
+
|
| 223 |
+
# init denosing timestep
|
| 224 |
+
if current_num_frames == rolling_window_length_blocks * self.num_frame_per_block:
|
| 225 |
+
current_timestep = shared_timestep
|
| 226 |
+
elif current_start_frame == 0:
|
| 227 |
+
current_timestep = shared_timestep[:,-current_num_frames:]
|
| 228 |
+
elif current_end_frame == num_frames:
|
| 229 |
+
current_timestep = shared_timestep[:,:current_num_frames]
|
| 230 |
+
else:
|
| 231 |
+
raise ValueError("current_num_frames should be equal to rolling_window_length_blocks * self.num_frame_per_block, or the first or last window.")
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
# calling DiT
|
| 235 |
+
_, denoised_pred = self.generator(
|
| 236 |
+
noisy_image_or_video=noisy_input,
|
| 237 |
+
conditional_dict=conditional_dict,
|
| 238 |
+
timestep=current_timestep,
|
| 239 |
+
kv_cache=self.kv_cache_clean,
|
| 240 |
+
crossattn_cache=self.crossattn_cache,
|
| 241 |
+
current_start=current_start_frame * self.frame_seq_length
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
output[:, current_start_frame:current_end_frame] = denoised_pred
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
# update noisy_cache, which is detached from the computation graph
|
| 248 |
+
with torch.no_grad():
|
| 249 |
+
for block_idx in range(start_block, end_block + 1):
|
| 250 |
+
|
| 251 |
+
block_time_step = current_timestep[:,
|
| 252 |
+
(block_idx - start_block)*self.num_frame_per_block :
|
| 253 |
+
(block_idx - start_block+1)*self.num_frame_per_block].mean().item()
|
| 254 |
+
matches = torch.abs(self.denoising_step_list - block_time_step) < 1e-4
|
| 255 |
+
block_timestep_index = torch.nonzero(matches, as_tuple=True)[0]
|
| 256 |
+
|
| 257 |
+
if block_timestep_index == len(self.denoising_step_list) - 1:
|
| 258 |
+
continue
|
| 259 |
+
|
| 260 |
+
next_timestep = self.denoising_step_list[block_timestep_index + 1].to(noise.device)
|
| 261 |
+
|
| 262 |
+
noisy_cache[:, block_idx * self.num_frame_per_block:
|
| 263 |
+
(block_idx+1) * self.num_frame_per_block] = \
|
| 264 |
+
self.scheduler.add_noise(
|
| 265 |
+
denoised_pred.flatten(0, 1),
|
| 266 |
+
torch.randn_like(denoised_pred.flatten(0, 1)),
|
| 267 |
+
next_timestep * torch.ones(
|
| 268 |
+
[batch_size * current_num_frames], device=noise.device, dtype=torch.long)
|
| 269 |
+
).unflatten(0, denoised_pred.shape[:2])[:, (block_idx - start_block)*self.num_frame_per_block:
|
| 270 |
+
(block_idx - start_block+1)*self.num_frame_per_block]
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
# rerun with timestep zero to update the clean cache, which is also detached from the computation graph
|
| 274 |
+
with torch.no_grad():
|
| 275 |
+
context_timestep = torch.ones_like(current_timestep) * self.args.context_noise
|
| 276 |
+
# # add context noise
|
| 277 |
+
# denoised_pred = self.scheduler.add_noise(
|
| 278 |
+
# denoised_pred.flatten(0, 1),
|
| 279 |
+
# torch.randn_like(denoised_pred.flatten(0, 1)),
|
| 280 |
+
# context_timestep * torch.ones(
|
| 281 |
+
# [batch_size * current_num_frames], device=noise.device, dtype=torch.long)
|
| 282 |
+
# ).unflatten(0, denoised_pred.shape[:2])
|
| 283 |
+
|
| 284 |
+
# only cache the first block
|
| 285 |
+
denoised_pred = denoised_pred[:,:self.num_frame_per_block]
|
| 286 |
+
context_timestep = context_timestep[:,:self.num_frame_per_block]
|
| 287 |
+
self.generator(
|
| 288 |
+
noisy_image_or_video=denoised_pred,
|
| 289 |
+
conditional_dict=conditional_dict,
|
| 290 |
+
timestep=context_timestep,
|
| 291 |
+
kv_cache=self.kv_cache_clean,
|
| 292 |
+
crossattn_cache=self.crossattn_cache,
|
| 293 |
+
current_start=current_start_frame * self.frame_seq_length,
|
| 294 |
+
updating_cache=True,
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
if profile:
|
| 298 |
+
block_end.record()
|
| 299 |
+
torch.cuda.synchronize()
|
| 300 |
+
block_time = block_start.elapsed_time(block_end)
|
| 301 |
+
block_times.append(block_time)
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
if profile:
|
| 305 |
+
# End diffusion timing and synchronize CUDA
|
| 306 |
+
diffusion_end.record()
|
| 307 |
+
torch.cuda.synchronize()
|
| 308 |
+
diffusion_time = diffusion_start.elapsed_time(diffusion_end)
|
| 309 |
+
init_time = init_start.elapsed_time(init_end)
|
| 310 |
+
vae_start.record()
|
| 311 |
+
|
| 312 |
+
# Step 4: Decode the output
|
| 313 |
+
video = self.vae.decode_to_pixel(output, use_cache=False)
|
| 314 |
+
video = (video * 0.5 + 0.5).clamp(0, 1)
|
| 315 |
+
|
| 316 |
+
if profile:
|
| 317 |
+
# End VAE timing and synchronize CUDA
|
| 318 |
+
vae_end.record()
|
| 319 |
+
torch.cuda.synchronize()
|
| 320 |
+
vae_time = vae_start.elapsed_time(vae_end)
|
| 321 |
+
total_time = init_time + diffusion_time + vae_time
|
| 322 |
+
|
| 323 |
+
print("Profiling results:")
|
| 324 |
+
print(f" - Initialization/caching time: {init_time:.2f} ms ({100 * init_time / total_time:.2f}%)")
|
| 325 |
+
print(f" - Diffusion generation time: {diffusion_time:.2f} ms ({100 * diffusion_time / total_time:.2f}%)")
|
| 326 |
+
for i, block_time in enumerate(block_times):
|
| 327 |
+
print(f" - Block {i} generation time: {block_time:.2f} ms ({100 * block_time / diffusion_time:.2f}% of diffusion)")
|
| 328 |
+
print(f" - VAE decoding time: {vae_time:.2f} ms ({100 * vae_time / total_time:.2f}%)")
|
| 329 |
+
print(f" - Total time: {total_time:.2f} ms")
|
| 330 |
+
|
| 331 |
+
if return_latents:
|
| 332 |
+
return video, output
|
| 333 |
+
else:
|
| 334 |
+
return video
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
def _initialize_kv_cache(self, batch_size, dtype, device):
|
| 339 |
+
"""
|
| 340 |
+
Initialize a Per-GPU KV cache for the Wan model.
|
| 341 |
+
"""
|
| 342 |
+
kv_cache_clean = []
|
| 343 |
+
# if self.local_attn_size != -1:
|
| 344 |
+
# # Use the local attention size to compute the KV cache size
|
| 345 |
+
# kv_cache_size = self.local_attn_size * self.frame_seq_length
|
| 346 |
+
# else:
|
| 347 |
+
# # Use the default KV cache size
|
| 348 |
+
kv_cache_size = 1560 * 24
|
| 349 |
+
|
| 350 |
+
for _ in range(self.num_transformer_blocks):
|
| 351 |
+
kv_cache_clean.append({
|
| 352 |
+
"k": torch.zeros([batch_size, kv_cache_size, 12, 128], dtype=dtype, device=device),
|
| 353 |
+
"v": torch.zeros([batch_size, kv_cache_size, 12, 128], dtype=dtype, device=device),
|
| 354 |
+
"global_end_index": torch.tensor([0], dtype=torch.long, device=device),
|
| 355 |
+
"local_end_index": torch.tensor([0], dtype=torch.long, device=device)
|
| 356 |
+
})
|
| 357 |
+
|
| 358 |
+
self.kv_cache_clean = kv_cache_clean # always store the clean cache
|
| 359 |
+
|
| 360 |
+
def _initialize_crossattn_cache(self, batch_size, dtype, device):
|
| 361 |
+
"""
|
| 362 |
+
Initialize a Per-GPU cross-attention cache for the Wan model.
|
| 363 |
+
"""
|
| 364 |
+
crossattn_cache = []
|
| 365 |
+
|
| 366 |
+
for _ in range(self.num_transformer_blocks):
|
| 367 |
+
crossattn_cache.append({
|
| 368 |
+
"k": torch.zeros([batch_size, 512, 12, 128], dtype=dtype, device=device),
|
| 369 |
+
"v": torch.zeros([batch_size, 512, 12, 128], dtype=dtype, device=device),
|
| 370 |
+
"is_init": False
|
| 371 |
+
})
|
| 372 |
+
self.crossattn_cache = crossattn_cache
|
pipeline/rolling_forcing_training.py
ADDED
|
@@ -0,0 +1,464 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from utils.wan_wrapper import WanDiffusionWrapper
|
| 2 |
+
from utils.scheduler import SchedulerInterface
|
| 3 |
+
from typing import List, Optional
|
| 4 |
+
import torch
|
| 5 |
+
import torch.distributed as dist
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class RollingForcingTrainingPipeline:
|
| 9 |
+
def __init__(self,
|
| 10 |
+
denoising_step_list: List[int],
|
| 11 |
+
scheduler: SchedulerInterface,
|
| 12 |
+
generator: WanDiffusionWrapper,
|
| 13 |
+
num_frame_per_block=3,
|
| 14 |
+
independent_first_frame: bool = False,
|
| 15 |
+
same_step_across_blocks: bool = False,
|
| 16 |
+
last_step_only: bool = False,
|
| 17 |
+
num_max_frames: int = 21,
|
| 18 |
+
context_noise: int = 0,
|
| 19 |
+
**kwargs):
|
| 20 |
+
super().__init__()
|
| 21 |
+
self.scheduler = scheduler
|
| 22 |
+
self.generator = generator
|
| 23 |
+
self.denoising_step_list = denoising_step_list
|
| 24 |
+
if self.denoising_step_list[-1] == 0:
|
| 25 |
+
self.denoising_step_list = self.denoising_step_list[:-1] # remove the zero timestep for inference
|
| 26 |
+
|
| 27 |
+
# Wan specific hyperparameters
|
| 28 |
+
self.num_transformer_blocks = 30
|
| 29 |
+
self.frame_seq_length = 1560
|
| 30 |
+
self.num_frame_per_block = num_frame_per_block
|
| 31 |
+
self.context_noise = context_noise
|
| 32 |
+
self.i2v = False
|
| 33 |
+
|
| 34 |
+
self.kv_cache_clean = None
|
| 35 |
+
self.kv_cache2 = None
|
| 36 |
+
self.independent_first_frame = independent_first_frame
|
| 37 |
+
self.same_step_across_blocks = same_step_across_blocks
|
| 38 |
+
self.last_step_only = last_step_only
|
| 39 |
+
self.kv_cache_size = num_max_frames * self.frame_seq_length
|
| 40 |
+
|
| 41 |
+
def generate_and_sync_list(self, num_blocks, num_denoising_steps, device):
|
| 42 |
+
rank = dist.get_rank() if dist.is_initialized() else 0
|
| 43 |
+
|
| 44 |
+
if rank == 0:
|
| 45 |
+
# Generate random indices
|
| 46 |
+
indices = torch.randint(
|
| 47 |
+
low=0,
|
| 48 |
+
high=num_denoising_steps,
|
| 49 |
+
size=(num_blocks,),
|
| 50 |
+
device=device
|
| 51 |
+
)
|
| 52 |
+
if self.last_step_only:
|
| 53 |
+
indices = torch.ones_like(indices) * (num_denoising_steps - 1)
|
| 54 |
+
else:
|
| 55 |
+
indices = torch.empty(num_blocks, dtype=torch.long, device=device)
|
| 56 |
+
|
| 57 |
+
dist.broadcast(indices, src=0) # Broadcast the random indices to all ranks
|
| 58 |
+
return indices.tolist()
|
| 59 |
+
|
| 60 |
+
def generate_list(self, num_blocks, num_denoising_steps, device):
|
| 61 |
+
|
| 62 |
+
# Generate random indices
|
| 63 |
+
indices = torch.randint(
|
| 64 |
+
low=0,
|
| 65 |
+
high=num_denoising_steps,
|
| 66 |
+
size=(num_blocks,),
|
| 67 |
+
device=device
|
| 68 |
+
)
|
| 69 |
+
if self.last_step_only:
|
| 70 |
+
indices = torch.ones_like(indices) * (num_denoising_steps - 1)
|
| 71 |
+
|
| 72 |
+
return indices.tolist()
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def inference_with_rolling_forcing(
|
| 76 |
+
self,
|
| 77 |
+
noise: torch.Tensor,
|
| 78 |
+
initial_latent: Optional[torch.Tensor] = None,
|
| 79 |
+
return_sim_step: bool = False,
|
| 80 |
+
**conditional_dict
|
| 81 |
+
) -> torch.Tensor:
|
| 82 |
+
batch_size, num_frames, num_channels, height, width = noise.shape
|
| 83 |
+
if not self.independent_first_frame or (self.independent_first_frame and initial_latent is not None):
|
| 84 |
+
# If the first frame is independent and the first frame is provided, then the number of frames in the
|
| 85 |
+
# noise should still be a multiple of num_frame_per_block
|
| 86 |
+
assert num_frames % self.num_frame_per_block == 0
|
| 87 |
+
num_blocks = num_frames // self.num_frame_per_block
|
| 88 |
+
else:
|
| 89 |
+
# Using a [1, 4, 4, 4, 4, 4, ...] model to generate a video without image conditioning
|
| 90 |
+
assert (num_frames - 1) % self.num_frame_per_block == 0
|
| 91 |
+
num_blocks = (num_frames - 1) // self.num_frame_per_block
|
| 92 |
+
num_input_frames = initial_latent.shape[1] if initial_latent is not None else 0
|
| 93 |
+
num_output_frames = num_frames + num_input_frames # add the initial latent frames
|
| 94 |
+
output = torch.zeros(
|
| 95 |
+
[batch_size, num_output_frames, num_channels, height, width],
|
| 96 |
+
device=noise.device,
|
| 97 |
+
dtype=noise.dtype
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
# Step 1: Initialize KV cache to all zeros
|
| 101 |
+
self._initialize_kv_cache(
|
| 102 |
+
batch_size=batch_size, dtype=noise.dtype, device=noise.device
|
| 103 |
+
)
|
| 104 |
+
self._initialize_crossattn_cache(
|
| 105 |
+
batch_size=batch_size, dtype=noise.dtype, device=noise.device
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
# implementing rolling forcing
|
| 109 |
+
# construct the rolling forcing windows
|
| 110 |
+
num_denoising_steps = len(self.denoising_step_list)
|
| 111 |
+
rolling_window_length_blocks = num_denoising_steps
|
| 112 |
+
window_start_blocks = []
|
| 113 |
+
window_end_blocks = []
|
| 114 |
+
window_num = num_blocks + rolling_window_length_blocks - 1
|
| 115 |
+
|
| 116 |
+
for window_index in range(window_num):
|
| 117 |
+
start_block = max(0, window_index - rolling_window_length_blocks + 1)
|
| 118 |
+
end_block = min(num_blocks - 1, window_index)
|
| 119 |
+
window_start_blocks.append(start_block)
|
| 120 |
+
window_end_blocks.append(end_block)
|
| 121 |
+
|
| 122 |
+
# exit_flag indicates the window at which the model will backpropagate gradients.
|
| 123 |
+
exit_flag = torch.randint(high=rolling_window_length_blocks, device=noise.device, size=())
|
| 124 |
+
start_gradient_frame_index = num_output_frames - 21
|
| 125 |
+
|
| 126 |
+
# init noisy cache
|
| 127 |
+
noisy_cache = torch.zeros(
|
| 128 |
+
[batch_size, num_output_frames, num_channels, height, width],
|
| 129 |
+
device=noise.device,
|
| 130 |
+
dtype=noise.dtype
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
# init denosing timestep, same accross windows
|
| 134 |
+
shared_timestep = torch.ones(
|
| 135 |
+
[batch_size, rolling_window_length_blocks * self.num_frame_per_block],
|
| 136 |
+
device=noise.device,
|
| 137 |
+
dtype=torch.float32)
|
| 138 |
+
|
| 139 |
+
for index, current_timestep in enumerate(reversed(self.denoising_step_list)): # from clean to noisy
|
| 140 |
+
shared_timestep[:, index * self.num_frame_per_block:(index + 1) * self.num_frame_per_block] *= current_timestep
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
# Denoising loop with rolling forcing
|
| 144 |
+
for window_index in range(window_num):
|
| 145 |
+
start_block = window_start_blocks[window_index]
|
| 146 |
+
end_block = window_end_blocks[window_index] # include
|
| 147 |
+
|
| 148 |
+
current_start_frame = start_block * self.num_frame_per_block
|
| 149 |
+
current_end_frame = (end_block + 1) * self.num_frame_per_block # not include
|
| 150 |
+
current_num_frames = current_end_frame - current_start_frame
|
| 151 |
+
|
| 152 |
+
# noisy_input: new noise and previous denoised noisy frames, only last block is pure noise
|
| 153 |
+
if current_num_frames == rolling_window_length_blocks * self.num_frame_per_block or current_start_frame == 0:
|
| 154 |
+
noisy_input = torch.cat([
|
| 155 |
+
noisy_cache[:, current_start_frame : current_end_frame - self.num_frame_per_block],
|
| 156 |
+
noise[:, current_end_frame - self.num_frame_per_block : current_end_frame ]
|
| 157 |
+
], dim=1)
|
| 158 |
+
else: # at the end of the video
|
| 159 |
+
noisy_input = noisy_cache[:, current_start_frame:current_end_frame].clone()
|
| 160 |
+
|
| 161 |
+
# init denosing timestep
|
| 162 |
+
if current_num_frames == rolling_window_length_blocks * self.num_frame_per_block:
|
| 163 |
+
current_timestep = shared_timestep
|
| 164 |
+
elif current_start_frame == 0:
|
| 165 |
+
current_timestep = shared_timestep[:,-current_num_frames:]
|
| 166 |
+
elif current_end_frame == num_frames:
|
| 167 |
+
current_timestep = shared_timestep[:,:current_num_frames]
|
| 168 |
+
else:
|
| 169 |
+
raise ValueError("current_num_frames should be equal to rolling_window_length_blocks * self.num_frame_per_block, or the first or last window.")
|
| 170 |
+
|
| 171 |
+
require_grad = window_index % rolling_window_length_blocks == exit_flag
|
| 172 |
+
if current_end_frame <= start_gradient_frame_index:
|
| 173 |
+
require_grad = False
|
| 174 |
+
|
| 175 |
+
# calling DiT
|
| 176 |
+
if not require_grad:
|
| 177 |
+
with torch.no_grad():
|
| 178 |
+
_, denoised_pred = self.generator(
|
| 179 |
+
noisy_image_or_video=noisy_input,
|
| 180 |
+
conditional_dict=conditional_dict,
|
| 181 |
+
timestep=current_timestep,
|
| 182 |
+
kv_cache=self.kv_cache_clean,
|
| 183 |
+
crossattn_cache=self.crossattn_cache,
|
| 184 |
+
current_start=current_start_frame * self.frame_seq_length
|
| 185 |
+
)
|
| 186 |
+
else:
|
| 187 |
+
_, denoised_pred = self.generator(
|
| 188 |
+
noisy_image_or_video=noisy_input,
|
| 189 |
+
conditional_dict=conditional_dict,
|
| 190 |
+
timestep=current_timestep,
|
| 191 |
+
kv_cache=self.kv_cache_clean,
|
| 192 |
+
crossattn_cache=self.crossattn_cache,
|
| 193 |
+
current_start=current_start_frame * self.frame_seq_length
|
| 194 |
+
)
|
| 195 |
+
output[:, current_start_frame:current_end_frame] = denoised_pred
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
# update noisy_cache, which is detached from the computation graph
|
| 199 |
+
with torch.no_grad():
|
| 200 |
+
for block_idx in range(start_block, end_block + 1):
|
| 201 |
+
|
| 202 |
+
block_time_step = current_timestep[:,
|
| 203 |
+
(block_idx - start_block)*self.num_frame_per_block :
|
| 204 |
+
(block_idx - start_block+1)*self.num_frame_per_block].mean().item()
|
| 205 |
+
matches = torch.abs(self.denoising_step_list - block_time_step) < 1e-4
|
| 206 |
+
block_timestep_index = torch.nonzero(matches, as_tuple=True)[0]
|
| 207 |
+
|
| 208 |
+
if block_timestep_index == len(self.denoising_step_list) - 1:
|
| 209 |
+
continue
|
| 210 |
+
|
| 211 |
+
next_timestep = self.denoising_step_list[block_timestep_index + 1].to(noise.device)
|
| 212 |
+
|
| 213 |
+
noisy_cache[:, block_idx * self.num_frame_per_block:
|
| 214 |
+
(block_idx+1) * self.num_frame_per_block] = \
|
| 215 |
+
self.scheduler.add_noise(
|
| 216 |
+
denoised_pred.flatten(0, 1),
|
| 217 |
+
torch.randn_like(denoised_pred.flatten(0, 1)),
|
| 218 |
+
next_timestep * torch.ones(
|
| 219 |
+
[batch_size * current_num_frames], device=noise.device, dtype=torch.long)
|
| 220 |
+
).unflatten(0, denoised_pred.shape[:2])[:, (block_idx - start_block)*self.num_frame_per_block:
|
| 221 |
+
(block_idx - start_block+1)*self.num_frame_per_block]
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
# rerun with timestep zero to update the clean cache, which is also detached from the computation graph
|
| 225 |
+
with torch.no_grad():
|
| 226 |
+
context_timestep = torch.ones_like(current_timestep) * self.context_noise
|
| 227 |
+
# # add context noise
|
| 228 |
+
# denoised_pred = self.scheduler.add_noise(
|
| 229 |
+
# denoised_pred.flatten(0, 1),
|
| 230 |
+
# torch.randn_like(denoised_pred.flatten(0, 1)),
|
| 231 |
+
# context_timestep * torch.ones(
|
| 232 |
+
# [batch_size * current_num_frames], device=noise.device, dtype=torch.long)
|
| 233 |
+
# ).unflatten(0, denoised_pred.shape[:2])
|
| 234 |
+
|
| 235 |
+
# only cache the first block
|
| 236 |
+
denoised_pred = denoised_pred[:,:self.num_frame_per_block]
|
| 237 |
+
context_timestep = context_timestep[:,:self.num_frame_per_block]
|
| 238 |
+
self.generator(
|
| 239 |
+
noisy_image_or_video=denoised_pred,
|
| 240 |
+
conditional_dict=conditional_dict,
|
| 241 |
+
timestep=context_timestep,
|
| 242 |
+
kv_cache=self.kv_cache_clean,
|
| 243 |
+
crossattn_cache=self.crossattn_cache,
|
| 244 |
+
current_start=current_start_frame * self.frame_seq_length,
|
| 245 |
+
updating_cache=True,
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
# Step 3.5: Return the denoised timestep
|
| 249 |
+
# can ignore since not used
|
| 250 |
+
denoised_timestep_from, denoised_timestep_to = None, None
|
| 251 |
+
|
| 252 |
+
return output, denoised_timestep_from, denoised_timestep_to
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def inference_with_self_forcing(
|
| 257 |
+
self,
|
| 258 |
+
noise: torch.Tensor,
|
| 259 |
+
initial_latent: Optional[torch.Tensor] = None,
|
| 260 |
+
return_sim_step: bool = False,
|
| 261 |
+
**conditional_dict
|
| 262 |
+
) -> torch.Tensor:
|
| 263 |
+
batch_size, num_frames, num_channels, height, width = noise.shape
|
| 264 |
+
if not self.independent_first_frame or (self.independent_first_frame and initial_latent is not None):
|
| 265 |
+
# If the first frame is independent and the first frame is provided, then the number of frames in the
|
| 266 |
+
# noise should still be a multiple of num_frame_per_block
|
| 267 |
+
assert num_frames % self.num_frame_per_block == 0
|
| 268 |
+
num_blocks = num_frames // self.num_frame_per_block
|
| 269 |
+
else:
|
| 270 |
+
# Using a [1, 4, 4, 4, 4, 4, ...] model to generate a video without image conditioning
|
| 271 |
+
assert (num_frames - 1) % self.num_frame_per_block == 0
|
| 272 |
+
num_blocks = (num_frames - 1) // self.num_frame_per_block
|
| 273 |
+
num_input_frames = initial_latent.shape[1] if initial_latent is not None else 0
|
| 274 |
+
num_output_frames = num_frames + num_input_frames # add the initial latent frames
|
| 275 |
+
output = torch.zeros(
|
| 276 |
+
[batch_size, num_output_frames, num_channels, height, width],
|
| 277 |
+
device=noise.device,
|
| 278 |
+
dtype=noise.dtype
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
# Step 1: Initialize KV cache to all zeros
|
| 282 |
+
self._initialize_kv_cache(
|
| 283 |
+
batch_size=batch_size, dtype=noise.dtype, device=noise.device
|
| 284 |
+
)
|
| 285 |
+
self._initialize_crossattn_cache(
|
| 286 |
+
batch_size=batch_size, dtype=noise.dtype, device=noise.device
|
| 287 |
+
)
|
| 288 |
+
# if self.kv_cache_clean is None:
|
| 289 |
+
# self._initialize_kv_cache(
|
| 290 |
+
# batch_size=batch_size,
|
| 291 |
+
# dtype=noise.dtype,
|
| 292 |
+
# device=noise.device,
|
| 293 |
+
# )
|
| 294 |
+
# self._initialize_crossattn_cache(
|
| 295 |
+
# batch_size=batch_size,
|
| 296 |
+
# dtype=noise.dtype,
|
| 297 |
+
# device=noise.device
|
| 298 |
+
# )
|
| 299 |
+
# else:
|
| 300 |
+
# # reset cross attn cache
|
| 301 |
+
# for block_index in range(self.num_transformer_blocks):
|
| 302 |
+
# self.crossattn_cache[block_index]["is_init"] = False
|
| 303 |
+
# # reset kv cache
|
| 304 |
+
# for block_index in range(len(self.kv_cache_clean)):
|
| 305 |
+
# self.kv_cache_clean[block_index]["global_end_index"] = torch.tensor(
|
| 306 |
+
# [0], dtype=torch.long, device=noise.device)
|
| 307 |
+
# self.kv_cache_clean[block_index]["local_end_index"] = torch.tensor(
|
| 308 |
+
# [0], dtype=torch.long, device=noise.device)
|
| 309 |
+
|
| 310 |
+
# Step 2: Cache context feature
|
| 311 |
+
current_start_frame = 0
|
| 312 |
+
if initial_latent is not None:
|
| 313 |
+
timestep = torch.ones([batch_size, 1], device=noise.device, dtype=torch.int64) * 0
|
| 314 |
+
# Assume num_input_frames is 1 + self.num_frame_per_block * num_input_blocks
|
| 315 |
+
output[:, :1] = initial_latent
|
| 316 |
+
with torch.no_grad():
|
| 317 |
+
self.generator(
|
| 318 |
+
noisy_image_or_video=initial_latent,
|
| 319 |
+
conditional_dict=conditional_dict,
|
| 320 |
+
timestep=timestep * 0,
|
| 321 |
+
kv_cache=self.kv_cache_clean,
|
| 322 |
+
crossattn_cache=self.crossattn_cache,
|
| 323 |
+
current_start=current_start_frame * self.frame_seq_length
|
| 324 |
+
)
|
| 325 |
+
current_start_frame += 1
|
| 326 |
+
|
| 327 |
+
# Step 3: Temporal denoising loop
|
| 328 |
+
all_num_frames = [self.num_frame_per_block] * num_blocks
|
| 329 |
+
if self.independent_first_frame and initial_latent is None:
|
| 330 |
+
all_num_frames = [1] + all_num_frames
|
| 331 |
+
num_denoising_steps = len(self.denoising_step_list)
|
| 332 |
+
exit_flags = self.generate_and_sync_list(len(all_num_frames), num_denoising_steps, device=noise.device)
|
| 333 |
+
start_gradient_frame_index = num_output_frames - 21
|
| 334 |
+
|
| 335 |
+
# for block_index in range(num_blocks):
|
| 336 |
+
for block_index, current_num_frames in enumerate(all_num_frames):
|
| 337 |
+
noisy_input = noise[
|
| 338 |
+
:, current_start_frame - num_input_frames:current_start_frame + current_num_frames - num_input_frames]
|
| 339 |
+
|
| 340 |
+
# Step 3.1: Spatial denoising loop
|
| 341 |
+
for index, current_timestep in enumerate(self.denoising_step_list):
|
| 342 |
+
if self.same_step_across_blocks:
|
| 343 |
+
exit_flag = (index == exit_flags[0])
|
| 344 |
+
else:
|
| 345 |
+
exit_flag = (index == exit_flags[block_index]) # Only backprop at the randomly selected timestep (consistent across all ranks)
|
| 346 |
+
timestep = torch.ones(
|
| 347 |
+
[batch_size, current_num_frames],
|
| 348 |
+
device=noise.device,
|
| 349 |
+
dtype=torch.int64) * current_timestep
|
| 350 |
+
|
| 351 |
+
if not exit_flag:
|
| 352 |
+
with torch.no_grad():
|
| 353 |
+
_, denoised_pred = self.generator(
|
| 354 |
+
noisy_image_or_video=noisy_input,
|
| 355 |
+
conditional_dict=conditional_dict,
|
| 356 |
+
timestep=timestep,
|
| 357 |
+
kv_cache=self.kv_cache_clean,
|
| 358 |
+
crossattn_cache=self.crossattn_cache,
|
| 359 |
+
current_start=current_start_frame * self.frame_seq_length
|
| 360 |
+
)
|
| 361 |
+
next_timestep = self.denoising_step_list[index + 1]
|
| 362 |
+
noisy_input = self.scheduler.add_noise(
|
| 363 |
+
denoised_pred.flatten(0, 1),
|
| 364 |
+
torch.randn_like(denoised_pred.flatten(0, 1)),
|
| 365 |
+
next_timestep * torch.ones(
|
| 366 |
+
[batch_size * current_num_frames], device=noise.device, dtype=torch.long)
|
| 367 |
+
).unflatten(0, denoised_pred.shape[:2])
|
| 368 |
+
else:
|
| 369 |
+
# for getting real output
|
| 370 |
+
# with torch.set_grad_enabled(current_start_frame >= start_gradient_frame_index):
|
| 371 |
+
if current_start_frame < start_gradient_frame_index:
|
| 372 |
+
with torch.no_grad():
|
| 373 |
+
_, denoised_pred = self.generator(
|
| 374 |
+
noisy_image_or_video=noisy_input,
|
| 375 |
+
conditional_dict=conditional_dict,
|
| 376 |
+
timestep=timestep,
|
| 377 |
+
kv_cache=self.kv_cache_clean,
|
| 378 |
+
crossattn_cache=self.crossattn_cache,
|
| 379 |
+
current_start=current_start_frame * self.frame_seq_length
|
| 380 |
+
)
|
| 381 |
+
else:
|
| 382 |
+
_, denoised_pred = self.generator(
|
| 383 |
+
noisy_image_or_video=noisy_input,
|
| 384 |
+
conditional_dict=conditional_dict,
|
| 385 |
+
timestep=timestep,
|
| 386 |
+
kv_cache=self.kv_cache_clean,
|
| 387 |
+
crossattn_cache=self.crossattn_cache,
|
| 388 |
+
current_start=current_start_frame * self.frame_seq_length
|
| 389 |
+
)
|
| 390 |
+
break
|
| 391 |
+
|
| 392 |
+
# Step 3.2: record the model's output
|
| 393 |
+
output[:, current_start_frame:current_start_frame + current_num_frames] = denoised_pred
|
| 394 |
+
|
| 395 |
+
# Step 3.3: rerun with timestep zero to update the cache
|
| 396 |
+
context_timestep = torch.ones_like(timestep) * self.context_noise
|
| 397 |
+
# add context noise
|
| 398 |
+
denoised_pred = self.scheduler.add_noise(
|
| 399 |
+
denoised_pred.flatten(0, 1),
|
| 400 |
+
torch.randn_like(denoised_pred.flatten(0, 1)),
|
| 401 |
+
context_timestep * torch.ones(
|
| 402 |
+
[batch_size * current_num_frames], device=noise.device, dtype=torch.long)
|
| 403 |
+
).unflatten(0, denoised_pred.shape[:2])
|
| 404 |
+
with torch.no_grad():
|
| 405 |
+
self.generator(
|
| 406 |
+
noisy_image_or_video=denoised_pred,
|
| 407 |
+
conditional_dict=conditional_dict,
|
| 408 |
+
timestep=context_timestep,
|
| 409 |
+
kv_cache=self.kv_cache_clean,
|
| 410 |
+
crossattn_cache=self.crossattn_cache,
|
| 411 |
+
current_start=current_start_frame * self.frame_seq_length,
|
| 412 |
+
updating_cache=True,
|
| 413 |
+
)
|
| 414 |
+
|
| 415 |
+
# Step 3.4: update the start and end frame indices
|
| 416 |
+
current_start_frame += current_num_frames
|
| 417 |
+
|
| 418 |
+
# Step 3.5: Return the denoised timestep
|
| 419 |
+
if not self.same_step_across_blocks:
|
| 420 |
+
denoised_timestep_from, denoised_timestep_to = None, None
|
| 421 |
+
elif exit_flags[0] == len(self.denoising_step_list) - 1:
|
| 422 |
+
denoised_timestep_to = 0
|
| 423 |
+
denoised_timestep_from = 1000 - torch.argmin(
|
| 424 |
+
(self.scheduler.timesteps.cuda() - self.denoising_step_list[exit_flags[0]].cuda()).abs(), dim=0).item()
|
| 425 |
+
else:
|
| 426 |
+
denoised_timestep_to = 1000 - torch.argmin(
|
| 427 |
+
(self.scheduler.timesteps.cuda() - self.denoising_step_list[exit_flags[0] + 1].cuda()).abs(), dim=0).item()
|
| 428 |
+
denoised_timestep_from = 1000 - torch.argmin(
|
| 429 |
+
(self.scheduler.timesteps.cuda() - self.denoising_step_list[exit_flags[0]].cuda()).abs(), dim=0).item()
|
| 430 |
+
|
| 431 |
+
if return_sim_step:
|
| 432 |
+
return output, denoised_timestep_from, denoised_timestep_to, exit_flags[0] + 1
|
| 433 |
+
|
| 434 |
+
return output, denoised_timestep_from, denoised_timestep_to
|
| 435 |
+
|
| 436 |
+
def _initialize_kv_cache(self, batch_size, dtype, device):
|
| 437 |
+
"""
|
| 438 |
+
Initialize a Per-GPU KV cache for the Wan model.
|
| 439 |
+
"""
|
| 440 |
+
kv_cache_clean = []
|
| 441 |
+
|
| 442 |
+
for _ in range(self.num_transformer_blocks):
|
| 443 |
+
kv_cache_clean.append({
|
| 444 |
+
"k": torch.zeros([batch_size, self.kv_cache_size, 12, 128], dtype=dtype, device=device),
|
| 445 |
+
"v": torch.zeros([batch_size, self.kv_cache_size, 12, 128], dtype=dtype, device=device),
|
| 446 |
+
"global_end_index": torch.tensor([0], dtype=torch.long, device=device),
|
| 447 |
+
"local_end_index": torch.tensor([0], dtype=torch.long, device=device)
|
| 448 |
+
})
|
| 449 |
+
|
| 450 |
+
self.kv_cache_clean = kv_cache_clean # always store the clean cache
|
| 451 |
+
|
| 452 |
+
def _initialize_crossattn_cache(self, batch_size, dtype, device):
|
| 453 |
+
"""
|
| 454 |
+
Initialize a Per-GPU cross-attention cache for the Wan model.
|
| 455 |
+
"""
|
| 456 |
+
crossattn_cache = []
|
| 457 |
+
|
| 458 |
+
for _ in range(self.num_transformer_blocks):
|
| 459 |
+
crossattn_cache.append({
|
| 460 |
+
"k": torch.zeros([batch_size, 512, 12, 128], dtype=dtype, device=device),
|
| 461 |
+
"v": torch.zeros([batch_size, 512, 12, 128], dtype=dtype, device=device),
|
| 462 |
+
"is_init": False
|
| 463 |
+
})
|
| 464 |
+
self.crossattn_cache = crossattn_cache
|
prompts/example_prompts.txt
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
A cinematic scene from a classic western movie, featuring a rugged man riding a powerful horse through the vast Gobi Desert at sunset. The man, dressed in a dusty cowboy hat and a worn leather jacket, reins tightly on the horse's neck as he gallops across the golden sands. The sun sets dramatically behind them, casting long shadows and warm hues across the landscape. The background is filled with rolling dunes and sparse, rocky outcrops, emphasizing the harsh beauty of the desert. A dynamic wide shot from a low angle, capturing both the man and the expansive desert vista.
|
| 2 |
+
A classic black-and-white photograph style image of an older man playing the piano. The man, with a weathered face and kind eyes, sits at an antique piano with his fingers gracefully moving over the keys. The lighting comes from the side, casting dramatic shadows on his face and emphasizing the texture of his hands. His posture is upright and focused, conveying a sense of deep concentration and passion for music. The background is blurred, revealing only hints of a cozy room with wooden floors and old furniture. A close-up shot from a slightly elevated angle, capturing both the man and the piano in detail.
|
| 3 |
+
A dramatic post-apocalyptic scene in the style of a horror film, featuring a skeleton wearing a colorful flower hat and oversized sunglasses dancing wildly in a sunlit meadow at sunset. The skeleton has a weathered and somewhat decayed appearance, with bones visible through tattered remnants of clothing. The dance is energetic and almost comical, with exaggerated movements. The background is a vivid blend of warm oranges and pinks, with tall grasses and wildflowers swaying in the breeze. The sky is painted with rich hues of orange and pink, casting long shadows across the landscape. A dynamic medium shot from a low angle, capturing the skeleton's animated dance.
|
| 4 |
+
A dynamic action scene in a modern gym, featuring a kangaroo wearing boxing gloves, engaged in an intense sparring session with a punching bag. The kangaroo has a muscular build and is positioned mid-punch, its front legs wrapped in red boxing gloves, eyes focused intently on the target. The background showcases a cluttered gym with heavy equipment and mats, creating a vivid and realistic setting. The kangaroo's movements are fluid and powerful, conveying both agility and strength. The scene captures a split-second moment of mid-action, with the kangaroo's tail swaying behind it. A high-angle shot emphasizing the kangaroo's dynamic pose and the surrounding gym environment.
|
| 5 |
+
A dynamic action shot in the style of a high-energy sports magazine spread, featuring a golden retriever sprinting with all its might after a red sports car speeding down the road. The dog's fur glistens in the sunlight, and its eyes are filled with determination and excitement. It leaps forward, its tail wagging wildly, while the car speeds away in the background, leaving a trail of dust. The background shows a busy city street with blurred cars and pedestrians, adding to the sense of urgency. The photo has a crisp, vibrant color palette and a high-resolution quality. A medium-long shot capturing the dog's full run.
|
| 6 |
+
A dynamic action shot in the style of a professional skateboard magazine, featuring a young male longboarder accelerating downhill. He is fully focused, his expression intense and determined, carving through tight turns with precision. His longboard glides smoothly over the pavement, creating a blur of motion. He wears a black longboard shirt, blue jeans, and white sneakers, with a backpack slung over one shoulder. His hair flows behind him as he moves, and he grips the board tightly with both hands. The background shows a scenic urban street with blurred buildings and trees, hinting at a lively cityscape. The photo captures the moment just after he exits a turn, with a slight bounce in the board and a sense of speed and agility. A medium shot with a slightly elevated camera angle.
|
| 7 |
+
A dynamic hip-hop dance scene in a vibrant urban style, featuring an Asian girl in a bright yellow T-shirt and white pants. She is mid-dance move, arms stretched out and feet rhythmically stepping, exuding energy and confidence. Her hair is tied up in a ponytail, and she has a mischievous smile on her face. The background shows a bustling city street with blurred reflections of tall buildings and passing cars. The scene captures the lively and energetic atmosphere of a hip-hop performance, with a slightly grainy texture. A medium shot from a low-angle perspective.
|
| 8 |
+
A dynamic tracking shot following a skateboarder performing a series of fluid tricks down a bustling city street. The skateboarder, wearing a black helmet and a colorful shirt, moves with grace and confidence, executing flips, grinds, and spins. The camera captures the skateboarder's fluid movements, capturing the essence of each trick with precision. The background showcases the urban environment, with tall buildings, busy traffic, and passersby in the distance. The lighting highlights the skateboarder's movements, creating a sense of speed and energy. The overall style is reminiscent of a skateboarding documentary, emphasizing the natural and dynamic nature of the tricks.
|
| 9 |
+
A handheld camera captures a dog running through a park with a joyful exploration, the camera following the dog closely and bouncing and tilting with its movements. The dog bounds through the grass, tail wagging excitedly, sniffing at flowers and chasing after butterflies. Its fur glistens in the sunlight, and its eyes sparkle with enthusiasm. The park is filled with trees and colorful blooms, and the background shows a blurred path leading into the distance. The camera angle changes dynamically, providing a sense of the dog's lively energy and the vibrant environment around it.
|
| 10 |
+
A handheld shot following a young child running through a field of tall grass, capturing the spontaneity and playfulness of their movements. The child has curly brown hair and a mischievous smile, arms swinging freely as they sprint across the green expanse. Their small feet kick up bits of grass and dirt, creating a trail behind them. The background features a blurred landscape with rolling hills and scattered wildflowers, bathed in warm sunlight. The photo has a natural, documentary-style quality, emphasizing the dynamic motion and joy of the moment. A dynamic handheld shot from a slightly elevated angle, following the child's energetic run.
|
| 11 |
+
A high-speed action shot of a cheetah in its natural habitat, sprinting at full speed while chasing its prey across the savanna. The cheetah's golden fur glistens under the bright African sun, and its muscular body is stretched out in a powerful run. Its sharp eyes focus intently on the fleeing antelope, and its distinctive black tear marks streak down its face. The background is a blurred landscape with tall grass swaying in the wind, and distant acacia trees. The cheetah's tail is raised high, and its paws leave deep prints in the soft earth. A dynamic mid-shot capturing the intense moment of pursuit.
|
| 12 |
+
A photograph in a soft, warm lighting style, capturing a young woman with a bright smile and a playful wink. She has long curly brown hair and warm hazel eyes, with a slightly flushed cheeks from laughter. She is dressed in a casual yet stylish outfit: a floral printed sundress with a flowy skirt and a fitted top. Her hands are on her hips, giving a casual pose. The background features a blurred outdoor garden setting with blooming flowers and greenery. A medium shot from a slightly above-the-shoulder angle, emphasizing her joyful expression and the natural movement of her face.
|
| 13 |
+
A poignant moment captured in a realistic photographic style, showing a middle-aged man with a rugged face and slightly tousled hair, his chin quivering with emotion as he says a heartfelt goodbye to a loved one. He wears a simple grey sweater and jeans, standing on a dewy grassy field under a clear blue sky, with fluffy white clouds in the background. The camera angle is slightly from below, emphasizing his sorrowful expression and the depth of his feelings. A medium shot with a soft focus on the man's face and a blurred background.
|
| 14 |
+
A realistic photo of a llama wearing colorful pajamas dancing energetically on a stage under vibrant disco lighting. The llama has large floppy ears and a playful expression, moving its legs in a lively dance. It wears a red and yellow striped pajama top and matching pajama pants, with a fluffy tail swaying behind it. The stage is adorned with glittering disco balls and colorful lights, casting a lively and joyful atmosphere. The background features blurred audience members and a backdrop with disco-themed decorations. A dynamic shot capturing the llama mid-dance from a slightly elevated angle.
|
| 15 |
+
An adorable kangaroo, dressed in a cute green dress with polka dots, is wearing a small sun hat perched on its head. The kangaroo takes a pleasant stroll through the bustling streets of Mumbai during a vibrant and colorful festival. The background is filled with lively festival-goers in traditional Indian attire, adorned with intricate henna designs and bright jewelry. The scene is filled with colorful decorations, vendors selling various items, and people dancing and singing. The kangaroo moves gracefully, hopping along the cobblestone streets, its tail swinging behind it. The camera angle captures the kangaroo from a slight overhead perspective, highlighting its joyful expression and the festive atmosphere. A medium shot with dynamic movement.
|
| 16 |
+
An atmospheric and dramatic arc shot around a lone tree standing in a vast, foggy field at dawn. The early morning light filters through the mist, casting a soft, warm glow on the tree and the surrounding landscape. The tree's branches stretch out against the backdrop of a gradually lightening sky, with the shadows shifting and changing as the sun rises. The field is dotted with tall grasses and scattered wildflowers, their silhouettes softened by the fog. The overall scene has a moody, ethereal quality, emphasizing the natural movement of the fog and the subtle changes in light and shadow. A dynamic arc shot capturing the transition from night to day.
|
requirements.txt
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch==2.5.1
|
| 2 |
+
torchvision==0.20.1
|
| 3 |
+
torchaudio==2.5.1
|
| 4 |
+
opencv-python>=4.9.0.80
|
| 5 |
+
diffusers==0.31.0
|
| 6 |
+
transformers>=4.49.0
|
| 7 |
+
tokenizers>=0.20.3
|
| 8 |
+
accelerate>=1.1.1
|
| 9 |
+
tqdm
|
| 10 |
+
imageio
|
| 11 |
+
easydict
|
| 12 |
+
ftfy
|
| 13 |
+
dashscope
|
| 14 |
+
imageio-ffmpeg
|
| 15 |
+
numpy==1.24.4
|
| 16 |
+
wandb
|
| 17 |
+
omegaconf
|
| 18 |
+
einops
|
| 19 |
+
av==13.1.0
|
| 20 |
+
opencv-python
|
| 21 |
+
open_clip_torch
|
| 22 |
+
starlette
|
| 23 |
+
pycocotools
|
| 24 |
+
lmdb
|
| 25 |
+
matplotlib
|
| 26 |
+
sentencepiece
|
| 27 |
+
pydantic==2.10.6
|
| 28 |
+
scikit-image
|
| 29 |
+
huggingface_hub
|
| 30 |
+
dominate
|
| 31 |
+
nvidia-pyindex
|
| 32 |
+
nvidia-tensorrt
|
| 33 |
+
pycuda
|
| 34 |
+
onnx
|
| 35 |
+
onnxruntime
|
| 36 |
+
onnxscript
|
| 37 |
+
onnxconverter_common
|
| 38 |
+
flask
|
| 39 |
+
flask-socketio
|
| 40 |
+
torchao
|
| 41 |
+
tensorboard
|
| 42 |
+
ninja
|
| 43 |
+
packaging
|
| 44 |
+
--no-build-isolation
|
| 45 |
+
flash-attn
|
train.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
from omegaconf import OmegaConf
|
| 4 |
+
|
| 5 |
+
from trainer import DiffusionTrainer, GANTrainer, ODETrainer, ScoreDistillationTrainer
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def main():
|
| 9 |
+
parser = argparse.ArgumentParser()
|
| 10 |
+
parser.add_argument("--config_path", type=str, required=True)
|
| 11 |
+
parser.add_argument("--no_save", action="store_true")
|
| 12 |
+
parser.add_argument("--no_visualize", action="store_true")
|
| 13 |
+
parser.add_argument("--logdir", type=str, default="", help="Path to the directory to save logs")
|
| 14 |
+
parser.add_argument("--wandb-save-dir", type=str, default="", help="Path to the directory to save wandb logs")
|
| 15 |
+
parser.add_argument("--disable-wandb", default=False, action="store_true")
|
| 16 |
+
|
| 17 |
+
args = parser.parse_args()
|
| 18 |
+
|
| 19 |
+
config = OmegaConf.load(args.config_path)
|
| 20 |
+
default_config = OmegaConf.load("configs/default_config.yaml")
|
| 21 |
+
config = OmegaConf.merge(default_config, config)
|
| 22 |
+
config.no_save = args.no_save
|
| 23 |
+
config.no_visualize = args.no_visualize
|
| 24 |
+
|
| 25 |
+
# get the filename of config_path
|
| 26 |
+
config_name = os.path.basename(args.config_path).split(".")[0]
|
| 27 |
+
config.config_name = config_name
|
| 28 |
+
config.logdir = args.logdir
|
| 29 |
+
config.wandb_save_dir = args.wandb_save_dir
|
| 30 |
+
config.disable_wandb = args.disable_wandb
|
| 31 |
+
|
| 32 |
+
if config.trainer == "diffusion":
|
| 33 |
+
trainer = DiffusionTrainer(config)
|
| 34 |
+
elif config.trainer == "gan":
|
| 35 |
+
trainer = GANTrainer(config)
|
| 36 |
+
elif config.trainer == "ode":
|
| 37 |
+
trainer = ODETrainer(config)
|
| 38 |
+
elif config.trainer == "score_distillation":
|
| 39 |
+
trainer = ScoreDistillationTrainer(config)
|
| 40 |
+
trainer.train()
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
if __name__ == "__main__":
|
| 45 |
+
main()
|
trainer/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .diffusion import Trainer as DiffusionTrainer
|
| 2 |
+
from .gan import Trainer as GANTrainer
|
| 3 |
+
from .ode import Trainer as ODETrainer
|
| 4 |
+
from .distillation import Trainer as ScoreDistillationTrainer
|
| 5 |
+
|
| 6 |
+
__all__ = [
|
| 7 |
+
"DiffusionTrainer",
|
| 8 |
+
"GANTrainer",
|
| 9 |
+
"ODETrainer",
|
| 10 |
+
"ScoreDistillationTrainer"
|
| 11 |
+
]
|
trainer/diffusion.py
ADDED
|
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gc
|
| 2 |
+
import logging
|
| 3 |
+
|
| 4 |
+
from model import CausalDiffusion
|
| 5 |
+
from utils.dataset import ShardingLMDBDataset, cycle
|
| 6 |
+
from utils.misc import set_seed
|
| 7 |
+
import torch.distributed as dist
|
| 8 |
+
from omegaconf import OmegaConf
|
| 9 |
+
import torch
|
| 10 |
+
import wandb
|
| 11 |
+
import time
|
| 12 |
+
import os
|
| 13 |
+
|
| 14 |
+
from utils.distributed import EMA_FSDP, barrier, fsdp_wrap, fsdp_state_dict, launch_distributed_job
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class Trainer:
|
| 18 |
+
def __init__(self, config):
|
| 19 |
+
self.config = config
|
| 20 |
+
self.step = 0
|
| 21 |
+
|
| 22 |
+
# Step 1: Initialize the distributed training environment (rank, seed, dtype, logging etc.)
|
| 23 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 24 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 25 |
+
|
| 26 |
+
launch_distributed_job()
|
| 27 |
+
global_rank = dist.get_rank()
|
| 28 |
+
|
| 29 |
+
self.dtype = torch.bfloat16 if config.mixed_precision else torch.float32
|
| 30 |
+
self.device = torch.cuda.current_device()
|
| 31 |
+
self.is_main_process = global_rank == 0
|
| 32 |
+
self.causal = config.causal
|
| 33 |
+
self.disable_wandb = config.disable_wandb
|
| 34 |
+
|
| 35 |
+
# use a random seed for the training
|
| 36 |
+
if config.seed == 0:
|
| 37 |
+
random_seed = torch.randint(0, 10000000, (1,), device=self.device)
|
| 38 |
+
dist.broadcast(random_seed, src=0)
|
| 39 |
+
config.seed = random_seed.item()
|
| 40 |
+
|
| 41 |
+
set_seed(config.seed + global_rank)
|
| 42 |
+
|
| 43 |
+
if self.is_main_process and not self.disable_wandb:
|
| 44 |
+
wandb.login(host=config.wandb_host, key=config.wandb_key)
|
| 45 |
+
wandb.init(
|
| 46 |
+
config=OmegaConf.to_container(config, resolve=True),
|
| 47 |
+
name=config.config_name,
|
| 48 |
+
mode="online",
|
| 49 |
+
entity=config.wandb_entity,
|
| 50 |
+
project=config.wandb_project,
|
| 51 |
+
dir=config.wandb_save_dir
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
self.output_path = config.logdir
|
| 55 |
+
|
| 56 |
+
# Step 2: Initialize the model and optimizer
|
| 57 |
+
self.model = CausalDiffusion(config, device=self.device)
|
| 58 |
+
self.model.generator = fsdp_wrap(
|
| 59 |
+
self.model.generator,
|
| 60 |
+
sharding_strategy=config.sharding_strategy,
|
| 61 |
+
mixed_precision=config.mixed_precision,
|
| 62 |
+
wrap_strategy=config.generator_fsdp_wrap_strategy
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
self.model.text_encoder = fsdp_wrap(
|
| 66 |
+
self.model.text_encoder,
|
| 67 |
+
sharding_strategy=config.sharding_strategy,
|
| 68 |
+
mixed_precision=config.mixed_precision,
|
| 69 |
+
wrap_strategy=config.text_encoder_fsdp_wrap_strategy
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
if not config.no_visualize or config.load_raw_video:
|
| 73 |
+
self.model.vae = self.model.vae.to(
|
| 74 |
+
device=self.device, dtype=torch.bfloat16 if config.mixed_precision else torch.float32)
|
| 75 |
+
|
| 76 |
+
self.generator_optimizer = torch.optim.AdamW(
|
| 77 |
+
[param for param in self.model.generator.parameters()
|
| 78 |
+
if param.requires_grad],
|
| 79 |
+
lr=config.lr,
|
| 80 |
+
betas=(config.beta1, config.beta2),
|
| 81 |
+
weight_decay=config.weight_decay
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
# Step 3: Initialize the dataloader
|
| 85 |
+
dataset = ShardingLMDBDataset(config.data_path, max_pair=int(1e8))
|
| 86 |
+
sampler = torch.utils.data.distributed.DistributedSampler(
|
| 87 |
+
dataset, shuffle=True, drop_last=True)
|
| 88 |
+
dataloader = torch.utils.data.DataLoader(
|
| 89 |
+
dataset,
|
| 90 |
+
batch_size=config.batch_size,
|
| 91 |
+
sampler=sampler,
|
| 92 |
+
num_workers=8)
|
| 93 |
+
|
| 94 |
+
if dist.get_rank() == 0:
|
| 95 |
+
print("DATASET SIZE %d" % len(dataset))
|
| 96 |
+
self.dataloader = cycle(dataloader)
|
| 97 |
+
|
| 98 |
+
##############################################################################################################
|
| 99 |
+
# 6. Set up EMA parameter containers
|
| 100 |
+
rename_param = (
|
| 101 |
+
lambda name: name.replace("_fsdp_wrapped_module.", "")
|
| 102 |
+
.replace("_checkpoint_wrapped_module.", "")
|
| 103 |
+
.replace("_orig_mod.", "")
|
| 104 |
+
)
|
| 105 |
+
self.name_to_trainable_params = {}
|
| 106 |
+
for n, p in self.model.generator.named_parameters():
|
| 107 |
+
if not p.requires_grad:
|
| 108 |
+
continue
|
| 109 |
+
|
| 110 |
+
renamed_n = rename_param(n)
|
| 111 |
+
self.name_to_trainable_params[renamed_n] = p
|
| 112 |
+
ema_weight = config.ema_weight
|
| 113 |
+
self.generator_ema = None
|
| 114 |
+
if (ema_weight is not None) and (ema_weight > 0.0):
|
| 115 |
+
print(f"Setting up EMA with weight {ema_weight}")
|
| 116 |
+
self.generator_ema = EMA_FSDP(self.model.generator, decay=ema_weight)
|
| 117 |
+
|
| 118 |
+
##############################################################################################################
|
| 119 |
+
# 7. (If resuming) Load the model and optimizer, lr_scheduler, ema's statedicts
|
| 120 |
+
if getattr(config, "generator_ckpt", False):
|
| 121 |
+
print(f"Loading pretrained generator from {config.generator_ckpt}")
|
| 122 |
+
state_dict = torch.load(config.generator_ckpt, map_location="cpu")
|
| 123 |
+
if "generator" in state_dict:
|
| 124 |
+
state_dict = state_dict["generator"]
|
| 125 |
+
elif "model" in state_dict:
|
| 126 |
+
state_dict = state_dict["model"]
|
| 127 |
+
self.model.generator.load_state_dict(
|
| 128 |
+
state_dict, strict=True
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
##############################################################################################################
|
| 132 |
+
|
| 133 |
+
# Let's delete EMA params for early steps to save some computes at training and inference
|
| 134 |
+
if self.step < config.ema_start_step:
|
| 135 |
+
self.generator_ema = None
|
| 136 |
+
|
| 137 |
+
self.max_grad_norm = 10.0
|
| 138 |
+
self.previous_time = None
|
| 139 |
+
|
| 140 |
+
def save(self):
|
| 141 |
+
print("Start gathering distributed model states...")
|
| 142 |
+
generator_state_dict = fsdp_state_dict(
|
| 143 |
+
self.model.generator)
|
| 144 |
+
|
| 145 |
+
if self.config.ema_start_step < self.step:
|
| 146 |
+
state_dict = {
|
| 147 |
+
"generator": generator_state_dict,
|
| 148 |
+
"generator_ema": self.generator_ema.state_dict(),
|
| 149 |
+
}
|
| 150 |
+
else:
|
| 151 |
+
state_dict = {
|
| 152 |
+
"generator": generator_state_dict,
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
if self.is_main_process:
|
| 156 |
+
os.makedirs(os.path.join(self.output_path,
|
| 157 |
+
f"checkpoint_model_{self.step:06d}"), exist_ok=True)
|
| 158 |
+
torch.save(state_dict, os.path.join(self.output_path,
|
| 159 |
+
f"checkpoint_model_{self.step:06d}", "model.pt"))
|
| 160 |
+
print("Model saved to", os.path.join(self.output_path,
|
| 161 |
+
f"checkpoint_model_{self.step:06d}", "model.pt"))
|
| 162 |
+
|
| 163 |
+
def train_one_step(self, batch):
|
| 164 |
+
self.log_iters = 1
|
| 165 |
+
|
| 166 |
+
if self.step % 20 == 0:
|
| 167 |
+
torch.cuda.empty_cache()
|
| 168 |
+
|
| 169 |
+
# Step 1: Get the next batch of text prompts
|
| 170 |
+
text_prompts = batch["prompts"]
|
| 171 |
+
if not self.config.load_raw_video: # precomputed latent
|
| 172 |
+
clean_latent = batch["ode_latent"][:, -1].to(
|
| 173 |
+
device=self.device, dtype=self.dtype)
|
| 174 |
+
else: # encode raw video to latent
|
| 175 |
+
frames = batch["frames"].to(
|
| 176 |
+
device=self.device, dtype=self.dtype)
|
| 177 |
+
with torch.no_grad():
|
| 178 |
+
clean_latent = self.model.vae.encode_to_latent(
|
| 179 |
+
frames).to(device=self.device, dtype=self.dtype)
|
| 180 |
+
image_latent = clean_latent[:, 0:1, ]
|
| 181 |
+
|
| 182 |
+
batch_size = len(text_prompts)
|
| 183 |
+
image_or_video_shape = list(self.config.image_or_video_shape)
|
| 184 |
+
image_or_video_shape[0] = batch_size
|
| 185 |
+
|
| 186 |
+
# Step 2: Extract the conditional infos
|
| 187 |
+
with torch.no_grad():
|
| 188 |
+
conditional_dict = self.model.text_encoder(
|
| 189 |
+
text_prompts=text_prompts)
|
| 190 |
+
|
| 191 |
+
if not getattr(self, "unconditional_dict", None):
|
| 192 |
+
unconditional_dict = self.model.text_encoder(
|
| 193 |
+
text_prompts=[self.config.negative_prompt] * batch_size)
|
| 194 |
+
unconditional_dict = {k: v.detach()
|
| 195 |
+
for k, v in unconditional_dict.items()}
|
| 196 |
+
self.unconditional_dict = unconditional_dict # cache the unconditional_dict
|
| 197 |
+
else:
|
| 198 |
+
unconditional_dict = self.unconditional_dict
|
| 199 |
+
|
| 200 |
+
# Step 3: Train the generator
|
| 201 |
+
generator_loss, log_dict = self.model.generator_loss(
|
| 202 |
+
image_or_video_shape=image_or_video_shape,
|
| 203 |
+
conditional_dict=conditional_dict,
|
| 204 |
+
unconditional_dict=unconditional_dict,
|
| 205 |
+
clean_latent=clean_latent,
|
| 206 |
+
initial_latent=image_latent
|
| 207 |
+
)
|
| 208 |
+
self.generator_optimizer.zero_grad()
|
| 209 |
+
generator_loss.backward()
|
| 210 |
+
generator_grad_norm = self.model.generator.clip_grad_norm_(
|
| 211 |
+
self.max_grad_norm)
|
| 212 |
+
self.generator_optimizer.step()
|
| 213 |
+
|
| 214 |
+
# Increment the step since we finished gradient update
|
| 215 |
+
self.step += 1
|
| 216 |
+
|
| 217 |
+
wandb_loss_dict = {
|
| 218 |
+
"generator_loss": generator_loss.item(),
|
| 219 |
+
"generator_grad_norm": generator_grad_norm.item(),
|
| 220 |
+
}
|
| 221 |
+
|
| 222 |
+
# Step 4: Logging
|
| 223 |
+
if self.is_main_process:
|
| 224 |
+
if not self.disable_wandb:
|
| 225 |
+
wandb.log(wandb_loss_dict, step=self.step)
|
| 226 |
+
|
| 227 |
+
if self.step % self.config.gc_interval == 0:
|
| 228 |
+
if dist.get_rank() == 0:
|
| 229 |
+
logging.info("DistGarbageCollector: Running GC.")
|
| 230 |
+
gc.collect()
|
| 231 |
+
|
| 232 |
+
# Step 5. Create EMA params
|
| 233 |
+
# TODO: Implement EMA
|
| 234 |
+
|
| 235 |
+
def generate_video(self, pipeline, prompts, image=None):
|
| 236 |
+
batch_size = len(prompts)
|
| 237 |
+
sampled_noise = torch.randn(
|
| 238 |
+
[batch_size, 21, 16, 60, 104], device="cuda", dtype=self.dtype
|
| 239 |
+
)
|
| 240 |
+
video, _ = pipeline.inference(
|
| 241 |
+
noise=sampled_noise,
|
| 242 |
+
text_prompts=prompts,
|
| 243 |
+
return_latents=True
|
| 244 |
+
)
|
| 245 |
+
current_video = video.permute(0, 1, 3, 4, 2).cpu().numpy() * 255.0
|
| 246 |
+
return current_video
|
| 247 |
+
|
| 248 |
+
def train(self):
|
| 249 |
+
while True:
|
| 250 |
+
batch = next(self.dataloader)
|
| 251 |
+
self.train_one_step(batch)
|
| 252 |
+
if (not self.config.no_save) and self.step % self.config.log_iters == 0:
|
| 253 |
+
torch.cuda.empty_cache()
|
| 254 |
+
self.save()
|
| 255 |
+
torch.cuda.empty_cache()
|
| 256 |
+
|
| 257 |
+
barrier()
|
| 258 |
+
if self.is_main_process:
|
| 259 |
+
current_time = time.time()
|
| 260 |
+
if self.previous_time is None:
|
| 261 |
+
self.previous_time = current_time
|
| 262 |
+
else:
|
| 263 |
+
if not self.disable_wandb:
|
| 264 |
+
wandb.log({"per iteration time": current_time - self.previous_time}, step=self.step)
|
| 265 |
+
self.previous_time = current_time
|
trainer/distillation.py
ADDED
|
@@ -0,0 +1,398 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gc
|
| 2 |
+
import logging
|
| 3 |
+
|
| 4 |
+
from utils.dataset import ShardingLMDBDataset, cycle
|
| 5 |
+
from utils.dataset import TextDataset
|
| 6 |
+
from utils.distributed import EMA_FSDP, fsdp_wrap, fsdp_state_dict, launch_distributed_job
|
| 7 |
+
from utils.misc import (
|
| 8 |
+
set_seed,
|
| 9 |
+
merge_dict_list
|
| 10 |
+
)
|
| 11 |
+
import torch.distributed as dist
|
| 12 |
+
from omegaconf import OmegaConf
|
| 13 |
+
from model import CausVid, DMD, SiD
|
| 14 |
+
import torch
|
| 15 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 16 |
+
import time
|
| 17 |
+
import os
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class Trainer:
|
| 21 |
+
def __init__(self, config):
|
| 22 |
+
self.config = config
|
| 23 |
+
self.step = 0
|
| 24 |
+
|
| 25 |
+
# Step 1: Initialize the distributed training environment (rank, seed, dtype, logging etc.)
|
| 26 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 27 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 28 |
+
|
| 29 |
+
launch_distributed_job()
|
| 30 |
+
global_rank = dist.get_rank()
|
| 31 |
+
self.world_size = dist.get_world_size()
|
| 32 |
+
|
| 33 |
+
self.dtype = torch.bfloat16 if config.mixed_precision else torch.float32
|
| 34 |
+
self.device = torch.cuda.current_device()
|
| 35 |
+
self.is_main_process = global_rank == 0
|
| 36 |
+
self.causal = config.causal
|
| 37 |
+
|
| 38 |
+
# use a random seed for the training
|
| 39 |
+
if config.seed == 0:
|
| 40 |
+
random_seed = torch.randint(0, 10000000, (1,), device=self.device)
|
| 41 |
+
dist.broadcast(random_seed, src=0)
|
| 42 |
+
config.seed = random_seed.item()
|
| 43 |
+
|
| 44 |
+
set_seed(config.seed + global_rank)
|
| 45 |
+
|
| 46 |
+
if self.is_main_process:
|
| 47 |
+
self.writer = SummaryWriter(
|
| 48 |
+
log_dir=os.path.join(config.logdir, "tensorboard"),
|
| 49 |
+
flush_secs=10
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
self.output_path = config.logdir
|
| 53 |
+
|
| 54 |
+
# Step 2: Initialize the model and optimizer
|
| 55 |
+
if config.distribution_loss == "causvid":
|
| 56 |
+
self.model = CausVid(config, device=self.device)
|
| 57 |
+
elif config.distribution_loss == "dmd":
|
| 58 |
+
self.model = DMD(config, device=self.device)
|
| 59 |
+
elif config.distribution_loss == "sid":
|
| 60 |
+
self.model = SiD(config, device=self.device)
|
| 61 |
+
else:
|
| 62 |
+
raise ValueError("Invalid distribution matching loss")
|
| 63 |
+
|
| 64 |
+
# Save pretrained model state_dicts to CPU
|
| 65 |
+
self.fake_score_state_dict_cpu = self.model.fake_score.state_dict()
|
| 66 |
+
|
| 67 |
+
self.model.generator = fsdp_wrap(
|
| 68 |
+
self.model.generator,
|
| 69 |
+
sharding_strategy=config.sharding_strategy,
|
| 70 |
+
mixed_precision=config.mixed_precision,
|
| 71 |
+
wrap_strategy=config.generator_fsdp_wrap_strategy
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
self.model.real_score = fsdp_wrap(
|
| 75 |
+
self.model.real_score,
|
| 76 |
+
sharding_strategy=config.sharding_strategy,
|
| 77 |
+
mixed_precision=config.mixed_precision,
|
| 78 |
+
wrap_strategy=config.real_score_fsdp_wrap_strategy
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
self.model.fake_score = fsdp_wrap(
|
| 82 |
+
self.model.fake_score,
|
| 83 |
+
sharding_strategy=config.sharding_strategy,
|
| 84 |
+
mixed_precision=config.mixed_precision,
|
| 85 |
+
wrap_strategy=config.fake_score_fsdp_wrap_strategy
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
self.model.text_encoder = fsdp_wrap(
|
| 89 |
+
self.model.text_encoder,
|
| 90 |
+
sharding_strategy=config.sharding_strategy,
|
| 91 |
+
mixed_precision=config.mixed_precision,
|
| 92 |
+
wrap_strategy=config.text_encoder_fsdp_wrap_strategy,
|
| 93 |
+
cpu_offload=getattr(config, "text_encoder_cpu_offload", False)
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
if not config.no_visualize or config.load_raw_video:
|
| 97 |
+
self.model.vae = self.model.vae.to(
|
| 98 |
+
device=self.device, dtype=torch.bfloat16 if config.mixed_precision else torch.float32)
|
| 99 |
+
|
| 100 |
+
self.generator_optimizer = torch.optim.AdamW(
|
| 101 |
+
[param for param in self.model.generator.parameters()
|
| 102 |
+
if param.requires_grad],
|
| 103 |
+
lr=config.lr,
|
| 104 |
+
betas=(config.beta1, config.beta2),
|
| 105 |
+
weight_decay=config.weight_decay
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
self.critic_optimizer = torch.optim.AdamW(
|
| 109 |
+
[param for param in self.model.fake_score.parameters()
|
| 110 |
+
if param.requires_grad],
|
| 111 |
+
lr=config.lr_critic if hasattr(config, "lr_critic") else config.lr,
|
| 112 |
+
betas=(config.beta1_critic, config.beta2_critic),
|
| 113 |
+
weight_decay=config.weight_decay
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
# Step 3: Initialize the dataloader
|
| 117 |
+
if self.config.i2v:
|
| 118 |
+
dataset = ShardingLMDBDataset(config.data_path, max_pair=int(1e8))
|
| 119 |
+
else:
|
| 120 |
+
dataset = TextDataset(config.data_path)
|
| 121 |
+
sampler = torch.utils.data.distributed.DistributedSampler(
|
| 122 |
+
dataset, shuffle=True, drop_last=True)
|
| 123 |
+
dataloader = torch.utils.data.DataLoader(
|
| 124 |
+
dataset,
|
| 125 |
+
batch_size=config.batch_size,
|
| 126 |
+
sampler=sampler,
|
| 127 |
+
num_workers=8)
|
| 128 |
+
|
| 129 |
+
if dist.get_rank() == 0:
|
| 130 |
+
print("DATASET SIZE %d" % len(dataset))
|
| 131 |
+
self.dataloader = cycle(dataloader)
|
| 132 |
+
|
| 133 |
+
##############################################################################################################
|
| 134 |
+
# 6. Set up EMA parameter containers
|
| 135 |
+
rename_param = (
|
| 136 |
+
lambda name: name.replace("_fsdp_wrapped_module.", "")
|
| 137 |
+
.replace("_checkpoint_wrapped_module.", "")
|
| 138 |
+
.replace("_orig_mod.", "")
|
| 139 |
+
)
|
| 140 |
+
self.name_to_trainable_params = {}
|
| 141 |
+
for n, p in self.model.generator.named_parameters():
|
| 142 |
+
if not p.requires_grad:
|
| 143 |
+
continue
|
| 144 |
+
|
| 145 |
+
renamed_n = rename_param(n)
|
| 146 |
+
self.name_to_trainable_params[renamed_n] = p
|
| 147 |
+
ema_weight = config.ema_weight
|
| 148 |
+
self.generator_ema = None
|
| 149 |
+
if (ema_weight is not None) and (ema_weight > 0.0):
|
| 150 |
+
print(f"Setting up EMA with weight {ema_weight}")
|
| 151 |
+
self.generator_ema = EMA_FSDP(self.model.generator, decay=ema_weight)
|
| 152 |
+
|
| 153 |
+
##############################################################################################################
|
| 154 |
+
# 7. (If resuming) Load the model and optimizer, lr_scheduler, ema's statedicts
|
| 155 |
+
if getattr(config, "generator_ckpt", False):
|
| 156 |
+
print(f"Loading pretrained generator from {config.generator_ckpt}")
|
| 157 |
+
state_dict = torch.load(config.generator_ckpt, map_location="cpu")
|
| 158 |
+
if "generator" in state_dict:
|
| 159 |
+
state_dict = state_dict["generator"]
|
| 160 |
+
elif "model" in state_dict:
|
| 161 |
+
state_dict = state_dict["model"]
|
| 162 |
+
self.model.generator.load_state_dict(
|
| 163 |
+
state_dict, strict=True
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
##############################################################################################################
|
| 167 |
+
|
| 168 |
+
# Let's delete EMA params for early steps to save some computes at training and inference
|
| 169 |
+
if self.step < config.ema_start_step:
|
| 170 |
+
self.generator_ema = None
|
| 171 |
+
|
| 172 |
+
self.max_grad_norm_generator = getattr(config, "max_grad_norm_generator", 10.0)
|
| 173 |
+
self.max_grad_norm_critic = getattr(config, "max_grad_norm_critic", 10.0)
|
| 174 |
+
self.previous_time = None
|
| 175 |
+
|
| 176 |
+
def save(self):
|
| 177 |
+
print("Start gathering distributed model states...")
|
| 178 |
+
generator_state_dict = fsdp_state_dict(
|
| 179 |
+
self.model.generator)
|
| 180 |
+
critic_state_dict = fsdp_state_dict(
|
| 181 |
+
self.model.fake_score)
|
| 182 |
+
|
| 183 |
+
if self.config.ema_start_step < self.step:
|
| 184 |
+
state_dict = {
|
| 185 |
+
"generator": generator_state_dict,
|
| 186 |
+
"critic": critic_state_dict,
|
| 187 |
+
"generator_ema": self.generator_ema.state_dict(),
|
| 188 |
+
}
|
| 189 |
+
else:
|
| 190 |
+
state_dict = {
|
| 191 |
+
"generator": generator_state_dict,
|
| 192 |
+
"critic": critic_state_dict,
|
| 193 |
+
}
|
| 194 |
+
|
| 195 |
+
if self.is_main_process:
|
| 196 |
+
os.makedirs(os.path.join(self.output_path,
|
| 197 |
+
f"checkpoint_model_{self.step:06d}"), exist_ok=True)
|
| 198 |
+
torch.save(state_dict, os.path.join(self.output_path,
|
| 199 |
+
f"checkpoint_model_{self.step:06d}", "model.pt"))
|
| 200 |
+
print("Model saved to", os.path.join(self.output_path,
|
| 201 |
+
f"checkpoint_model_{self.step:06d}", "model.pt"))
|
| 202 |
+
|
| 203 |
+
def fwdbwd_one_step(self, batch, train_generator):
|
| 204 |
+
self.model.eval() # prevent any randomness (e.g. dropout)
|
| 205 |
+
|
| 206 |
+
if self.step % 20 == 0:
|
| 207 |
+
torch.cuda.empty_cache()
|
| 208 |
+
|
| 209 |
+
# Step 1: Get the next batch of text prompts
|
| 210 |
+
text_prompts = batch["prompts"]
|
| 211 |
+
if self.config.i2v:
|
| 212 |
+
clean_latent = None
|
| 213 |
+
image_latent = batch["ode_latent"][:, -1][:, 0:1, ].to(
|
| 214 |
+
device=self.device, dtype=self.dtype)
|
| 215 |
+
else:
|
| 216 |
+
clean_latent = None
|
| 217 |
+
image_latent = None
|
| 218 |
+
|
| 219 |
+
batch_size = len(text_prompts)
|
| 220 |
+
image_or_video_shape = list(self.config.image_or_video_shape)
|
| 221 |
+
image_or_video_shape[0] = batch_size
|
| 222 |
+
|
| 223 |
+
# Step 2: Extract the conditional infos
|
| 224 |
+
with torch.no_grad():
|
| 225 |
+
conditional_dict = self.model.text_encoder(
|
| 226 |
+
text_prompts=text_prompts)
|
| 227 |
+
|
| 228 |
+
if not getattr(self, "unconditional_dict", None):
|
| 229 |
+
unconditional_dict = self.model.text_encoder(
|
| 230 |
+
text_prompts=[self.config.negative_prompt] * batch_size)
|
| 231 |
+
unconditional_dict = {k: v.detach()
|
| 232 |
+
for k, v in unconditional_dict.items()}
|
| 233 |
+
self.unconditional_dict = unconditional_dict # cache the unconditional_dict
|
| 234 |
+
else:
|
| 235 |
+
unconditional_dict = self.unconditional_dict
|
| 236 |
+
|
| 237 |
+
# Step 3: Store gradients for the generator (if training the generator)
|
| 238 |
+
if train_generator:
|
| 239 |
+
generator_loss, generator_log_dict = self.model.generator_loss(
|
| 240 |
+
image_or_video_shape=image_or_video_shape,
|
| 241 |
+
conditional_dict=conditional_dict,
|
| 242 |
+
unconditional_dict=unconditional_dict,
|
| 243 |
+
clean_latent=clean_latent,
|
| 244 |
+
initial_latent=image_latent if self.config.i2v else None
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
generator_loss.backward()
|
| 248 |
+
generator_grad_norm = self.model.generator.clip_grad_norm_(
|
| 249 |
+
self.max_grad_norm_generator)
|
| 250 |
+
|
| 251 |
+
generator_log_dict.update({"generator_loss": generator_loss,
|
| 252 |
+
"generator_grad_norm": generator_grad_norm})
|
| 253 |
+
|
| 254 |
+
return generator_log_dict
|
| 255 |
+
else:
|
| 256 |
+
generator_log_dict = {}
|
| 257 |
+
|
| 258 |
+
# Step 4: Store gradients for the critic (if training the critic)
|
| 259 |
+
critic_loss, critic_log_dict = self.model.critic_loss(
|
| 260 |
+
image_or_video_shape=image_or_video_shape,
|
| 261 |
+
conditional_dict=conditional_dict,
|
| 262 |
+
unconditional_dict=unconditional_dict,
|
| 263 |
+
clean_latent=clean_latent,
|
| 264 |
+
initial_latent=image_latent if self.config.i2v else None
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
critic_loss.backward()
|
| 268 |
+
critic_grad_norm = self.model.fake_score.clip_grad_norm_(
|
| 269 |
+
self.max_grad_norm_critic)
|
| 270 |
+
|
| 271 |
+
critic_log_dict.update({"critic_loss": critic_loss,
|
| 272 |
+
"critic_grad_norm": critic_grad_norm})
|
| 273 |
+
|
| 274 |
+
return critic_log_dict
|
| 275 |
+
|
| 276 |
+
def generate_video(self, pipeline, prompts, image=None):
|
| 277 |
+
batch_size = len(prompts)
|
| 278 |
+
if image is not None:
|
| 279 |
+
image = image.squeeze(0).unsqueeze(0).unsqueeze(2).to(device="cuda", dtype=torch.bfloat16)
|
| 280 |
+
|
| 281 |
+
# Encode the input image as the first latent
|
| 282 |
+
initial_latent = pipeline.vae.encode_to_latent(image).to(device="cuda", dtype=torch.bfloat16)
|
| 283 |
+
initial_latent = initial_latent.repeat(batch_size, 1, 1, 1, 1)
|
| 284 |
+
sampled_noise = torch.randn(
|
| 285 |
+
[batch_size, self.model.num_training_frames - 1, 16, 60, 104],
|
| 286 |
+
device="cuda",
|
| 287 |
+
dtype=self.dtype
|
| 288 |
+
)
|
| 289 |
+
else:
|
| 290 |
+
initial_latent = None
|
| 291 |
+
sampled_noise = torch.randn(
|
| 292 |
+
[batch_size, self.model.num_training_frames, 16, 60, 104],
|
| 293 |
+
device="cuda",
|
| 294 |
+
dtype=self.dtype
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
video, _ = pipeline.inference(
|
| 298 |
+
noise=sampled_noise,
|
| 299 |
+
text_prompts=prompts,
|
| 300 |
+
return_latents=True,
|
| 301 |
+
initial_latent=initial_latent
|
| 302 |
+
)
|
| 303 |
+
current_video = video.permute(0, 1, 3, 4, 2).cpu().numpy() * 255.0
|
| 304 |
+
return current_video
|
| 305 |
+
|
| 306 |
+
def train(self):
|
| 307 |
+
start_step = self.step
|
| 308 |
+
|
| 309 |
+
while True:
|
| 310 |
+
TRAIN_GENERATOR = self.step % self.config.dfake_gen_update_ratio == 0
|
| 311 |
+
|
| 312 |
+
# Train the generator
|
| 313 |
+
if TRAIN_GENERATOR:
|
| 314 |
+
self.generator_optimizer.zero_grad(set_to_none=True)
|
| 315 |
+
extras_list = []
|
| 316 |
+
batch = next(self.dataloader)
|
| 317 |
+
extra = self.fwdbwd_one_step(batch, True)
|
| 318 |
+
extras_list.append(extra)
|
| 319 |
+
generator_log_dict = merge_dict_list(extras_list)
|
| 320 |
+
self.generator_optimizer.step()
|
| 321 |
+
if self.generator_ema is not None:
|
| 322 |
+
self.generator_ema.update(self.model.generator)
|
| 323 |
+
|
| 324 |
+
# Train the critic
|
| 325 |
+
self.critic_optimizer.zero_grad(set_to_none=True)
|
| 326 |
+
extras_list = []
|
| 327 |
+
batch = next(self.dataloader)
|
| 328 |
+
extra = self.fwdbwd_one_step(batch, False)
|
| 329 |
+
extras_list.append(extra)
|
| 330 |
+
critic_log_dict = merge_dict_list(extras_list)
|
| 331 |
+
self.critic_optimizer.step()
|
| 332 |
+
|
| 333 |
+
# Increment the step since we finished gradient update
|
| 334 |
+
self.step += 1
|
| 335 |
+
|
| 336 |
+
# Create EMA params (if not already created)
|
| 337 |
+
if (self.step >= self.config.ema_start_step) and \
|
| 338 |
+
(self.generator_ema is None) and (self.config.ema_weight > 0):
|
| 339 |
+
self.generator_ema = EMA_FSDP(self.model.generator, decay=self.config.ema_weight)
|
| 340 |
+
|
| 341 |
+
# Save the model
|
| 342 |
+
if (not self.config.no_save) and (self.step - start_step) > 0 and self.step % self.config.log_iters == 0:
|
| 343 |
+
torch.cuda.empty_cache()
|
| 344 |
+
self.save()
|
| 345 |
+
torch.cuda.empty_cache()
|
| 346 |
+
|
| 347 |
+
# Logging
|
| 348 |
+
if self.is_main_process:
|
| 349 |
+
|
| 350 |
+
if TRAIN_GENERATOR:
|
| 351 |
+
self.writer.add_scalar(
|
| 352 |
+
"generator_loss",
|
| 353 |
+
generator_log_dict["generator_loss"].mean().item(),
|
| 354 |
+
self.step
|
| 355 |
+
)
|
| 356 |
+
self.writer.add_scalar(
|
| 357 |
+
"generator_grad_norm",
|
| 358 |
+
generator_log_dict["generator_grad_norm"].mean().item(),
|
| 359 |
+
self.step
|
| 360 |
+
)
|
| 361 |
+
self.writer.add_scalar(
|
| 362 |
+
"dmdtrain_gradient_norm",
|
| 363 |
+
generator_log_dict["dmdtrain_gradient_norm"].mean().item(),
|
| 364 |
+
self.step
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
self.writer.add_scalar(
|
| 368 |
+
"critic_loss",
|
| 369 |
+
critic_log_dict["critic_loss"].mean().item(),
|
| 370 |
+
self.step
|
| 371 |
+
)
|
| 372 |
+
self.writer.add_scalar(
|
| 373 |
+
"critic_grad_norm",
|
| 374 |
+
critic_log_dict["critic_grad_norm"].mean().item(),
|
| 375 |
+
self.step
|
| 376 |
+
)
|
| 377 |
+
|
| 378 |
+
if self.step % self.config.gc_interval == 0:
|
| 379 |
+
if dist.get_rank() == 0:
|
| 380 |
+
logging.info("DistGarbageCollector: Running GC.")
|
| 381 |
+
gc.collect()
|
| 382 |
+
torch.cuda.empty_cache()
|
| 383 |
+
|
| 384 |
+
if self.is_main_process:
|
| 385 |
+
current_time = time.time()
|
| 386 |
+
if self.previous_time is None:
|
| 387 |
+
self.previous_time = current_time
|
| 388 |
+
else:
|
| 389 |
+
self.writer.add_scalar(
|
| 390 |
+
"per iteration time",
|
| 391 |
+
current_time - self.previous_time,
|
| 392 |
+
self.step
|
| 393 |
+
)
|
| 394 |
+
print(
|
| 395 |
+
f"Step {self.step} | "
|
| 396 |
+
f"Iteration time: {current_time - self.previous_time:.2f} seconds | "
|
| 397 |
+
)
|
| 398 |
+
self.previous_time = current_time
|
trainer/gan.py
ADDED
|
@@ -0,0 +1,464 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gc
|
| 2 |
+
import logging
|
| 3 |
+
|
| 4 |
+
from utils.dataset import ShardingLMDBDataset, cycle
|
| 5 |
+
from utils.distributed import EMA_FSDP, fsdp_wrap, fsdp_state_dict, launch_distributed_job
|
| 6 |
+
from utils.misc import (
|
| 7 |
+
set_seed,
|
| 8 |
+
merge_dict_list
|
| 9 |
+
)
|
| 10 |
+
import torch.distributed as dist
|
| 11 |
+
from omegaconf import OmegaConf
|
| 12 |
+
from model import GAN
|
| 13 |
+
import torch
|
| 14 |
+
import wandb
|
| 15 |
+
import time
|
| 16 |
+
import os
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class Trainer:
|
| 20 |
+
def __init__(self, config):
|
| 21 |
+
self.config = config
|
| 22 |
+
self.step = 0
|
| 23 |
+
|
| 24 |
+
# Step 1: Initialize the distributed training environment (rank, seed, dtype, logging etc.)
|
| 25 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 26 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 27 |
+
|
| 28 |
+
launch_distributed_job()
|
| 29 |
+
global_rank = dist.get_rank()
|
| 30 |
+
self.world_size = dist.get_world_size()
|
| 31 |
+
|
| 32 |
+
self.dtype = torch.bfloat16 if config.mixed_precision else torch.float32
|
| 33 |
+
self.device = torch.cuda.current_device()
|
| 34 |
+
self.is_main_process = global_rank == 0
|
| 35 |
+
self.causal = config.causal
|
| 36 |
+
self.disable_wandb = config.disable_wandb
|
| 37 |
+
|
| 38 |
+
# Configuration for discriminator warmup
|
| 39 |
+
self.discriminator_warmup_steps = getattr(config, "discriminator_warmup_steps", 0)
|
| 40 |
+
self.in_discriminator_warmup = self.step < self.discriminator_warmup_steps
|
| 41 |
+
if self.in_discriminator_warmup and self.is_main_process:
|
| 42 |
+
print(f"Starting with discriminator warmup for {self.discriminator_warmup_steps} steps")
|
| 43 |
+
self.loss_scale = getattr(config, "loss_scale", 1.0)
|
| 44 |
+
|
| 45 |
+
# use a random seed for the training
|
| 46 |
+
if config.seed == 0:
|
| 47 |
+
random_seed = torch.randint(0, 10000000, (1,), device=self.device)
|
| 48 |
+
dist.broadcast(random_seed, src=0)
|
| 49 |
+
config.seed = random_seed.item()
|
| 50 |
+
|
| 51 |
+
set_seed(config.seed + global_rank)
|
| 52 |
+
|
| 53 |
+
if self.is_main_process and not self.disable_wandb:
|
| 54 |
+
wandb.login(host=config.wandb_host, key=config.wandb_key)
|
| 55 |
+
wandb.init(
|
| 56 |
+
config=OmegaConf.to_container(config, resolve=True),
|
| 57 |
+
name=config.config_name,
|
| 58 |
+
mode="online",
|
| 59 |
+
entity=config.wandb_entity,
|
| 60 |
+
project=config.wandb_project,
|
| 61 |
+
dir=config.wandb_save_dir
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
self.output_path = config.logdir
|
| 65 |
+
|
| 66 |
+
# Step 2: Initialize the model and optimizer
|
| 67 |
+
self.model = GAN(config, device=self.device)
|
| 68 |
+
|
| 69 |
+
self.model.generator = fsdp_wrap(
|
| 70 |
+
self.model.generator,
|
| 71 |
+
sharding_strategy=config.sharding_strategy,
|
| 72 |
+
mixed_precision=config.mixed_precision,
|
| 73 |
+
wrap_strategy=config.generator_fsdp_wrap_strategy
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
self.model.fake_score = fsdp_wrap(
|
| 77 |
+
self.model.fake_score,
|
| 78 |
+
sharding_strategy=config.sharding_strategy,
|
| 79 |
+
mixed_precision=config.mixed_precision,
|
| 80 |
+
wrap_strategy=config.fake_score_fsdp_wrap_strategy
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
self.model.text_encoder = fsdp_wrap(
|
| 84 |
+
self.model.text_encoder,
|
| 85 |
+
sharding_strategy=config.sharding_strategy,
|
| 86 |
+
mixed_precision=config.mixed_precision,
|
| 87 |
+
wrap_strategy=config.text_encoder_fsdp_wrap_strategy,
|
| 88 |
+
cpu_offload=getattr(config, "text_encoder_cpu_offload", False)
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
if not config.no_visualize or config.load_raw_video:
|
| 92 |
+
self.model.vae = self.model.vae.to(
|
| 93 |
+
device=self.device, dtype=torch.bfloat16 if config.mixed_precision else torch.float32)
|
| 94 |
+
|
| 95 |
+
self.generator_optimizer = torch.optim.AdamW(
|
| 96 |
+
[param for param in self.model.generator.parameters()
|
| 97 |
+
if param.requires_grad],
|
| 98 |
+
lr=config.gen_lr,
|
| 99 |
+
betas=(config.beta1, config.beta2)
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
# Create separate parameter groups for the fake_score network
|
| 103 |
+
# One group for parameters with "_cls_pred_branch" or "_gan_ca_blocks" in the name
|
| 104 |
+
# and another group for all other parameters
|
| 105 |
+
fake_score_params = []
|
| 106 |
+
discriminator_params = []
|
| 107 |
+
|
| 108 |
+
for name, param in self.model.fake_score.named_parameters():
|
| 109 |
+
if param.requires_grad:
|
| 110 |
+
if "_cls_pred_branch" in name or "_gan_ca_blocks" in name:
|
| 111 |
+
discriminator_params.append(param)
|
| 112 |
+
else:
|
| 113 |
+
fake_score_params.append(param)
|
| 114 |
+
|
| 115 |
+
# Use the special learning rate for the special parameter group
|
| 116 |
+
# and the default critic learning rate for other parameters
|
| 117 |
+
self.critic_param_groups = [
|
| 118 |
+
{'params': fake_score_params, 'lr': config.critic_lr},
|
| 119 |
+
{'params': discriminator_params, 'lr': config.critic_lr * config.discriminator_lr_multiplier}
|
| 120 |
+
]
|
| 121 |
+
if self.in_discriminator_warmup:
|
| 122 |
+
self.critic_optimizer = torch.optim.AdamW(
|
| 123 |
+
self.critic_param_groups,
|
| 124 |
+
betas=(0.9, config.beta2_critic)
|
| 125 |
+
)
|
| 126 |
+
else:
|
| 127 |
+
self.critic_optimizer = torch.optim.AdamW(
|
| 128 |
+
self.critic_param_groups,
|
| 129 |
+
betas=(config.beta1_critic, config.beta2_critic)
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
# Step 3: Initialize the dataloader
|
| 133 |
+
self.data_path = config.data_path
|
| 134 |
+
dataset = ShardingLMDBDataset(config.data_path, max_pair=int(1e8))
|
| 135 |
+
sampler = torch.utils.data.distributed.DistributedSampler(
|
| 136 |
+
dataset, shuffle=True, drop_last=True)
|
| 137 |
+
dataloader = torch.utils.data.DataLoader(
|
| 138 |
+
dataset,
|
| 139 |
+
batch_size=config.batch_size,
|
| 140 |
+
sampler=sampler,
|
| 141 |
+
num_workers=8)
|
| 142 |
+
|
| 143 |
+
if dist.get_rank() == 0:
|
| 144 |
+
print("DATASET SIZE %d" % len(dataset))
|
| 145 |
+
|
| 146 |
+
self.dataloader = cycle(dataloader)
|
| 147 |
+
|
| 148 |
+
##############################################################################################################
|
| 149 |
+
# 6. Set up EMA parameter containers
|
| 150 |
+
rename_param = (
|
| 151 |
+
lambda name: name.replace("_fsdp_wrapped_module.", "")
|
| 152 |
+
.replace("_checkpoint_wrapped_module.", "")
|
| 153 |
+
.replace("_orig_mod.", "")
|
| 154 |
+
)
|
| 155 |
+
self.name_to_trainable_params = {}
|
| 156 |
+
for n, p in self.model.generator.named_parameters():
|
| 157 |
+
if not p.requires_grad:
|
| 158 |
+
continue
|
| 159 |
+
|
| 160 |
+
renamed_n = rename_param(n)
|
| 161 |
+
self.name_to_trainable_params[renamed_n] = p
|
| 162 |
+
ema_weight = config.ema_weight
|
| 163 |
+
self.generator_ema = None
|
| 164 |
+
if (ema_weight is not None) and (ema_weight > 0.0):
|
| 165 |
+
print(f"Setting up EMA with weight {ema_weight}")
|
| 166 |
+
self.generator_ema = EMA_FSDP(self.model.generator, decay=ema_weight)
|
| 167 |
+
|
| 168 |
+
##############################################################################################################
|
| 169 |
+
# 7. (If resuming) Load the model and optimizer, lr_scheduler, ema's statedicts
|
| 170 |
+
if getattr(config, "generator_ckpt", False):
|
| 171 |
+
print(f"Loading pretrained generator from {config.generator_ckpt}")
|
| 172 |
+
state_dict = torch.load(config.generator_ckpt, map_location="cpu")
|
| 173 |
+
if "generator" in state_dict:
|
| 174 |
+
state_dict = state_dict["generator"]
|
| 175 |
+
elif "model" in state_dict:
|
| 176 |
+
state_dict = state_dict["model"]
|
| 177 |
+
self.model.generator.load_state_dict(
|
| 178 |
+
state_dict, strict=True
|
| 179 |
+
)
|
| 180 |
+
if hasattr(config, "load"):
|
| 181 |
+
resume_ckpt_path_critic = os.path.join(config.load, "critic")
|
| 182 |
+
resume_ckpt_path_generator = os.path.join(config.load, "generator")
|
| 183 |
+
else:
|
| 184 |
+
resume_ckpt_path_critic = "none"
|
| 185 |
+
resume_ckpt_path_generator = "none"
|
| 186 |
+
|
| 187 |
+
_, _ = self.checkpointer_critic.try_best_load(
|
| 188 |
+
resume_ckpt_path=resume_ckpt_path_critic,
|
| 189 |
+
)
|
| 190 |
+
self.step, _ = self.checkpointer_generator.try_best_load(
|
| 191 |
+
resume_ckpt_path=resume_ckpt_path_generator,
|
| 192 |
+
force_start_w_ema=config.force_start_w_ema,
|
| 193 |
+
force_reset_zero_step=config.force_reset_zero_step,
|
| 194 |
+
force_reinit_ema=config.force_reinit_ema,
|
| 195 |
+
skip_optimizer_scheduler=config.skip_optimizer_scheduler,
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
##############################################################################################################
|
| 199 |
+
|
| 200 |
+
# Let's delete EMA params for early steps to save some computes at training and inference
|
| 201 |
+
if self.step < config.ema_start_step:
|
| 202 |
+
self.generator_ema = None
|
| 203 |
+
|
| 204 |
+
self.max_grad_norm_generator = getattr(config, "max_grad_norm_generator", 10.0)
|
| 205 |
+
self.max_grad_norm_critic = getattr(config, "max_grad_norm_critic", 10.0)
|
| 206 |
+
self.previous_time = None
|
| 207 |
+
|
| 208 |
+
def save(self):
|
| 209 |
+
print("Start gathering distributed model states...")
|
| 210 |
+
generator_state_dict = fsdp_state_dict(
|
| 211 |
+
self.model.generator)
|
| 212 |
+
critic_state_dict = fsdp_state_dict(
|
| 213 |
+
self.model.fake_score)
|
| 214 |
+
|
| 215 |
+
if self.config.ema_start_step < self.step:
|
| 216 |
+
state_dict = {
|
| 217 |
+
"generator": generator_state_dict,
|
| 218 |
+
"critic": critic_state_dict,
|
| 219 |
+
"generator_ema": self.generator_ema.state_dict(),
|
| 220 |
+
}
|
| 221 |
+
else:
|
| 222 |
+
state_dict = {
|
| 223 |
+
"generator": generator_state_dict,
|
| 224 |
+
"critic": critic_state_dict,
|
| 225 |
+
}
|
| 226 |
+
|
| 227 |
+
if self.is_main_process:
|
| 228 |
+
os.makedirs(os.path.join(self.output_path,
|
| 229 |
+
f"checkpoint_model_{self.step:06d}"), exist_ok=True)
|
| 230 |
+
torch.save(state_dict, os.path.join(self.output_path,
|
| 231 |
+
f"checkpoint_model_{self.step:06d}", "model.pt"))
|
| 232 |
+
print("Model saved to", os.path.join(self.output_path,
|
| 233 |
+
f"checkpoint_model_{self.step:06d}", "model.pt"))
|
| 234 |
+
|
| 235 |
+
def fwdbwd_one_step(self, batch, train_generator):
|
| 236 |
+
self.model.eval() # prevent any randomness (e.g. dropout)
|
| 237 |
+
|
| 238 |
+
if self.step % 20 == 0:
|
| 239 |
+
torch.cuda.empty_cache()
|
| 240 |
+
|
| 241 |
+
# Step 1: Get the next batch of text prompts
|
| 242 |
+
text_prompts = batch["prompts"] # next(self.dataloader)
|
| 243 |
+
if "ode_latent" in batch:
|
| 244 |
+
clean_latent = batch["ode_latent"][:, -1].to(device=self.device, dtype=self.dtype)
|
| 245 |
+
else:
|
| 246 |
+
frames = batch["frames"].to(device=self.device, dtype=self.dtype)
|
| 247 |
+
with torch.no_grad():
|
| 248 |
+
clean_latent = self.model.vae.encode_to_latent(
|
| 249 |
+
frames).to(device=self.device, dtype=self.dtype)
|
| 250 |
+
|
| 251 |
+
image_latent = clean_latent[:, 0:1, ]
|
| 252 |
+
|
| 253 |
+
batch_size = len(text_prompts)
|
| 254 |
+
image_or_video_shape = list(self.config.image_or_video_shape)
|
| 255 |
+
image_or_video_shape[0] = batch_size
|
| 256 |
+
|
| 257 |
+
# Step 2: Extract the conditional infos
|
| 258 |
+
with torch.no_grad():
|
| 259 |
+
conditional_dict = self.model.text_encoder(
|
| 260 |
+
text_prompts=text_prompts)
|
| 261 |
+
|
| 262 |
+
if not getattr(self, "unconditional_dict", None):
|
| 263 |
+
unconditional_dict = self.model.text_encoder(
|
| 264 |
+
text_prompts=[self.config.negative_prompt] * batch_size)
|
| 265 |
+
unconditional_dict = {k: v.detach()
|
| 266 |
+
for k, v in unconditional_dict.items()}
|
| 267 |
+
self.unconditional_dict = unconditional_dict # cache the unconditional_dict
|
| 268 |
+
else:
|
| 269 |
+
unconditional_dict = self.unconditional_dict
|
| 270 |
+
|
| 271 |
+
mini_bs, full_bs = (
|
| 272 |
+
batch["mini_bs"],
|
| 273 |
+
batch["full_bs"],
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
# Step 3: Store gradients for the generator (if training the generator)
|
| 277 |
+
if train_generator:
|
| 278 |
+
gan_G_loss = self.model.generator_loss(
|
| 279 |
+
image_or_video_shape=image_or_video_shape,
|
| 280 |
+
conditional_dict=conditional_dict,
|
| 281 |
+
unconditional_dict=unconditional_dict,
|
| 282 |
+
clean_latent=clean_latent,
|
| 283 |
+
initial_latent=image_latent if self.config.i2v else None
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
loss_ratio = mini_bs * self.world_size / full_bs
|
| 287 |
+
total_loss = gan_G_loss * loss_ratio * self.loss_scale
|
| 288 |
+
|
| 289 |
+
total_loss.backward()
|
| 290 |
+
generator_grad_norm = self.model.generator.clip_grad_norm_(
|
| 291 |
+
self.max_grad_norm_generator)
|
| 292 |
+
|
| 293 |
+
generator_log_dict = {"generator_grad_norm": generator_grad_norm,
|
| 294 |
+
"gan_G_loss": gan_G_loss}
|
| 295 |
+
|
| 296 |
+
return generator_log_dict
|
| 297 |
+
else:
|
| 298 |
+
generator_log_dict = {}
|
| 299 |
+
|
| 300 |
+
# Step 4: Store gradients for the critic (if training the critic)
|
| 301 |
+
(gan_D_loss, r1_loss, r2_loss), critic_log_dict = self.model.critic_loss(
|
| 302 |
+
image_or_video_shape=image_or_video_shape,
|
| 303 |
+
conditional_dict=conditional_dict,
|
| 304 |
+
unconditional_dict=unconditional_dict,
|
| 305 |
+
clean_latent=clean_latent,
|
| 306 |
+
real_image_or_video=clean_latent,
|
| 307 |
+
initial_latent=image_latent if self.config.i2v else None
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
loss_ratio = mini_bs * dist.get_world_size() / full_bs
|
| 311 |
+
total_loss = (gan_D_loss + 0.5 * (r1_loss + r2_loss)) * loss_ratio * self.loss_scale
|
| 312 |
+
|
| 313 |
+
total_loss.backward()
|
| 314 |
+
critic_grad_norm = self.model.fake_score.clip_grad_norm_(
|
| 315 |
+
self.max_grad_norm_critic)
|
| 316 |
+
|
| 317 |
+
critic_log_dict.update({"critic_grad_norm": critic_grad_norm,
|
| 318 |
+
"gan_D_loss": gan_D_loss,
|
| 319 |
+
"r1_loss": r1_loss,
|
| 320 |
+
"r2_loss": r2_loss})
|
| 321 |
+
|
| 322 |
+
return critic_log_dict
|
| 323 |
+
|
| 324 |
+
def generate_video(self, pipeline, prompts, image=None):
|
| 325 |
+
batch_size = len(prompts)
|
| 326 |
+
sampled_noise = torch.randn(
|
| 327 |
+
[batch_size, 21, 16, 60, 104], device="cuda", dtype=self.dtype
|
| 328 |
+
)
|
| 329 |
+
video, _ = pipeline.inference(
|
| 330 |
+
noise=sampled_noise,
|
| 331 |
+
text_prompts=prompts,
|
| 332 |
+
return_latents=True
|
| 333 |
+
)
|
| 334 |
+
current_video = video.permute(0, 1, 3, 4, 2).cpu().numpy() * 255.0
|
| 335 |
+
return current_video
|
| 336 |
+
|
| 337 |
+
def train(self):
|
| 338 |
+
start_step = self.step
|
| 339 |
+
|
| 340 |
+
while True:
|
| 341 |
+
if self.step == self.discriminator_warmup_steps and self.discriminator_warmup_steps != 0:
|
| 342 |
+
print("Resetting critic optimizer")
|
| 343 |
+
del self.critic_optimizer
|
| 344 |
+
torch.cuda.empty_cache()
|
| 345 |
+
# Create new optimizers
|
| 346 |
+
self.critic_optimizer = torch.optim.AdamW(
|
| 347 |
+
self.critic_param_groups,
|
| 348 |
+
betas=(self.config.beta1_critic, self.config.beta2_critic)
|
| 349 |
+
)
|
| 350 |
+
# Update checkpointer references
|
| 351 |
+
self.checkpointer_critic.optimizer = self.critic_optimizer
|
| 352 |
+
# Check if we're in the discriminator warmup phase
|
| 353 |
+
self.in_discriminator_warmup = self.step < self.discriminator_warmup_steps
|
| 354 |
+
|
| 355 |
+
# Only update generator and critic outside the warmup phase
|
| 356 |
+
TRAIN_GENERATOR = not self.in_discriminator_warmup and self.step % self.config.dfake_gen_update_ratio == 0
|
| 357 |
+
|
| 358 |
+
# Train the generator (only outside warmup phase)
|
| 359 |
+
if TRAIN_GENERATOR:
|
| 360 |
+
self.model.fake_score.requires_grad_(False)
|
| 361 |
+
self.model.generator.requires_grad_(True)
|
| 362 |
+
self.generator_optimizer.zero_grad(set_to_none=True)
|
| 363 |
+
extras_list = []
|
| 364 |
+
for ii, mini_batch in enumerate(self.dataloader.next()):
|
| 365 |
+
extra = self.fwdbwd_one_step(mini_batch, True)
|
| 366 |
+
extras_list.append(extra)
|
| 367 |
+
generator_log_dict = merge_dict_list(extras_list)
|
| 368 |
+
self.generator_optimizer.step()
|
| 369 |
+
if self.generator_ema is not None:
|
| 370 |
+
self.generator_ema.update(self.model.generator)
|
| 371 |
+
else:
|
| 372 |
+
generator_log_dict = {}
|
| 373 |
+
|
| 374 |
+
# Train the critic/discriminator
|
| 375 |
+
if self.in_discriminator_warmup:
|
| 376 |
+
# During warmup, only allow gradient for discriminator params
|
| 377 |
+
self.model.generator.requires_grad_(False)
|
| 378 |
+
self.model.fake_score.requires_grad_(False)
|
| 379 |
+
|
| 380 |
+
# Enable gradient only for discriminator params
|
| 381 |
+
for name, param in self.model.fake_score.named_parameters():
|
| 382 |
+
if "_cls_pred_branch" in name or "_gan_ca_blocks" in name:
|
| 383 |
+
param.requires_grad_(True)
|
| 384 |
+
else:
|
| 385 |
+
# Normal training mode
|
| 386 |
+
self.model.generator.requires_grad_(False)
|
| 387 |
+
self.model.fake_score.requires_grad_(True)
|
| 388 |
+
|
| 389 |
+
self.critic_optimizer.zero_grad(set_to_none=True)
|
| 390 |
+
extras_list = []
|
| 391 |
+
batch = next(self.dataloader)
|
| 392 |
+
extra = self.fwdbwd_one_step(batch, False)
|
| 393 |
+
extras_list.append(extra)
|
| 394 |
+
critic_log_dict = merge_dict_list(extras_list)
|
| 395 |
+
self.critic_optimizer.step()
|
| 396 |
+
|
| 397 |
+
# Increment the step since we finished gradient update
|
| 398 |
+
self.step += 1
|
| 399 |
+
|
| 400 |
+
# If we just finished warmup, print a message
|
| 401 |
+
if self.is_main_process and self.step == self.discriminator_warmup_steps:
|
| 402 |
+
print(f"Finished discriminator warmup after {self.discriminator_warmup_steps} steps")
|
| 403 |
+
|
| 404 |
+
# Create EMA params (if not already created)
|
| 405 |
+
if (self.step >= self.config.ema_start_step) and \
|
| 406 |
+
(self.generator_ema is None) and (self.config.ema_weight > 0):
|
| 407 |
+
self.generator_ema = EMA_FSDP(self.model.generator, decay=self.config.ema_weight)
|
| 408 |
+
|
| 409 |
+
# Save the model
|
| 410 |
+
if (not self.config.no_save) and (self.step - start_step) > 0 and self.step % self.config.log_iters == 0:
|
| 411 |
+
torch.cuda.empty_cache()
|
| 412 |
+
self.save()
|
| 413 |
+
torch.cuda.empty_cache()
|
| 414 |
+
|
| 415 |
+
# Logging
|
| 416 |
+
wandb_loss_dict = {
|
| 417 |
+
"generator_grad_norm": generator_log_dict["generator_grad_norm"],
|
| 418 |
+
"critic_grad_norm": critic_log_dict["critic_grad_norm"],
|
| 419 |
+
"real_logit": critic_log_dict["noisy_real_logit"],
|
| 420 |
+
"fake_logit": critic_log_dict["noisy_fake_logit"],
|
| 421 |
+
"r1_loss": critic_log_dict["r1_loss"],
|
| 422 |
+
"r2_loss": critic_log_dict["r2_loss"],
|
| 423 |
+
}
|
| 424 |
+
if TRAIN_GENERATOR:
|
| 425 |
+
wandb_loss_dict.update({
|
| 426 |
+
"generator_grad_norm": generator_log_dict["generator_grad_norm"],
|
| 427 |
+
})
|
| 428 |
+
self.all_gather_dict(wandb_loss_dict)
|
| 429 |
+
wandb_loss_dict["diff_logit"] = wandb_loss_dict["real_logit"] - wandb_loss_dict["fake_logit"]
|
| 430 |
+
wandb_loss_dict["reg_loss"] = 0.5 * (wandb_loss_dict["r1_loss"] + wandb_loss_dict["r2_loss"])
|
| 431 |
+
|
| 432 |
+
if self.is_main_process:
|
| 433 |
+
if self.in_discriminator_warmup:
|
| 434 |
+
warmup_status = f"[WARMUP {self.step}/{self.discriminator_warmup_steps}] Training only discriminator params"
|
| 435 |
+
print(warmup_status)
|
| 436 |
+
if not self.disable_wandb:
|
| 437 |
+
wandb_loss_dict.update({"warmup_status": 1.0})
|
| 438 |
+
|
| 439 |
+
if not self.disable_wandb:
|
| 440 |
+
wandb.log(wandb_loss_dict, step=self.step)
|
| 441 |
+
|
| 442 |
+
if self.step % self.config.gc_interval == 0:
|
| 443 |
+
if dist.get_rank() == 0:
|
| 444 |
+
logging.info("DistGarbageCollector: Running GC.")
|
| 445 |
+
gc.collect()
|
| 446 |
+
torch.cuda.empty_cache()
|
| 447 |
+
|
| 448 |
+
if self.is_main_process:
|
| 449 |
+
current_time = time.time()
|
| 450 |
+
if self.previous_time is None:
|
| 451 |
+
self.previous_time = current_time
|
| 452 |
+
else:
|
| 453 |
+
if not self.disable_wandb:
|
| 454 |
+
wandb.log({"per iteration time": current_time - self.previous_time}, step=self.step)
|
| 455 |
+
self.previous_time = current_time
|
| 456 |
+
|
| 457 |
+
def all_gather_dict(self, target_dict):
|
| 458 |
+
for key, value in target_dict.items():
|
| 459 |
+
gathered_value = torch.zeros(
|
| 460 |
+
[self.world_size, *value.shape],
|
| 461 |
+
dtype=value.dtype, device=self.device)
|
| 462 |
+
dist.all_gather_into_tensor(gathered_value, value)
|
| 463 |
+
avg_value = gathered_value.mean().item()
|
| 464 |
+
target_dict[key] = avg_value
|
trainer/ode.py
ADDED
|
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gc
|
| 2 |
+
import logging
|
| 3 |
+
from utils.dataset import ODERegressionLMDBDataset, cycle
|
| 4 |
+
from model import ODERegression
|
| 5 |
+
from collections import defaultdict
|
| 6 |
+
from utils.misc import (
|
| 7 |
+
set_seed
|
| 8 |
+
)
|
| 9 |
+
import torch.distributed as dist
|
| 10 |
+
from omegaconf import OmegaConf
|
| 11 |
+
import torch
|
| 12 |
+
import wandb
|
| 13 |
+
import time
|
| 14 |
+
import os
|
| 15 |
+
|
| 16 |
+
from utils.distributed import barrier, fsdp_wrap, fsdp_state_dict, launch_distributed_job
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class Trainer:
|
| 20 |
+
def __init__(self, config):
|
| 21 |
+
self.config = config
|
| 22 |
+
self.step = 0
|
| 23 |
+
|
| 24 |
+
# Step 1: Initialize the distributed training environment (rank, seed, dtype, logging etc.)
|
| 25 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 26 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 27 |
+
|
| 28 |
+
launch_distributed_job()
|
| 29 |
+
global_rank = dist.get_rank()
|
| 30 |
+
self.world_size = dist.get_world_size()
|
| 31 |
+
|
| 32 |
+
self.dtype = torch.bfloat16 if config.mixed_precision else torch.float32
|
| 33 |
+
self.device = torch.cuda.current_device()
|
| 34 |
+
self.is_main_process = global_rank == 0
|
| 35 |
+
self.disable_wandb = config.disable_wandb
|
| 36 |
+
|
| 37 |
+
# use a random seed for the training
|
| 38 |
+
if config.seed == 0:
|
| 39 |
+
random_seed = torch.randint(0, 10000000, (1,), device=self.device)
|
| 40 |
+
dist.broadcast(random_seed, src=0)
|
| 41 |
+
config.seed = random_seed.item()
|
| 42 |
+
|
| 43 |
+
set_seed(config.seed + global_rank)
|
| 44 |
+
|
| 45 |
+
if self.is_main_process and not self.disable_wandb:
|
| 46 |
+
wandb.login(host=config.wandb_host, key=config.wandb_key)
|
| 47 |
+
wandb.init(
|
| 48 |
+
config=OmegaConf.to_container(config, resolve=True),
|
| 49 |
+
name=config.config_name,
|
| 50 |
+
mode="online",
|
| 51 |
+
entity=config.wandb_entity,
|
| 52 |
+
project=config.wandb_project,
|
| 53 |
+
dir=config.wandb_save_dir
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
self.output_path = config.logdir
|
| 57 |
+
|
| 58 |
+
# Step 2: Initialize the model and optimizer
|
| 59 |
+
|
| 60 |
+
assert config.distribution_loss == "ode", "Only ODE loss is supported for ODE training"
|
| 61 |
+
self.model = ODERegression(config, device=self.device)
|
| 62 |
+
|
| 63 |
+
self.model.generator = fsdp_wrap(
|
| 64 |
+
self.model.generator,
|
| 65 |
+
sharding_strategy=config.sharding_strategy,
|
| 66 |
+
mixed_precision=config.mixed_precision,
|
| 67 |
+
wrap_strategy=config.generator_fsdp_wrap_strategy
|
| 68 |
+
)
|
| 69 |
+
self.model.text_encoder = fsdp_wrap(
|
| 70 |
+
self.model.text_encoder,
|
| 71 |
+
sharding_strategy=config.sharding_strategy,
|
| 72 |
+
mixed_precision=config.mixed_precision,
|
| 73 |
+
wrap_strategy=config.text_encoder_fsdp_wrap_strategy,
|
| 74 |
+
cpu_offload=getattr(config, "text_encoder_cpu_offload", False)
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
if not config.no_visualize or config.load_raw_video:
|
| 78 |
+
self.model.vae = self.model.vae.to(
|
| 79 |
+
device=self.device, dtype=torch.bfloat16 if config.mixed_precision else torch.float32)
|
| 80 |
+
|
| 81 |
+
self.generator_optimizer = torch.optim.AdamW(
|
| 82 |
+
[param for param in self.model.generator.parameters()
|
| 83 |
+
if param.requires_grad],
|
| 84 |
+
lr=config.lr,
|
| 85 |
+
betas=(config.beta1, config.beta2),
|
| 86 |
+
weight_decay=config.weight_decay
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
# Step 3: Initialize the dataloader
|
| 90 |
+
dataset = ODERegressionLMDBDataset(
|
| 91 |
+
config.data_path, max_pair=getattr(config, "max_pair", int(1e8)))
|
| 92 |
+
sampler = torch.utils.data.distributed.DistributedSampler(
|
| 93 |
+
dataset, shuffle=True, drop_last=True)
|
| 94 |
+
dataloader = torch.utils.data.DataLoader(
|
| 95 |
+
dataset, batch_size=config.batch_size, sampler=sampler, num_workers=8)
|
| 96 |
+
total_batch_size = getattr(config, "total_batch_size", None)
|
| 97 |
+
if total_batch_size is not None:
|
| 98 |
+
assert total_batch_size == config.batch_size * self.world_size, "Gradient accumulation is not supported for ODE training"
|
| 99 |
+
self.dataloader = cycle(dataloader)
|
| 100 |
+
|
| 101 |
+
self.step = 0
|
| 102 |
+
|
| 103 |
+
##############################################################################################################
|
| 104 |
+
# 7. (If resuming) Load the model and optimizer, lr_scheduler, ema's statedicts
|
| 105 |
+
if getattr(config, "generator_ckpt", False):
|
| 106 |
+
print(f"Loading pretrained generator from {config.generator_ckpt}")
|
| 107 |
+
state_dict = torch.load(config.generator_ckpt, map_location="cpu")[
|
| 108 |
+
'generator']
|
| 109 |
+
self.model.generator.load_state_dict(
|
| 110 |
+
state_dict, strict=True
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
##############################################################################################################
|
| 114 |
+
|
| 115 |
+
self.max_grad_norm = 10.0
|
| 116 |
+
self.previous_time = None
|
| 117 |
+
|
| 118 |
+
def save(self):
|
| 119 |
+
print("Start gathering distributed model states...")
|
| 120 |
+
generator_state_dict = fsdp_state_dict(
|
| 121 |
+
self.model.generator)
|
| 122 |
+
state_dict = {
|
| 123 |
+
"generator": generator_state_dict
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
if self.is_main_process:
|
| 127 |
+
os.makedirs(os.path.join(self.output_path,
|
| 128 |
+
f"checkpoint_model_{self.step:06d}"), exist_ok=True)
|
| 129 |
+
torch.save(state_dict, os.path.join(self.output_path,
|
| 130 |
+
f"checkpoint_model_{self.step:06d}", "model.pt"))
|
| 131 |
+
print("Model saved to", os.path.join(self.output_path,
|
| 132 |
+
f"checkpoint_model_{self.step:06d}", "model.pt"))
|
| 133 |
+
|
| 134 |
+
def train_one_step(self):
|
| 135 |
+
VISUALIZE = self.step % 100 == 0
|
| 136 |
+
self.model.eval() # prevent any randomness (e.g. dropout)
|
| 137 |
+
|
| 138 |
+
# Step 1: Get the next batch of text prompts
|
| 139 |
+
batch = next(self.dataloader)
|
| 140 |
+
text_prompts = batch["prompts"]
|
| 141 |
+
ode_latent = batch["ode_latent"].to(
|
| 142 |
+
device=self.device, dtype=self.dtype)
|
| 143 |
+
|
| 144 |
+
# Step 2: Extract the conditional infos
|
| 145 |
+
with torch.no_grad():
|
| 146 |
+
conditional_dict = self.model.text_encoder(
|
| 147 |
+
text_prompts=text_prompts)
|
| 148 |
+
|
| 149 |
+
# Step 3: Train the generator
|
| 150 |
+
generator_loss, log_dict = self.model.generator_loss(
|
| 151 |
+
ode_latent=ode_latent,
|
| 152 |
+
conditional_dict=conditional_dict
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
unnormalized_loss = log_dict["unnormalized_loss"]
|
| 156 |
+
timestep = log_dict["timestep"]
|
| 157 |
+
|
| 158 |
+
if self.world_size > 1:
|
| 159 |
+
gathered_unnormalized_loss = torch.zeros(
|
| 160 |
+
[self.world_size, *unnormalized_loss.shape],
|
| 161 |
+
dtype=unnormalized_loss.dtype, device=self.device)
|
| 162 |
+
gathered_timestep = torch.zeros(
|
| 163 |
+
[self.world_size, *timestep.shape],
|
| 164 |
+
dtype=timestep.dtype, device=self.device)
|
| 165 |
+
|
| 166 |
+
dist.all_gather_into_tensor(
|
| 167 |
+
gathered_unnormalized_loss, unnormalized_loss)
|
| 168 |
+
dist.all_gather_into_tensor(gathered_timestep, timestep)
|
| 169 |
+
else:
|
| 170 |
+
gathered_unnormalized_loss = unnormalized_loss
|
| 171 |
+
gathered_timestep = timestep
|
| 172 |
+
|
| 173 |
+
loss_breakdown = defaultdict(list)
|
| 174 |
+
stats = {}
|
| 175 |
+
|
| 176 |
+
for index, t in enumerate(timestep):
|
| 177 |
+
loss_breakdown[str(int(t.item()) // 250 * 250)].append(
|
| 178 |
+
unnormalized_loss[index].item())
|
| 179 |
+
|
| 180 |
+
for key_t in loss_breakdown.keys():
|
| 181 |
+
stats["loss_at_time_" + key_t] = sum(loss_breakdown[key_t]) / \
|
| 182 |
+
len(loss_breakdown[key_t])
|
| 183 |
+
|
| 184 |
+
self.generator_optimizer.zero_grad()
|
| 185 |
+
generator_loss.backward()
|
| 186 |
+
generator_grad_norm = self.model.generator.clip_grad_norm_(
|
| 187 |
+
self.max_grad_norm)
|
| 188 |
+
self.generator_optimizer.step()
|
| 189 |
+
|
| 190 |
+
# Step 4: Visualization
|
| 191 |
+
if VISUALIZE and not self.config.no_visualize and not self.config.disable_wandb and self.is_main_process:
|
| 192 |
+
# Visualize the input, output, and ground truth
|
| 193 |
+
input = log_dict["input"]
|
| 194 |
+
output = log_dict["output"]
|
| 195 |
+
ground_truth = ode_latent[:, -1]
|
| 196 |
+
|
| 197 |
+
input_video = self.model.vae.decode_to_pixel(input)
|
| 198 |
+
output_video = self.model.vae.decode_to_pixel(output)
|
| 199 |
+
ground_truth_video = self.model.vae.decode_to_pixel(ground_truth)
|
| 200 |
+
input_video = 255.0 * (input_video.cpu().numpy() * 0.5 + 0.5)
|
| 201 |
+
output_video = 255.0 * (output_video.cpu().numpy() * 0.5 + 0.5)
|
| 202 |
+
ground_truth_video = 255.0 * (ground_truth_video.cpu().numpy() * 0.5 + 0.5)
|
| 203 |
+
|
| 204 |
+
# Visualize the input, output, and ground truth
|
| 205 |
+
wandb.log({
|
| 206 |
+
"input": wandb.Video(input_video, caption="Input", fps=16, format="mp4"),
|
| 207 |
+
"output": wandb.Video(output_video, caption="Output", fps=16, format="mp4"),
|
| 208 |
+
"ground_truth": wandb.Video(ground_truth_video, caption="Ground Truth", fps=16, format="mp4"),
|
| 209 |
+
}, step=self.step)
|
| 210 |
+
|
| 211 |
+
# Step 5: Logging
|
| 212 |
+
if self.is_main_process and not self.disable_wandb:
|
| 213 |
+
wandb_loss_dict = {
|
| 214 |
+
"generator_loss": generator_loss.item(),
|
| 215 |
+
"generator_grad_norm": generator_grad_norm.item(),
|
| 216 |
+
**stats
|
| 217 |
+
}
|
| 218 |
+
wandb.log(wandb_loss_dict, step=self.step)
|
| 219 |
+
|
| 220 |
+
if self.step % self.config.gc_interval == 0:
|
| 221 |
+
if dist.get_rank() == 0:
|
| 222 |
+
logging.info("DistGarbageCollector: Running GC.")
|
| 223 |
+
gc.collect()
|
| 224 |
+
|
| 225 |
+
def train(self):
|
| 226 |
+
while True:
|
| 227 |
+
self.train_one_step()
|
| 228 |
+
if (not self.config.no_save) and self.step % self.config.log_iters == 0:
|
| 229 |
+
self.save()
|
| 230 |
+
torch.cuda.empty_cache()
|
| 231 |
+
|
| 232 |
+
barrier()
|
| 233 |
+
if self.is_main_process:
|
| 234 |
+
current_time = time.time()
|
| 235 |
+
if self.previous_time is None:
|
| 236 |
+
self.previous_time = current_time
|
| 237 |
+
else:
|
| 238 |
+
if not self.disable_wandb:
|
| 239 |
+
wandb.log({"per iteration time": current_time - self.previous_time}, step=self.step)
|
| 240 |
+
self.previous_time = current_time
|
| 241 |
+
|
| 242 |
+
self.step += 1
|
utils/dataset.py
ADDED
|
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from utils.lmdb import get_array_shape_from_lmdb, retrieve_row_from_lmdb
|
| 2 |
+
from torch.utils.data import Dataset
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import lmdb
|
| 6 |
+
import json
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from PIL import Image
|
| 9 |
+
import os
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class TextDataset(Dataset):
|
| 13 |
+
def __init__(self, prompt_path, extended_prompt_path=None):
|
| 14 |
+
with open(prompt_path, encoding="utf-8") as f:
|
| 15 |
+
self.prompt_list = [line.rstrip() for line in f]
|
| 16 |
+
|
| 17 |
+
if extended_prompt_path is not None:
|
| 18 |
+
with open(extended_prompt_path, encoding="utf-8") as f:
|
| 19 |
+
self.extended_prompt_list = [line.rstrip() for line in f]
|
| 20 |
+
assert len(self.extended_prompt_list) == len(self.prompt_list)
|
| 21 |
+
else:
|
| 22 |
+
self.extended_prompt_list = None
|
| 23 |
+
|
| 24 |
+
def __len__(self):
|
| 25 |
+
return len(self.prompt_list)
|
| 26 |
+
|
| 27 |
+
def __getitem__(self, idx):
|
| 28 |
+
batch = {
|
| 29 |
+
"prompts": self.prompt_list[idx],
|
| 30 |
+
"idx": idx,
|
| 31 |
+
}
|
| 32 |
+
if self.extended_prompt_list is not None:
|
| 33 |
+
batch["extended_prompts"] = self.extended_prompt_list[idx]
|
| 34 |
+
return batch
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class ODERegressionLMDBDataset(Dataset):
|
| 38 |
+
def __init__(self, data_path: str, max_pair: int = int(1e8)):
|
| 39 |
+
self.env = lmdb.open(data_path, readonly=True,
|
| 40 |
+
lock=False, readahead=False, meminit=False)
|
| 41 |
+
|
| 42 |
+
self.latents_shape = get_array_shape_from_lmdb(self.env, 'latents')
|
| 43 |
+
self.max_pair = max_pair
|
| 44 |
+
|
| 45 |
+
def __len__(self):
|
| 46 |
+
return min(self.latents_shape[0], self.max_pair)
|
| 47 |
+
|
| 48 |
+
def __getitem__(self, idx):
|
| 49 |
+
"""
|
| 50 |
+
Outputs:
|
| 51 |
+
- prompts: List of Strings
|
| 52 |
+
- latents: Tensor of shape (num_denoising_steps, num_frames, num_channels, height, width). It is ordered from pure noise to clean image.
|
| 53 |
+
"""
|
| 54 |
+
latents = retrieve_row_from_lmdb(
|
| 55 |
+
self.env,
|
| 56 |
+
"latents", np.float16, idx, shape=self.latents_shape[1:]
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
if len(latents.shape) == 4:
|
| 60 |
+
latents = latents[None, ...]
|
| 61 |
+
|
| 62 |
+
prompts = retrieve_row_from_lmdb(
|
| 63 |
+
self.env,
|
| 64 |
+
"prompts", str, idx
|
| 65 |
+
)
|
| 66 |
+
return {
|
| 67 |
+
"prompts": prompts,
|
| 68 |
+
"ode_latent": torch.tensor(latents, dtype=torch.float32)
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class ShardingLMDBDataset(Dataset):
|
| 73 |
+
def __init__(self, data_path: str, max_pair: int = int(1e8)):
|
| 74 |
+
self.envs = []
|
| 75 |
+
self.index = []
|
| 76 |
+
|
| 77 |
+
for fname in sorted(os.listdir(data_path)):
|
| 78 |
+
path = os.path.join(data_path, fname)
|
| 79 |
+
env = lmdb.open(path,
|
| 80 |
+
readonly=True,
|
| 81 |
+
lock=False,
|
| 82 |
+
readahead=False,
|
| 83 |
+
meminit=False)
|
| 84 |
+
self.envs.append(env)
|
| 85 |
+
|
| 86 |
+
self.latents_shape = [None] * len(self.envs)
|
| 87 |
+
for shard_id, env in enumerate(self.envs):
|
| 88 |
+
self.latents_shape[shard_id] = get_array_shape_from_lmdb(env, 'latents')
|
| 89 |
+
for local_i in range(self.latents_shape[shard_id][0]):
|
| 90 |
+
self.index.append((shard_id, local_i))
|
| 91 |
+
|
| 92 |
+
# print("shard_id ", shard_id, " local_i ", local_i)
|
| 93 |
+
|
| 94 |
+
self.max_pair = max_pair
|
| 95 |
+
|
| 96 |
+
def __len__(self):
|
| 97 |
+
return len(self.index)
|
| 98 |
+
|
| 99 |
+
def __getitem__(self, idx):
|
| 100 |
+
"""
|
| 101 |
+
Outputs:
|
| 102 |
+
- prompts: List of Strings
|
| 103 |
+
- latents: Tensor of shape (num_denoising_steps, num_frames, num_channels, height, width). It is ordered from pure noise to clean image.
|
| 104 |
+
"""
|
| 105 |
+
shard_id, local_idx = self.index[idx]
|
| 106 |
+
|
| 107 |
+
latents = retrieve_row_from_lmdb(
|
| 108 |
+
self.envs[shard_id],
|
| 109 |
+
"latents", np.float16, local_idx,
|
| 110 |
+
shape=self.latents_shape[shard_id][1:]
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
if len(latents.shape) == 4:
|
| 114 |
+
latents = latents[None, ...]
|
| 115 |
+
|
| 116 |
+
prompts = retrieve_row_from_lmdb(
|
| 117 |
+
self.envs[shard_id],
|
| 118 |
+
"prompts", str, local_idx
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
return {
|
| 122 |
+
"prompts": prompts,
|
| 123 |
+
"ode_latent": torch.tensor(latents, dtype=torch.float32)
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
class TextImagePairDataset(Dataset):
|
| 128 |
+
def __init__(
|
| 129 |
+
self,
|
| 130 |
+
data_dir,
|
| 131 |
+
transform=None,
|
| 132 |
+
eval_first_n=-1,
|
| 133 |
+
pad_to_multiple_of=None
|
| 134 |
+
):
|
| 135 |
+
"""
|
| 136 |
+
Args:
|
| 137 |
+
data_dir (str): Path to the directory containing:
|
| 138 |
+
- target_crop_info_*.json (metadata file)
|
| 139 |
+
- */ (subdirectory containing images with matching aspect ratio)
|
| 140 |
+
transform (callable, optional): Optional transform to be applied on the image
|
| 141 |
+
"""
|
| 142 |
+
self.transform = transform
|
| 143 |
+
data_dir = Path(data_dir)
|
| 144 |
+
|
| 145 |
+
# Find the metadata JSON file
|
| 146 |
+
metadata_files = list(data_dir.glob('target_crop_info_*.json'))
|
| 147 |
+
if not metadata_files:
|
| 148 |
+
raise FileNotFoundError(f"No metadata file found in {data_dir}")
|
| 149 |
+
if len(metadata_files) > 1:
|
| 150 |
+
raise ValueError(f"Multiple metadata files found in {data_dir}")
|
| 151 |
+
|
| 152 |
+
metadata_path = metadata_files[0]
|
| 153 |
+
# Extract aspect ratio from metadata filename (e.g. target_crop_info_26-15.json -> 26-15)
|
| 154 |
+
aspect_ratio = metadata_path.stem.split('_')[-1]
|
| 155 |
+
|
| 156 |
+
# Use aspect ratio subfolder for images
|
| 157 |
+
self.image_dir = data_dir / aspect_ratio
|
| 158 |
+
if not self.image_dir.exists():
|
| 159 |
+
raise FileNotFoundError(f"Image directory not found: {self.image_dir}")
|
| 160 |
+
|
| 161 |
+
# Load metadata
|
| 162 |
+
with open(metadata_path, 'r') as f:
|
| 163 |
+
self.metadata = json.load(f)
|
| 164 |
+
|
| 165 |
+
eval_first_n = eval_first_n if eval_first_n != -1 else len(self.metadata)
|
| 166 |
+
self.metadata = self.metadata[:eval_first_n]
|
| 167 |
+
|
| 168 |
+
# Verify all images exist
|
| 169 |
+
for item in self.metadata:
|
| 170 |
+
image_path = self.image_dir / item['file_name']
|
| 171 |
+
if not image_path.exists():
|
| 172 |
+
raise FileNotFoundError(f"Image not found: {image_path}")
|
| 173 |
+
|
| 174 |
+
self.dummy_prompt = "DUMMY PROMPT"
|
| 175 |
+
self.pre_pad_len = len(self.metadata)
|
| 176 |
+
if pad_to_multiple_of is not None and len(self.metadata) % pad_to_multiple_of != 0:
|
| 177 |
+
# Duplicate the last entry
|
| 178 |
+
self.metadata += [self.metadata[-1]] * (
|
| 179 |
+
pad_to_multiple_of - len(self.metadata) % pad_to_multiple_of
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
def __len__(self):
|
| 183 |
+
return len(self.metadata)
|
| 184 |
+
|
| 185 |
+
def __getitem__(self, idx):
|
| 186 |
+
"""
|
| 187 |
+
Returns:
|
| 188 |
+
dict: A dictionary containing:
|
| 189 |
+
- image: PIL Image
|
| 190 |
+
- caption: str
|
| 191 |
+
- target_bbox: list of int [x1, y1, x2, y2]
|
| 192 |
+
- target_ratio: str
|
| 193 |
+
- type: str
|
| 194 |
+
- origin_size: tuple of int (width, height)
|
| 195 |
+
"""
|
| 196 |
+
item = self.metadata[idx]
|
| 197 |
+
|
| 198 |
+
# Load image
|
| 199 |
+
image_path = self.image_dir / item['file_name']
|
| 200 |
+
image = Image.open(image_path).convert('RGB')
|
| 201 |
+
|
| 202 |
+
# Apply transform if specified
|
| 203 |
+
if self.transform:
|
| 204 |
+
image = self.transform(image)
|
| 205 |
+
|
| 206 |
+
return {
|
| 207 |
+
'image': image,
|
| 208 |
+
'prompts': item['caption'],
|
| 209 |
+
'target_bbox': item['target_crop']['target_bbox'],
|
| 210 |
+
'target_ratio': item['target_crop']['target_ratio'],
|
| 211 |
+
'type': item['type'],
|
| 212 |
+
'origin_size': (item['origin_width'], item['origin_height']),
|
| 213 |
+
'idx': idx
|
| 214 |
+
}
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def cycle(dl):
|
| 218 |
+
while True:
|
| 219 |
+
for data in dl:
|
| 220 |
+
yield data
|
utils/distributed.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from datetime import timedelta
|
| 2 |
+
from functools import partial
|
| 3 |
+
import os
|
| 4 |
+
import torch
|
| 5 |
+
import torch.distributed as dist
|
| 6 |
+
from torch.distributed.fsdp import FullStateDictConfig, FullyShardedDataParallel as FSDP, MixedPrecision, ShardingStrategy, StateDictType
|
| 7 |
+
from torch.distributed.fsdp.api import CPUOffload
|
| 8 |
+
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def fsdp_state_dict(model):
|
| 12 |
+
fsdp_fullstate_save_policy = FullStateDictConfig(
|
| 13 |
+
offload_to_cpu=True, rank0_only=True
|
| 14 |
+
)
|
| 15 |
+
with FSDP.state_dict_type(
|
| 16 |
+
model, StateDictType.FULL_STATE_DICT, fsdp_fullstate_save_policy
|
| 17 |
+
):
|
| 18 |
+
checkpoint = model.state_dict()
|
| 19 |
+
|
| 20 |
+
return checkpoint
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def fsdp_wrap(module, sharding_strategy="full", mixed_precision=False, wrap_strategy="size", min_num_params=int(5e7), transformer_module=None, cpu_offload=False):
|
| 24 |
+
if mixed_precision:
|
| 25 |
+
mixed_precision_policy = MixedPrecision(
|
| 26 |
+
param_dtype=torch.bfloat16,
|
| 27 |
+
reduce_dtype=torch.float32,
|
| 28 |
+
buffer_dtype=torch.float32,
|
| 29 |
+
cast_forward_inputs=False
|
| 30 |
+
)
|
| 31 |
+
else:
|
| 32 |
+
mixed_precision_policy = None
|
| 33 |
+
|
| 34 |
+
if wrap_strategy == "transformer":
|
| 35 |
+
auto_wrap_policy = partial(
|
| 36 |
+
transformer_auto_wrap_policy,
|
| 37 |
+
transformer_layer_cls=transformer_module
|
| 38 |
+
)
|
| 39 |
+
elif wrap_strategy == "size":
|
| 40 |
+
auto_wrap_policy = partial(
|
| 41 |
+
size_based_auto_wrap_policy,
|
| 42 |
+
min_num_params=min_num_params
|
| 43 |
+
)
|
| 44 |
+
else:
|
| 45 |
+
raise ValueError(f"Invalid wrap strategy: {wrap_strategy}")
|
| 46 |
+
|
| 47 |
+
os.environ["NCCL_CROSS_NIC"] = "1"
|
| 48 |
+
|
| 49 |
+
sharding_strategy = {
|
| 50 |
+
"full": ShardingStrategy.FULL_SHARD,
|
| 51 |
+
"hybrid_full": ShardingStrategy.HYBRID_SHARD,
|
| 52 |
+
"hybrid_zero2": ShardingStrategy._HYBRID_SHARD_ZERO2,
|
| 53 |
+
"no_shard": ShardingStrategy.NO_SHARD,
|
| 54 |
+
}[sharding_strategy]
|
| 55 |
+
|
| 56 |
+
module = FSDP(
|
| 57 |
+
module,
|
| 58 |
+
auto_wrap_policy=auto_wrap_policy,
|
| 59 |
+
sharding_strategy=sharding_strategy,
|
| 60 |
+
mixed_precision=mixed_precision_policy,
|
| 61 |
+
device_id=torch.cuda.current_device(),
|
| 62 |
+
limit_all_gathers=True,
|
| 63 |
+
use_orig_params=True,
|
| 64 |
+
cpu_offload=CPUOffload(offload_params=cpu_offload),
|
| 65 |
+
sync_module_states=False # Load ckpt on rank 0 and sync to other ranks
|
| 66 |
+
)
|
| 67 |
+
return module
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def barrier():
|
| 71 |
+
if dist.is_initialized():
|
| 72 |
+
dist.barrier()
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def launch_distributed_job(backend: str = "nccl"):
|
| 76 |
+
rank = int(os.environ["RANK"])
|
| 77 |
+
local_rank = int(os.environ["LOCAL_RANK"])
|
| 78 |
+
world_size = int(os.environ["WORLD_SIZE"])
|
| 79 |
+
host = os.environ["MASTER_ADDR"]
|
| 80 |
+
port = int(os.environ["MASTER_PORT"])
|
| 81 |
+
|
| 82 |
+
if ":" in host: # IPv6
|
| 83 |
+
init_method = f"tcp://[{host}]:{port}"
|
| 84 |
+
else: # IPv4
|
| 85 |
+
init_method = f"tcp://{host}:{port}"
|
| 86 |
+
dist.init_process_group(rank=rank, world_size=world_size, backend=backend,
|
| 87 |
+
init_method=init_method, timeout=timedelta(minutes=30))
|
| 88 |
+
torch.cuda.set_device(local_rank)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class EMA_FSDP:
|
| 92 |
+
def __init__(self, fsdp_module: torch.nn.Module, decay: float = 0.999):
|
| 93 |
+
self.decay = decay
|
| 94 |
+
self.shadow = {}
|
| 95 |
+
self._init_shadow(fsdp_module)
|
| 96 |
+
|
| 97 |
+
@torch.no_grad()
|
| 98 |
+
def _init_shadow(self, fsdp_module):
|
| 99 |
+
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
| 100 |
+
with FSDP.summon_full_params(fsdp_module, writeback=False):
|
| 101 |
+
for n, p in fsdp_module.module.named_parameters():
|
| 102 |
+
self.shadow[n] = p.detach().clone().float().cpu()
|
| 103 |
+
|
| 104 |
+
@torch.no_grad()
|
| 105 |
+
def update(self, fsdp_module):
|
| 106 |
+
d = self.decay
|
| 107 |
+
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
| 108 |
+
with FSDP.summon_full_params(fsdp_module, writeback=False):
|
| 109 |
+
for n, p in fsdp_module.module.named_parameters():
|
| 110 |
+
self.shadow[n].mul_(d).add_(p.detach().float().cpu(), alpha=1. - d)
|
| 111 |
+
|
| 112 |
+
# Optional helpers ---------------------------------------------------
|
| 113 |
+
def state_dict(self):
|
| 114 |
+
return self.shadow # picklable
|
| 115 |
+
|
| 116 |
+
def load_state_dict(self, sd):
|
| 117 |
+
self.shadow = {k: v.clone() for k, v in sd.items()}
|
| 118 |
+
|
| 119 |
+
def copy_to(self, fsdp_module):
|
| 120 |
+
# load EMA weights into an (unwrapped) copy of the generator
|
| 121 |
+
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
| 122 |
+
with FSDP.summon_full_params(fsdp_module, writeback=True):
|
| 123 |
+
for n, p in fsdp_module.module.named_parameters():
|
| 124 |
+
if n in self.shadow:
|
| 125 |
+
p.data.copy_(self.shadow[n].to(p.dtype, device=p.device))
|
utils/lmdb.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def get_array_shape_from_lmdb(env, array_name):
|
| 5 |
+
with env.begin() as txn:
|
| 6 |
+
image_shape = txn.get(f"{array_name}_shape".encode()).decode()
|
| 7 |
+
image_shape = tuple(map(int, image_shape.split()))
|
| 8 |
+
return image_shape
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def store_arrays_to_lmdb(env, arrays_dict, start_index=0):
|
| 12 |
+
"""
|
| 13 |
+
Store rows of multiple numpy arrays in a single LMDB.
|
| 14 |
+
Each row is stored separately with a naming convention.
|
| 15 |
+
"""
|
| 16 |
+
with env.begin(write=True) as txn:
|
| 17 |
+
for array_name, array in arrays_dict.items():
|
| 18 |
+
for i, row in enumerate(array):
|
| 19 |
+
# Convert row to bytes
|
| 20 |
+
if isinstance(row, str):
|
| 21 |
+
row_bytes = row.encode()
|
| 22 |
+
else:
|
| 23 |
+
row_bytes = row.tobytes()
|
| 24 |
+
|
| 25 |
+
data_key = f'{array_name}_{start_index + i}_data'.encode()
|
| 26 |
+
|
| 27 |
+
txn.put(data_key, row_bytes)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def process_data_dict(data_dict, seen_prompts):
|
| 31 |
+
output_dict = {}
|
| 32 |
+
|
| 33 |
+
all_videos = []
|
| 34 |
+
all_prompts = []
|
| 35 |
+
for prompt, video in data_dict.items():
|
| 36 |
+
if prompt in seen_prompts:
|
| 37 |
+
continue
|
| 38 |
+
else:
|
| 39 |
+
seen_prompts.add(prompt)
|
| 40 |
+
|
| 41 |
+
video = video.half().numpy()
|
| 42 |
+
all_videos.append(video)
|
| 43 |
+
all_prompts.append(prompt)
|
| 44 |
+
|
| 45 |
+
if len(all_videos) == 0:
|
| 46 |
+
return {"latents": np.array([]), "prompts": np.array([])}
|
| 47 |
+
|
| 48 |
+
all_videos = np.concatenate(all_videos, axis=0)
|
| 49 |
+
|
| 50 |
+
output_dict['latents'] = all_videos
|
| 51 |
+
output_dict['prompts'] = np.array(all_prompts)
|
| 52 |
+
|
| 53 |
+
return output_dict
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def retrieve_row_from_lmdb(lmdb_env, array_name, dtype, row_index, shape=None):
|
| 57 |
+
"""
|
| 58 |
+
Retrieve a specific row from a specific array in the LMDB.
|
| 59 |
+
"""
|
| 60 |
+
data_key = f'{array_name}_{row_index}_data'.encode()
|
| 61 |
+
|
| 62 |
+
with lmdb_env.begin() as txn:
|
| 63 |
+
row_bytes = txn.get(data_key)
|
| 64 |
+
|
| 65 |
+
if dtype == str:
|
| 66 |
+
array = row_bytes.decode()
|
| 67 |
+
else:
|
| 68 |
+
array = np.frombuffer(row_bytes, dtype=dtype)
|
| 69 |
+
|
| 70 |
+
if shape is not None and len(shape) > 0:
|
| 71 |
+
array = array.reshape(shape)
|
| 72 |
+
return array
|
utils/loss.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class DenoisingLoss(ABC):
|
| 6 |
+
@abstractmethod
|
| 7 |
+
def __call__(
|
| 8 |
+
self, x: torch.Tensor, x_pred: torch.Tensor,
|
| 9 |
+
noise: torch.Tensor, noise_pred: torch.Tensor,
|
| 10 |
+
alphas_cumprod: torch.Tensor,
|
| 11 |
+
timestep: torch.Tensor,
|
| 12 |
+
**kwargs
|
| 13 |
+
) -> torch.Tensor:
|
| 14 |
+
"""
|
| 15 |
+
Base class for denoising loss.
|
| 16 |
+
Input:
|
| 17 |
+
- x: the clean data with shape [B, F, C, H, W]
|
| 18 |
+
- x_pred: the predicted clean data with shape [B, F, C, H, W]
|
| 19 |
+
- noise: the noise with shape [B, F, C, H, W]
|
| 20 |
+
- noise_pred: the predicted noise with shape [B, F, C, H, W]
|
| 21 |
+
- alphas_cumprod: the cumulative product of alphas (defining the noise schedule) with shape [T]
|
| 22 |
+
- timestep: the current timestep with shape [B, F]
|
| 23 |
+
"""
|
| 24 |
+
pass
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class X0PredLoss(DenoisingLoss):
|
| 28 |
+
def __call__(
|
| 29 |
+
self, x: torch.Tensor, x_pred: torch.Tensor,
|
| 30 |
+
noise: torch.Tensor, noise_pred: torch.Tensor,
|
| 31 |
+
alphas_cumprod: torch.Tensor,
|
| 32 |
+
timestep: torch.Tensor,
|
| 33 |
+
**kwargs
|
| 34 |
+
) -> torch.Tensor:
|
| 35 |
+
return torch.mean((x - x_pred) ** 2)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class VPredLoss(DenoisingLoss):
|
| 39 |
+
def __call__(
|
| 40 |
+
self, x: torch.Tensor, x_pred: torch.Tensor,
|
| 41 |
+
noise: torch.Tensor, noise_pred: torch.Tensor,
|
| 42 |
+
alphas_cumprod: torch.Tensor,
|
| 43 |
+
timestep: torch.Tensor,
|
| 44 |
+
**kwargs
|
| 45 |
+
) -> torch.Tensor:
|
| 46 |
+
weights = 1 / (1 - alphas_cumprod[timestep].reshape(*timestep.shape, 1, 1, 1))
|
| 47 |
+
return torch.mean(weights * (x - x_pred) ** 2)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class NoisePredLoss(DenoisingLoss):
|
| 51 |
+
def __call__(
|
| 52 |
+
self, x: torch.Tensor, x_pred: torch.Tensor,
|
| 53 |
+
noise: torch.Tensor, noise_pred: torch.Tensor,
|
| 54 |
+
alphas_cumprod: torch.Tensor,
|
| 55 |
+
timestep: torch.Tensor,
|
| 56 |
+
**kwargs
|
| 57 |
+
) -> torch.Tensor:
|
| 58 |
+
return torch.mean((noise - noise_pred) ** 2)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class FlowPredLoss(DenoisingLoss):
|
| 62 |
+
def __call__(
|
| 63 |
+
self, x: torch.Tensor, x_pred: torch.Tensor,
|
| 64 |
+
noise: torch.Tensor, noise_pred: torch.Tensor,
|
| 65 |
+
alphas_cumprod: torch.Tensor,
|
| 66 |
+
timestep: torch.Tensor,
|
| 67 |
+
**kwargs
|
| 68 |
+
) -> torch.Tensor:
|
| 69 |
+
return torch.mean((kwargs["flow_pred"] - (noise - x)) ** 2)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
NAME_TO_CLASS = {
|
| 73 |
+
"x0": X0PredLoss,
|
| 74 |
+
"v": VPredLoss,
|
| 75 |
+
"noise": NoisePredLoss,
|
| 76 |
+
"flow": FlowPredLoss
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def get_denoising_loss(loss_type: str) -> DenoisingLoss:
|
| 81 |
+
return NAME_TO_CLASS[loss_type]
|
utils/misc.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import random
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def set_seed(seed: int, deterministic: bool = False):
|
| 7 |
+
"""
|
| 8 |
+
Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`.
|
| 9 |
+
|
| 10 |
+
Args:
|
| 11 |
+
seed (`int`):
|
| 12 |
+
The seed to set.
|
| 13 |
+
deterministic (`bool`, *optional*, defaults to `False`):
|
| 14 |
+
Whether to use deterministic algorithms where available. Can slow down training.
|
| 15 |
+
"""
|
| 16 |
+
random.seed(seed)
|
| 17 |
+
np.random.seed(seed)
|
| 18 |
+
torch.manual_seed(seed)
|
| 19 |
+
torch.cuda.manual_seed_all(seed)
|
| 20 |
+
|
| 21 |
+
if deterministic:
|
| 22 |
+
torch.use_deterministic_algorithms(True)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def merge_dict_list(dict_list):
|
| 26 |
+
if len(dict_list) == 1:
|
| 27 |
+
return dict_list[0]
|
| 28 |
+
|
| 29 |
+
merged_dict = {}
|
| 30 |
+
for k, v in dict_list[0].items():
|
| 31 |
+
if isinstance(v, torch.Tensor):
|
| 32 |
+
if v.ndim == 0:
|
| 33 |
+
merged_dict[k] = torch.stack([d[k] for d in dict_list], dim=0)
|
| 34 |
+
else:
|
| 35 |
+
merged_dict[k] = torch.cat([d[k] for d in dict_list], dim=0)
|
| 36 |
+
else:
|
| 37 |
+
# for non-tensor values, we just copy the value from the first item
|
| 38 |
+
merged_dict[k] = v
|
| 39 |
+
return merged_dict
|
utils/scheduler.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import abstractmethod, ABC
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class SchedulerInterface(ABC):
|
| 6 |
+
"""
|
| 7 |
+
Base class for diffusion noise schedule.
|
| 8 |
+
"""
|
| 9 |
+
alphas_cumprod: torch.Tensor # [T], alphas for defining the noise schedule
|
| 10 |
+
|
| 11 |
+
@abstractmethod
|
| 12 |
+
def add_noise(
|
| 13 |
+
self, clean_latent: torch.Tensor,
|
| 14 |
+
noise: torch.Tensor, timestep: torch.Tensor
|
| 15 |
+
):
|
| 16 |
+
"""
|
| 17 |
+
Diffusion forward corruption process.
|
| 18 |
+
Input:
|
| 19 |
+
- clean_latent: the clean latent with shape [B, C, H, W]
|
| 20 |
+
- noise: the noise with shape [B, C, H, W]
|
| 21 |
+
- timestep: the timestep with shape [B]
|
| 22 |
+
Output: the corrupted latent with shape [B, C, H, W]
|
| 23 |
+
"""
|
| 24 |
+
pass
|
| 25 |
+
|
| 26 |
+
def convert_x0_to_noise(
|
| 27 |
+
self, x0: torch.Tensor, xt: torch.Tensor,
|
| 28 |
+
timestep: torch.Tensor
|
| 29 |
+
) -> torch.Tensor:
|
| 30 |
+
"""
|
| 31 |
+
Convert the diffusion network's x0 prediction to noise predidction.
|
| 32 |
+
x0: the predicted clean data with shape [B, C, H, W]
|
| 33 |
+
xt: the input noisy data with shape [B, C, H, W]
|
| 34 |
+
timestep: the timestep with shape [B]
|
| 35 |
+
|
| 36 |
+
noise = (xt-sqrt(alpha_t)*x0) / sqrt(beta_t) (eq 11 in https://arxiv.org/abs/2311.18828)
|
| 37 |
+
"""
|
| 38 |
+
# use higher precision for calculations
|
| 39 |
+
original_dtype = x0.dtype
|
| 40 |
+
x0, xt, alphas_cumprod = map(
|
| 41 |
+
lambda x: x.double().to(x0.device), [x0, xt,
|
| 42 |
+
self.alphas_cumprod]
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
alpha_prod_t = alphas_cumprod[timestep].reshape(-1, 1, 1, 1)
|
| 46 |
+
beta_prod_t = 1 - alpha_prod_t
|
| 47 |
+
|
| 48 |
+
noise_pred = (xt - alpha_prod_t **
|
| 49 |
+
(0.5) * x0) / beta_prod_t ** (0.5)
|
| 50 |
+
return noise_pred.to(original_dtype)
|
| 51 |
+
|
| 52 |
+
def convert_noise_to_x0(
|
| 53 |
+
self, noise: torch.Tensor, xt: torch.Tensor,
|
| 54 |
+
timestep: torch.Tensor
|
| 55 |
+
) -> torch.Tensor:
|
| 56 |
+
"""
|
| 57 |
+
Convert the diffusion network's noise prediction to x0 predidction.
|
| 58 |
+
noise: the predicted noise with shape [B, C, H, W]
|
| 59 |
+
xt: the input noisy data with shape [B, C, H, W]
|
| 60 |
+
timestep: the timestep with shape [B]
|
| 61 |
+
|
| 62 |
+
x0 = (x_t - sqrt(beta_t) * noise) / sqrt(alpha_t) (eq 11 in https://arxiv.org/abs/2311.18828)
|
| 63 |
+
"""
|
| 64 |
+
# use higher precision for calculations
|
| 65 |
+
original_dtype = noise.dtype
|
| 66 |
+
noise, xt, alphas_cumprod = map(
|
| 67 |
+
lambda x: x.double().to(noise.device), [noise, xt,
|
| 68 |
+
self.alphas_cumprod]
|
| 69 |
+
)
|
| 70 |
+
alpha_prod_t = alphas_cumprod[timestep].reshape(-1, 1, 1, 1)
|
| 71 |
+
beta_prod_t = 1 - alpha_prod_t
|
| 72 |
+
|
| 73 |
+
x0_pred = (xt - beta_prod_t **
|
| 74 |
+
(0.5) * noise) / alpha_prod_t ** (0.5)
|
| 75 |
+
return x0_pred.to(original_dtype)
|
| 76 |
+
|
| 77 |
+
def convert_velocity_to_x0(
|
| 78 |
+
self, velocity: torch.Tensor, xt: torch.Tensor,
|
| 79 |
+
timestep: torch.Tensor
|
| 80 |
+
) -> torch.Tensor:
|
| 81 |
+
"""
|
| 82 |
+
Convert the diffusion network's velocity prediction to x0 predidction.
|
| 83 |
+
velocity: the predicted noise with shape [B, C, H, W]
|
| 84 |
+
xt: the input noisy data with shape [B, C, H, W]
|
| 85 |
+
timestep: the timestep with shape [B]
|
| 86 |
+
|
| 87 |
+
v = sqrt(alpha_t) * noise - sqrt(beta_t) x0
|
| 88 |
+
noise = (xt-sqrt(alpha_t)*x0) / sqrt(beta_t)
|
| 89 |
+
given v, x_t, we have
|
| 90 |
+
x0 = sqrt(alpha_t) * x_t - sqrt(beta_t) * v
|
| 91 |
+
see derivations https://chatgpt.com/share/679fb6c8-3a30-8008-9b0e-d1ae892dac56
|
| 92 |
+
"""
|
| 93 |
+
# use higher precision for calculations
|
| 94 |
+
original_dtype = velocity.dtype
|
| 95 |
+
velocity, xt, alphas_cumprod = map(
|
| 96 |
+
lambda x: x.double().to(velocity.device), [velocity, xt,
|
| 97 |
+
self.alphas_cumprod]
|
| 98 |
+
)
|
| 99 |
+
alpha_prod_t = alphas_cumprod[timestep].reshape(-1, 1, 1, 1)
|
| 100 |
+
beta_prod_t = 1 - alpha_prod_t
|
| 101 |
+
|
| 102 |
+
x0_pred = (alpha_prod_t ** 0.5) * xt - (beta_prod_t ** 0.5) * velocity
|
| 103 |
+
return x0_pred.to(original_dtype)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
class FlowMatchScheduler():
|
| 107 |
+
|
| 108 |
+
def __init__(self, num_inference_steps=100, num_train_timesteps=1000, shift=3.0, sigma_max=1.0, sigma_min=0.003 / 1.002, inverse_timesteps=False, extra_one_step=False, reverse_sigmas=False):
|
| 109 |
+
self.num_train_timesteps = num_train_timesteps
|
| 110 |
+
self.shift = shift
|
| 111 |
+
self.sigma_max = sigma_max
|
| 112 |
+
self.sigma_min = sigma_min
|
| 113 |
+
self.inverse_timesteps = inverse_timesteps
|
| 114 |
+
self.extra_one_step = extra_one_step
|
| 115 |
+
self.reverse_sigmas = reverse_sigmas
|
| 116 |
+
self.set_timesteps(num_inference_steps)
|
| 117 |
+
|
| 118 |
+
def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False):
|
| 119 |
+
sigma_start = self.sigma_min + \
|
| 120 |
+
(self.sigma_max - self.sigma_min) * denoising_strength
|
| 121 |
+
if self.extra_one_step:
|
| 122 |
+
self.sigmas = torch.linspace(
|
| 123 |
+
sigma_start, self.sigma_min, num_inference_steps + 1)[:-1]
|
| 124 |
+
else:
|
| 125 |
+
self.sigmas = torch.linspace(
|
| 126 |
+
sigma_start, self.sigma_min, num_inference_steps)
|
| 127 |
+
if self.inverse_timesteps:
|
| 128 |
+
self.sigmas = torch.flip(self.sigmas, dims=[0])
|
| 129 |
+
self.sigmas = self.shift * self.sigmas / \
|
| 130 |
+
(1 + (self.shift - 1) * self.sigmas)
|
| 131 |
+
if self.reverse_sigmas:
|
| 132 |
+
self.sigmas = 1 - self.sigmas
|
| 133 |
+
self.timesteps = self.sigmas * self.num_train_timesteps
|
| 134 |
+
if training:
|
| 135 |
+
x = self.timesteps
|
| 136 |
+
y = torch.exp(-2 * ((x - num_inference_steps / 2) /
|
| 137 |
+
num_inference_steps) ** 2)
|
| 138 |
+
y_shifted = y - y.min()
|
| 139 |
+
bsmntw_weighing = y_shifted * \
|
| 140 |
+
(num_inference_steps / y_shifted.sum())
|
| 141 |
+
self.linear_timesteps_weights = bsmntw_weighing
|
| 142 |
+
|
| 143 |
+
def step(self, model_output, timestep, sample, to_final=False):
|
| 144 |
+
if timestep.ndim == 2:
|
| 145 |
+
timestep = timestep.flatten(0, 1)
|
| 146 |
+
self.sigmas = self.sigmas.to(model_output.device)
|
| 147 |
+
self.timesteps = self.timesteps.to(model_output.device)
|
| 148 |
+
timestep_id = torch.argmin(
|
| 149 |
+
(self.timesteps.unsqueeze(0) - timestep.unsqueeze(1)).abs(), dim=1)
|
| 150 |
+
sigma = self.sigmas[timestep_id].reshape(-1, 1, 1, 1)
|
| 151 |
+
if to_final or (timestep_id + 1 >= len(self.timesteps)).any():
|
| 152 |
+
sigma_ = 1 if (
|
| 153 |
+
self.inverse_timesteps or self.reverse_sigmas) else 0
|
| 154 |
+
else:
|
| 155 |
+
sigma_ = self.sigmas[timestep_id + 1].reshape(-1, 1, 1, 1)
|
| 156 |
+
prev_sample = sample + model_output * (sigma_ - sigma)
|
| 157 |
+
return prev_sample
|
| 158 |
+
|
| 159 |
+
def add_noise(self, original_samples, noise, timestep):
|
| 160 |
+
"""
|
| 161 |
+
Diffusion forward corruption process.
|
| 162 |
+
Input:
|
| 163 |
+
- clean_latent: the clean latent with shape [B*T, C, H, W]
|
| 164 |
+
- noise: the noise with shape [B*T, C, H, W]
|
| 165 |
+
- timestep: the timestep with shape [B*T]
|
| 166 |
+
Output: the corrupted latent with shape [B*T, C, H, W]
|
| 167 |
+
"""
|
| 168 |
+
if timestep.ndim == 2:
|
| 169 |
+
timestep = timestep.flatten(0, 1)
|
| 170 |
+
self.sigmas = self.sigmas.to(noise.device)
|
| 171 |
+
self.timesteps = self.timesteps.to(noise.device)
|
| 172 |
+
timestep_id = torch.argmin(
|
| 173 |
+
(self.timesteps.unsqueeze(0) - timestep.unsqueeze(1)).abs(), dim=1)
|
| 174 |
+
sigma = self.sigmas[timestep_id].reshape(-1, 1, 1, 1)
|
| 175 |
+
sample = (1 - sigma) * original_samples + sigma * noise
|
| 176 |
+
return sample.type_as(noise)
|
| 177 |
+
|
| 178 |
+
def training_target(self, sample, noise, timestep):
|
| 179 |
+
target = noise - sample
|
| 180 |
+
return target
|
| 181 |
+
|
| 182 |
+
def training_weight(self, timestep):
|
| 183 |
+
"""
|
| 184 |
+
Input:
|
| 185 |
+
- timestep: the timestep with shape [B*T]
|
| 186 |
+
Output: the corresponding weighting [B*T]
|
| 187 |
+
"""
|
| 188 |
+
if timestep.ndim == 2:
|
| 189 |
+
timestep = timestep.flatten(0, 1)
|
| 190 |
+
self.linear_timesteps_weights = self.linear_timesteps_weights.to(timestep.device)
|
| 191 |
+
timestep_id = torch.argmin(
|
| 192 |
+
(self.timesteps.unsqueeze(1) - timestep.unsqueeze(0)).abs(), dim=0)
|
| 193 |
+
weights = self.linear_timesteps_weights[timestep_id]
|
| 194 |
+
return weights
|
utils/wan_wrapper.py
ADDED
|
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import types
|
| 2 |
+
from typing import List, Optional
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
|
| 6 |
+
from utils.scheduler import SchedulerInterface, FlowMatchScheduler
|
| 7 |
+
from wan.modules.tokenizers import HuggingfaceTokenizer
|
| 8 |
+
from wan.modules.model import WanModel, RegisterTokens, GanAttentionBlock
|
| 9 |
+
from wan.modules.vae import _video_vae
|
| 10 |
+
from wan.modules.t5 import umt5_xxl
|
| 11 |
+
from wan.modules.causal_model import CausalWanModel
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class WanTextEncoder(torch.nn.Module):
|
| 15 |
+
def __init__(self) -> None:
|
| 16 |
+
super().__init__()
|
| 17 |
+
|
| 18 |
+
self.text_encoder = umt5_xxl(
|
| 19 |
+
encoder_only=True,
|
| 20 |
+
return_tokenizer=False,
|
| 21 |
+
dtype=torch.float32,
|
| 22 |
+
device=torch.device('cpu')
|
| 23 |
+
).eval().requires_grad_(False)
|
| 24 |
+
self.text_encoder.load_state_dict(
|
| 25 |
+
torch.load("wan_models/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth",
|
| 26 |
+
map_location='cpu', weights_only=False)
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
self.tokenizer = HuggingfaceTokenizer(
|
| 30 |
+
name="wan_models/Wan2.1-T2V-1.3B/google/umt5-xxl/", seq_len=512, clean='whitespace')
|
| 31 |
+
|
| 32 |
+
@property
|
| 33 |
+
def device(self):
|
| 34 |
+
# Assume we are always on GPU
|
| 35 |
+
return torch.cuda.current_device()
|
| 36 |
+
|
| 37 |
+
def forward(self, text_prompts: List[str]) -> dict:
|
| 38 |
+
ids, mask = self.tokenizer(
|
| 39 |
+
text_prompts, return_mask=True, add_special_tokens=True)
|
| 40 |
+
ids = ids.to(self.device)
|
| 41 |
+
mask = mask.to(self.device)
|
| 42 |
+
seq_lens = mask.gt(0).sum(dim=1).long()
|
| 43 |
+
context = self.text_encoder(ids, mask)
|
| 44 |
+
|
| 45 |
+
for u, v in zip(context, seq_lens):
|
| 46 |
+
u[v:] = 0.0 # set padding to 0.0
|
| 47 |
+
|
| 48 |
+
return {
|
| 49 |
+
"prompt_embeds": context
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class WanVAEWrapper(torch.nn.Module):
|
| 54 |
+
def __init__(self):
|
| 55 |
+
super().__init__()
|
| 56 |
+
mean = [
|
| 57 |
+
-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
|
| 58 |
+
0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
|
| 59 |
+
]
|
| 60 |
+
std = [
|
| 61 |
+
2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
|
| 62 |
+
3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
|
| 63 |
+
]
|
| 64 |
+
self.mean = torch.tensor(mean, dtype=torch.float32)
|
| 65 |
+
self.std = torch.tensor(std, dtype=torch.float32)
|
| 66 |
+
|
| 67 |
+
# init model
|
| 68 |
+
self.model = _video_vae(
|
| 69 |
+
pretrained_path="wan_models/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
|
| 70 |
+
z_dim=16,
|
| 71 |
+
).eval().requires_grad_(False)
|
| 72 |
+
|
| 73 |
+
def encode_to_latent(self, pixel: torch.Tensor) -> torch.Tensor:
|
| 74 |
+
# pixel: [batch_size, num_channels, num_frames, height, width]
|
| 75 |
+
device, dtype = pixel.device, pixel.dtype
|
| 76 |
+
scale = [self.mean.to(device=device, dtype=dtype),
|
| 77 |
+
1.0 / self.std.to(device=device, dtype=dtype)]
|
| 78 |
+
|
| 79 |
+
output = [
|
| 80 |
+
self.model.encode(u.unsqueeze(0), scale).float().squeeze(0)
|
| 81 |
+
for u in pixel
|
| 82 |
+
]
|
| 83 |
+
output = torch.stack(output, dim=0)
|
| 84 |
+
# from [batch_size, num_channels, num_frames, height, width]
|
| 85 |
+
# to [batch_size, num_frames, num_channels, height, width]
|
| 86 |
+
output = output.permute(0, 2, 1, 3, 4)
|
| 87 |
+
return output
|
| 88 |
+
|
| 89 |
+
def decode_to_pixel(self, latent: torch.Tensor, use_cache: bool = False) -> torch.Tensor:
|
| 90 |
+
# from [batch_size, num_frames, num_channels, height, width]
|
| 91 |
+
# to [batch_size, num_channels, num_frames, height, width]
|
| 92 |
+
zs = latent.permute(0, 2, 1, 3, 4)
|
| 93 |
+
if use_cache:
|
| 94 |
+
assert latent.shape[0] == 1, "Batch size must be 1 when using cache"
|
| 95 |
+
|
| 96 |
+
device, dtype = latent.device, latent.dtype
|
| 97 |
+
scale = [self.mean.to(device=device, dtype=dtype),
|
| 98 |
+
1.0 / self.std.to(device=device, dtype=dtype)]
|
| 99 |
+
|
| 100 |
+
if use_cache:
|
| 101 |
+
decode_function = self.model.cached_decode
|
| 102 |
+
else:
|
| 103 |
+
decode_function = self.model.decode
|
| 104 |
+
|
| 105 |
+
output = []
|
| 106 |
+
for u in zs:
|
| 107 |
+
output.append(decode_function(u.unsqueeze(0), scale).float().clamp_(-1, 1).squeeze(0))
|
| 108 |
+
output = torch.stack(output, dim=0)
|
| 109 |
+
# from [batch_size, num_channels, num_frames, height, width]
|
| 110 |
+
# to [batch_size, num_frames, num_channels, height, width]
|
| 111 |
+
output = output.permute(0, 2, 1, 3, 4)
|
| 112 |
+
return output
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class WanDiffusionWrapper(torch.nn.Module):
|
| 116 |
+
def __init__(
|
| 117 |
+
self,
|
| 118 |
+
model_name="Wan2.1-T2V-1.3B",
|
| 119 |
+
timestep_shift=8.0,
|
| 120 |
+
is_causal=False,
|
| 121 |
+
local_attn_size=-1,
|
| 122 |
+
sink_size=0
|
| 123 |
+
):
|
| 124 |
+
super().__init__()
|
| 125 |
+
|
| 126 |
+
if is_causal:
|
| 127 |
+
self.model = CausalWanModel.from_pretrained(
|
| 128 |
+
f"wan_models/{model_name}/", local_attn_size=local_attn_size, sink_size=sink_size)
|
| 129 |
+
else:
|
| 130 |
+
self.model = WanModel.from_pretrained(f"wan_models/{model_name}/")
|
| 131 |
+
self.model.eval()
|
| 132 |
+
|
| 133 |
+
# For non-causal diffusion, all frames share the same timestep
|
| 134 |
+
self.uniform_timestep = not is_causal
|
| 135 |
+
|
| 136 |
+
self.scheduler = FlowMatchScheduler(
|
| 137 |
+
shift=timestep_shift, sigma_min=0.0, extra_one_step=True
|
| 138 |
+
)
|
| 139 |
+
self.scheduler.set_timesteps(1000, training=True)
|
| 140 |
+
|
| 141 |
+
self.seq_len = 32760 # [1, 21, 16, 60, 104]
|
| 142 |
+
self.post_init()
|
| 143 |
+
|
| 144 |
+
def enable_gradient_checkpointing(self) -> None:
|
| 145 |
+
self.model.enable_gradient_checkpointing()
|
| 146 |
+
|
| 147 |
+
def adding_cls_branch(self, atten_dim=1536, num_class=4, time_embed_dim=0) -> None:
|
| 148 |
+
# NOTE: This is hard coded for WAN2.1-T2V-1.3B for now!!!!!!!!!!!!!!!!!!!!
|
| 149 |
+
self._cls_pred_branch = nn.Sequential(
|
| 150 |
+
# Input: [B, 384, 21, 60, 104]
|
| 151 |
+
nn.LayerNorm(atten_dim * 3 + time_embed_dim),
|
| 152 |
+
nn.Linear(atten_dim * 3 + time_embed_dim, 1536),
|
| 153 |
+
nn.SiLU(),
|
| 154 |
+
nn.Linear(atten_dim, num_class)
|
| 155 |
+
)
|
| 156 |
+
self._cls_pred_branch.requires_grad_(True)
|
| 157 |
+
num_registers = 3
|
| 158 |
+
self._register_tokens = RegisterTokens(num_registers=num_registers, dim=atten_dim)
|
| 159 |
+
self._register_tokens.requires_grad_(True)
|
| 160 |
+
|
| 161 |
+
gan_ca_blocks = []
|
| 162 |
+
for _ in range(num_registers):
|
| 163 |
+
block = GanAttentionBlock()
|
| 164 |
+
gan_ca_blocks.append(block)
|
| 165 |
+
self._gan_ca_blocks = nn.ModuleList(gan_ca_blocks)
|
| 166 |
+
self._gan_ca_blocks.requires_grad_(True)
|
| 167 |
+
# self.has_cls_branch = True
|
| 168 |
+
|
| 169 |
+
def _convert_flow_pred_to_x0(self, flow_pred: torch.Tensor, xt: torch.Tensor, timestep: torch.Tensor) -> torch.Tensor:
|
| 170 |
+
"""
|
| 171 |
+
Convert flow matching's prediction to x0 prediction.
|
| 172 |
+
flow_pred: the prediction with shape [B, C, H, W]
|
| 173 |
+
xt: the input noisy data with shape [B, C, H, W]
|
| 174 |
+
timestep: the timestep with shape [B]
|
| 175 |
+
|
| 176 |
+
pred = noise - x0
|
| 177 |
+
x_t = (1-sigma_t) * x0 + sigma_t * noise
|
| 178 |
+
we have x0 = x_t - sigma_t * pred
|
| 179 |
+
see derivations https://chatgpt.com/share/67bf8589-3d04-8008-bc6e-4cf1a24e2d0e
|
| 180 |
+
"""
|
| 181 |
+
# use higher precision for calculations
|
| 182 |
+
original_dtype = flow_pred.dtype
|
| 183 |
+
flow_pred, xt, sigmas, timesteps = map(
|
| 184 |
+
lambda x: x.double().to(flow_pred.device), [flow_pred, xt,
|
| 185 |
+
self.scheduler.sigmas,
|
| 186 |
+
self.scheduler.timesteps]
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
timestep_id = torch.argmin(
|
| 190 |
+
(timesteps.unsqueeze(0) - timestep.unsqueeze(1)).abs(), dim=1)
|
| 191 |
+
sigma_t = sigmas[timestep_id].reshape(-1, 1, 1, 1)
|
| 192 |
+
x0_pred = xt - sigma_t * flow_pred
|
| 193 |
+
return x0_pred.to(original_dtype)
|
| 194 |
+
|
| 195 |
+
@staticmethod
|
| 196 |
+
def _convert_x0_to_flow_pred(scheduler, x0_pred: torch.Tensor, xt: torch.Tensor, timestep: torch.Tensor) -> torch.Tensor:
|
| 197 |
+
"""
|
| 198 |
+
Convert x0 prediction to flow matching's prediction.
|
| 199 |
+
x0_pred: the x0 prediction with shape [B, C, H, W]
|
| 200 |
+
xt: the input noisy data with shape [B, C, H, W]
|
| 201 |
+
timestep: the timestep with shape [B]
|
| 202 |
+
|
| 203 |
+
pred = (x_t - x_0) / sigma_t
|
| 204 |
+
"""
|
| 205 |
+
# use higher precision for calculations
|
| 206 |
+
original_dtype = x0_pred.dtype
|
| 207 |
+
x0_pred, xt, sigmas, timesteps = map(
|
| 208 |
+
lambda x: x.double().to(x0_pred.device), [x0_pred, xt,
|
| 209 |
+
scheduler.sigmas,
|
| 210 |
+
scheduler.timesteps]
|
| 211 |
+
)
|
| 212 |
+
timestep_id = torch.argmin(
|
| 213 |
+
(timesteps.unsqueeze(0) - timestep.unsqueeze(1)).abs(), dim=1)
|
| 214 |
+
sigma_t = sigmas[timestep_id].reshape(-1, 1, 1, 1)
|
| 215 |
+
flow_pred = (xt - x0_pred) / sigma_t
|
| 216 |
+
return flow_pred.to(original_dtype)
|
| 217 |
+
|
| 218 |
+
def forward(
|
| 219 |
+
self,
|
| 220 |
+
noisy_image_or_video: torch.Tensor, conditional_dict: dict,
|
| 221 |
+
timestep: torch.Tensor, kv_cache: Optional[List[dict]] = None,
|
| 222 |
+
crossattn_cache: Optional[List[dict]] = None,
|
| 223 |
+
current_start: Optional[int] = None,
|
| 224 |
+
classify_mode: Optional[bool] = False,
|
| 225 |
+
concat_time_embeddings: Optional[bool] = False,
|
| 226 |
+
clean_x: Optional[torch.Tensor] = None,
|
| 227 |
+
aug_t: Optional[torch.Tensor] = None,
|
| 228 |
+
cache_start: Optional[int] = None,
|
| 229 |
+
updating_cache: Optional[bool] = False
|
| 230 |
+
) -> torch.Tensor:
|
| 231 |
+
prompt_embeds = conditional_dict["prompt_embeds"]
|
| 232 |
+
|
| 233 |
+
# [B, F] -> [B]
|
| 234 |
+
if self.uniform_timestep:
|
| 235 |
+
input_timestep = timestep[:, 0]
|
| 236 |
+
else:
|
| 237 |
+
input_timestep = timestep
|
| 238 |
+
|
| 239 |
+
logits = None
|
| 240 |
+
# X0 prediction
|
| 241 |
+
if kv_cache is not None:
|
| 242 |
+
flow_pred = self.model(
|
| 243 |
+
noisy_image_or_video.permute(0, 2, 1, 3, 4),
|
| 244 |
+
t=input_timestep, context=prompt_embeds,
|
| 245 |
+
seq_len=self.seq_len,
|
| 246 |
+
kv_cache=kv_cache,
|
| 247 |
+
crossattn_cache=crossattn_cache,
|
| 248 |
+
current_start=current_start,
|
| 249 |
+
cache_start=cache_start,
|
| 250 |
+
updating_cache=updating_cache
|
| 251 |
+
).permute(0, 2, 1, 3, 4)
|
| 252 |
+
else:
|
| 253 |
+
if clean_x is not None:
|
| 254 |
+
# teacher forcing
|
| 255 |
+
flow_pred = self.model(
|
| 256 |
+
noisy_image_or_video.permute(0, 2, 1, 3, 4),
|
| 257 |
+
t=input_timestep, context=prompt_embeds,
|
| 258 |
+
seq_len=self.seq_len,
|
| 259 |
+
clean_x=clean_x.permute(0, 2, 1, 3, 4),
|
| 260 |
+
aug_t=aug_t,
|
| 261 |
+
).permute(0, 2, 1, 3, 4)
|
| 262 |
+
else:
|
| 263 |
+
if classify_mode:
|
| 264 |
+
flow_pred, logits = self.model(
|
| 265 |
+
noisy_image_or_video.permute(0, 2, 1, 3, 4),
|
| 266 |
+
t=input_timestep, context=prompt_embeds,
|
| 267 |
+
seq_len=self.seq_len,
|
| 268 |
+
classify_mode=True,
|
| 269 |
+
register_tokens=self._register_tokens,
|
| 270 |
+
cls_pred_branch=self._cls_pred_branch,
|
| 271 |
+
gan_ca_blocks=self._gan_ca_blocks,
|
| 272 |
+
concat_time_embeddings=concat_time_embeddings
|
| 273 |
+
)
|
| 274 |
+
flow_pred = flow_pred.permute(0, 2, 1, 3, 4)
|
| 275 |
+
else:
|
| 276 |
+
flow_pred = self.model(
|
| 277 |
+
noisy_image_or_video.permute(0, 2, 1, 3, 4),
|
| 278 |
+
t=input_timestep, context=prompt_embeds,
|
| 279 |
+
seq_len=self.seq_len
|
| 280 |
+
).permute(0, 2, 1, 3, 4)
|
| 281 |
+
|
| 282 |
+
pred_x0 = self._convert_flow_pred_to_x0(
|
| 283 |
+
flow_pred=flow_pred.flatten(0, 1),
|
| 284 |
+
xt=noisy_image_or_video.flatten(0, 1),
|
| 285 |
+
timestep=timestep.flatten(0, 1)
|
| 286 |
+
).unflatten(0, flow_pred.shape[:2])
|
| 287 |
+
|
| 288 |
+
if logits is not None:
|
| 289 |
+
return flow_pred, pred_x0, logits
|
| 290 |
+
|
| 291 |
+
return flow_pred, pred_x0
|
| 292 |
+
|
| 293 |
+
def get_scheduler(self) -> SchedulerInterface:
|
| 294 |
+
"""
|
| 295 |
+
Update the current scheduler with the interface's static method
|
| 296 |
+
"""
|
| 297 |
+
scheduler = self.scheduler
|
| 298 |
+
scheduler.convert_x0_to_noise = types.MethodType(
|
| 299 |
+
SchedulerInterface.convert_x0_to_noise, scheduler)
|
| 300 |
+
scheduler.convert_noise_to_x0 = types.MethodType(
|
| 301 |
+
SchedulerInterface.convert_noise_to_x0, scheduler)
|
| 302 |
+
scheduler.convert_velocity_to_x0 = types.MethodType(
|
| 303 |
+
SchedulerInterface.convert_velocity_to_x0, scheduler)
|
| 304 |
+
self.scheduler = scheduler
|
| 305 |
+
return scheduler
|
| 306 |
+
|
| 307 |
+
def post_init(self):
|
| 308 |
+
"""
|
| 309 |
+
A few custom initialization steps that should be called after the object is created.
|
| 310 |
+
Currently, the only one we have is to bind a few methods to scheduler.
|
| 311 |
+
We can gradually add more methods here if needed.
|
| 312 |
+
"""
|
| 313 |
+
self.get_scheduler()
|
wan/README.md
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Code in this folder is modified from https://github.com/Wan-Video/Wan2.1
|
| 2 |
+
Apache-2.0 License
|
wan/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from . import configs, distributed, modules
|
| 2 |
+
from .image2video import WanI2V
|
| 3 |
+
from .text2video import WanT2V
|
wan/configs/__init__.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
from .wan_t2v_14B import t2v_14B
|
| 3 |
+
from .wan_t2v_1_3B import t2v_1_3B
|
| 4 |
+
from .wan_i2v_14B import i2v_14B
|
| 5 |
+
import copy
|
| 6 |
+
import os
|
| 7 |
+
|
| 8 |
+
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
# the config of t2i_14B is the same as t2v_14B
|
| 12 |
+
t2i_14B = copy.deepcopy(t2v_14B)
|
| 13 |
+
t2i_14B.__name__ = 'Config: Wan T2I 14B'
|
| 14 |
+
|
| 15 |
+
WAN_CONFIGS = {
|
| 16 |
+
't2v-14B': t2v_14B,
|
| 17 |
+
't2v-1.3B': t2v_1_3B,
|
| 18 |
+
'i2v-14B': i2v_14B,
|
| 19 |
+
't2i-14B': t2i_14B,
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
SIZE_CONFIGS = {
|
| 23 |
+
'720*1280': (720, 1280),
|
| 24 |
+
'1280*720': (1280, 720),
|
| 25 |
+
'480*832': (480, 832),
|
| 26 |
+
'832*480': (832, 480),
|
| 27 |
+
'1024*1024': (1024, 1024),
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
MAX_AREA_CONFIGS = {
|
| 31 |
+
'720*1280': 720 * 1280,
|
| 32 |
+
'1280*720': 1280 * 720,
|
| 33 |
+
'480*832': 480 * 832,
|
| 34 |
+
'832*480': 832 * 480,
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
SUPPORTED_SIZES = {
|
| 38 |
+
't2v-14B': ('720*1280', '1280*720', '480*832', '832*480'),
|
| 39 |
+
't2v-1.3B': ('480*832', '832*480'),
|
| 40 |
+
'i2v-14B': ('720*1280', '1280*720', '480*832', '832*480'),
|
| 41 |
+
't2i-14B': tuple(SIZE_CONFIGS.keys()),
|
| 42 |
+
}
|
wan/configs/shared_config.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
import torch
|
| 3 |
+
from easydict import EasyDict
|
| 4 |
+
|
| 5 |
+
# ------------------------ Wan shared config ------------------------#
|
| 6 |
+
wan_shared_cfg = EasyDict()
|
| 7 |
+
|
| 8 |
+
# t5
|
| 9 |
+
wan_shared_cfg.t5_model = 'umt5_xxl'
|
| 10 |
+
wan_shared_cfg.t5_dtype = torch.bfloat16
|
| 11 |
+
wan_shared_cfg.text_len = 512
|
| 12 |
+
|
| 13 |
+
# transformer
|
| 14 |
+
wan_shared_cfg.param_dtype = torch.bfloat16
|
| 15 |
+
|
| 16 |
+
# inference
|
| 17 |
+
wan_shared_cfg.num_train_timesteps = 1000
|
| 18 |
+
wan_shared_cfg.sample_fps = 16
|
| 19 |
+
wan_shared_cfg.sample_neg_prompt = '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走'
|
wan/configs/wan_i2v_14B.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
import torch
|
| 3 |
+
from easydict import EasyDict
|
| 4 |
+
|
| 5 |
+
from .shared_config import wan_shared_cfg
|
| 6 |
+
|
| 7 |
+
# ------------------------ Wan I2V 14B ------------------------#
|
| 8 |
+
|
| 9 |
+
i2v_14B = EasyDict(__name__='Config: Wan I2V 14B')
|
| 10 |
+
i2v_14B.update(wan_shared_cfg)
|
| 11 |
+
|
| 12 |
+
i2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
|
| 13 |
+
i2v_14B.t5_tokenizer = 'google/umt5-xxl'
|
| 14 |
+
|
| 15 |
+
# clip
|
| 16 |
+
i2v_14B.clip_model = 'clip_xlm_roberta_vit_h_14'
|
| 17 |
+
i2v_14B.clip_dtype = torch.float16
|
| 18 |
+
i2v_14B.clip_checkpoint = 'models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth'
|
| 19 |
+
i2v_14B.clip_tokenizer = 'xlm-roberta-large'
|
| 20 |
+
|
| 21 |
+
# vae
|
| 22 |
+
i2v_14B.vae_checkpoint = 'Wan2.1_VAE.pth'
|
| 23 |
+
i2v_14B.vae_stride = (4, 8, 8)
|
| 24 |
+
|
| 25 |
+
# transformer
|
| 26 |
+
i2v_14B.patch_size = (1, 2, 2)
|
| 27 |
+
i2v_14B.dim = 5120
|
| 28 |
+
i2v_14B.ffn_dim = 13824
|
| 29 |
+
i2v_14B.freq_dim = 256
|
| 30 |
+
i2v_14B.num_heads = 40
|
| 31 |
+
i2v_14B.num_layers = 40
|
| 32 |
+
i2v_14B.window_size = (-1, -1)
|
| 33 |
+
i2v_14B.qk_norm = True
|
| 34 |
+
i2v_14B.cross_attn_norm = True
|
| 35 |
+
i2v_14B.eps = 1e-6
|
wan/configs/wan_t2v_14B.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
from easydict import EasyDict
|
| 3 |
+
|
| 4 |
+
from .shared_config import wan_shared_cfg
|
| 5 |
+
|
| 6 |
+
# ------------------------ Wan T2V 14B ------------------------#
|
| 7 |
+
|
| 8 |
+
t2v_14B = EasyDict(__name__='Config: Wan T2V 14B')
|
| 9 |
+
t2v_14B.update(wan_shared_cfg)
|
| 10 |
+
|
| 11 |
+
# t5
|
| 12 |
+
t2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
|
| 13 |
+
t2v_14B.t5_tokenizer = 'google/umt5-xxl'
|
| 14 |
+
|
| 15 |
+
# vae
|
| 16 |
+
t2v_14B.vae_checkpoint = 'Wan2.1_VAE.pth'
|
| 17 |
+
t2v_14B.vae_stride = (4, 8, 8)
|
| 18 |
+
|
| 19 |
+
# transformer
|
| 20 |
+
t2v_14B.patch_size = (1, 2, 2)
|
| 21 |
+
t2v_14B.dim = 5120
|
| 22 |
+
t2v_14B.ffn_dim = 13824
|
| 23 |
+
t2v_14B.freq_dim = 256
|
| 24 |
+
t2v_14B.num_heads = 40
|
| 25 |
+
t2v_14B.num_layers = 40
|
| 26 |
+
t2v_14B.window_size = (-1, -1)
|
| 27 |
+
t2v_14B.qk_norm = True
|
| 28 |
+
t2v_14B.cross_attn_norm = True
|
| 29 |
+
t2v_14B.eps = 1e-6
|
wan/configs/wan_t2v_1_3B.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
from easydict import EasyDict
|
| 3 |
+
|
| 4 |
+
from .shared_config import wan_shared_cfg
|
| 5 |
+
|
| 6 |
+
# ------------------------ Wan T2V 1.3B ------------------------#
|
| 7 |
+
|
| 8 |
+
t2v_1_3B = EasyDict(__name__='Config: Wan T2V 1.3B')
|
| 9 |
+
t2v_1_3B.update(wan_shared_cfg)
|
| 10 |
+
|
| 11 |
+
# t5
|
| 12 |
+
t2v_1_3B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
|
| 13 |
+
t2v_1_3B.t5_tokenizer = 'google/umt5-xxl'
|
| 14 |
+
|
| 15 |
+
# vae
|
| 16 |
+
t2v_1_3B.vae_checkpoint = 'Wan2.1_VAE.pth'
|
| 17 |
+
t2v_1_3B.vae_stride = (4, 8, 8)
|
| 18 |
+
|
| 19 |
+
# transformer
|
| 20 |
+
t2v_1_3B.patch_size = (1, 2, 2)
|
| 21 |
+
t2v_1_3B.dim = 1536
|
| 22 |
+
t2v_1_3B.ffn_dim = 8960
|
| 23 |
+
t2v_1_3B.freq_dim = 256
|
| 24 |
+
t2v_1_3B.num_heads = 12
|
| 25 |
+
t2v_1_3B.num_layers = 30
|
| 26 |
+
t2v_1_3B.window_size = (-1, -1)
|
| 27 |
+
t2v_1_3B.qk_norm = True
|
| 28 |
+
t2v_1_3B.cross_attn_norm = True
|
| 29 |
+
t2v_1_3B.eps = 1e-6
|
wan/distributed/__init__.py
ADDED
|
File without changes
|
wan/distributed/fsdp.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
from functools import partial
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
| 6 |
+
from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
|
| 7 |
+
from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def shard_model(
|
| 11 |
+
model,
|
| 12 |
+
device_id,
|
| 13 |
+
param_dtype=torch.bfloat16,
|
| 14 |
+
reduce_dtype=torch.float32,
|
| 15 |
+
buffer_dtype=torch.float32,
|
| 16 |
+
process_group=None,
|
| 17 |
+
sharding_strategy=ShardingStrategy.FULL_SHARD,
|
| 18 |
+
sync_module_states=True,
|
| 19 |
+
):
|
| 20 |
+
model = FSDP(
|
| 21 |
+
module=model,
|
| 22 |
+
process_group=process_group,
|
| 23 |
+
sharding_strategy=sharding_strategy,
|
| 24 |
+
auto_wrap_policy=partial(
|
| 25 |
+
lambda_auto_wrap_policy, lambda_fn=lambda m: m in model.blocks),
|
| 26 |
+
mixed_precision=MixedPrecision(
|
| 27 |
+
param_dtype=param_dtype,
|
| 28 |
+
reduce_dtype=reduce_dtype,
|
| 29 |
+
buffer_dtype=buffer_dtype),
|
| 30 |
+
device_id=device_id,
|
| 31 |
+
use_orig_params=True,
|
| 32 |
+
sync_module_states=sync_module_states)
|
| 33 |
+
return model
|
wan/distributed/xdit_context_parallel.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
import torch
|
| 3 |
+
import torch.cuda.amp as amp
|
| 4 |
+
from xfuser.core.distributed import (get_sequence_parallel_rank,
|
| 5 |
+
get_sequence_parallel_world_size,
|
| 6 |
+
get_sp_group)
|
| 7 |
+
from xfuser.core.long_ctx_attention import xFuserLongContextAttention
|
| 8 |
+
|
| 9 |
+
from ..modules.model import sinusoidal_embedding_1d
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def pad_freqs(original_tensor, target_len):
|
| 13 |
+
seq_len, s1, s2 = original_tensor.shape
|
| 14 |
+
pad_size = target_len - seq_len
|
| 15 |
+
padding_tensor = torch.ones(
|
| 16 |
+
pad_size,
|
| 17 |
+
s1,
|
| 18 |
+
s2,
|
| 19 |
+
dtype=original_tensor.dtype,
|
| 20 |
+
device=original_tensor.device)
|
| 21 |
+
padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
|
| 22 |
+
return padded_tensor
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@amp.autocast(enabled=False)
|
| 26 |
+
def rope_apply(x, grid_sizes, freqs):
|
| 27 |
+
"""
|
| 28 |
+
x: [B, L, N, C].
|
| 29 |
+
grid_sizes: [B, 3].
|
| 30 |
+
freqs: [M, C // 2].
|
| 31 |
+
"""
|
| 32 |
+
s, n, c = x.size(1), x.size(2), x.size(3) // 2
|
| 33 |
+
# split freqs
|
| 34 |
+
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
|
| 35 |
+
|
| 36 |
+
# loop over samples
|
| 37 |
+
output = []
|
| 38 |
+
for i, (f, h, w) in enumerate(grid_sizes.tolist()):
|
| 39 |
+
seq_len = f * h * w
|
| 40 |
+
|
| 41 |
+
# precompute multipliers
|
| 42 |
+
x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape(
|
| 43 |
+
s, n, -1, 2))
|
| 44 |
+
freqs_i = torch.cat([
|
| 45 |
+
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
| 46 |
+
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
| 47 |
+
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
| 48 |
+
],
|
| 49 |
+
dim=-1).reshape(seq_len, 1, -1)
|
| 50 |
+
|
| 51 |
+
# apply rotary embedding
|
| 52 |
+
sp_size = get_sequence_parallel_world_size()
|
| 53 |
+
sp_rank = get_sequence_parallel_rank()
|
| 54 |
+
freqs_i = pad_freqs(freqs_i, s * sp_size)
|
| 55 |
+
s_per_rank = s
|
| 56 |
+
freqs_i_rank = freqs_i[(sp_rank * s_per_rank):((sp_rank + 1) *
|
| 57 |
+
s_per_rank), :, :]
|
| 58 |
+
x_i = torch.view_as_real(x_i * freqs_i_rank).flatten(2)
|
| 59 |
+
x_i = torch.cat([x_i, x[i, s:]])
|
| 60 |
+
|
| 61 |
+
# append to collection
|
| 62 |
+
output.append(x_i)
|
| 63 |
+
return torch.stack(output).float()
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def usp_dit_forward(
|
| 67 |
+
self,
|
| 68 |
+
x,
|
| 69 |
+
t,
|
| 70 |
+
context,
|
| 71 |
+
seq_len,
|
| 72 |
+
clip_fea=None,
|
| 73 |
+
y=None,
|
| 74 |
+
):
|
| 75 |
+
"""
|
| 76 |
+
x: A list of videos each with shape [C, T, H, W].
|
| 77 |
+
t: [B].
|
| 78 |
+
context: A list of text embeddings each with shape [L, C].
|
| 79 |
+
"""
|
| 80 |
+
if self.model_type == 'i2v':
|
| 81 |
+
assert clip_fea is not None and y is not None
|
| 82 |
+
# params
|
| 83 |
+
device = self.patch_embedding.weight.device
|
| 84 |
+
if self.freqs.device != device:
|
| 85 |
+
self.freqs = self.freqs.to(device)
|
| 86 |
+
|
| 87 |
+
if y is not None:
|
| 88 |
+
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
|
| 89 |
+
|
| 90 |
+
# embeddings
|
| 91 |
+
x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
|
| 92 |
+
grid_sizes = torch.stack(
|
| 93 |
+
[torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
|
| 94 |
+
x = [u.flatten(2).transpose(1, 2) for u in x]
|
| 95 |
+
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
|
| 96 |
+
assert seq_lens.max() <= seq_len
|
| 97 |
+
x = torch.cat([
|
| 98 |
+
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1)
|
| 99 |
+
for u in x
|
| 100 |
+
])
|
| 101 |
+
|
| 102 |
+
# time embeddings
|
| 103 |
+
with amp.autocast(dtype=torch.float32):
|
| 104 |
+
e = self.time_embedding(
|
| 105 |
+
sinusoidal_embedding_1d(self.freq_dim, t).float())
|
| 106 |
+
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
|
| 107 |
+
assert e.dtype == torch.float32 and e0.dtype == torch.float32
|
| 108 |
+
|
| 109 |
+
# context
|
| 110 |
+
context_lens = None
|
| 111 |
+
context = self.text_embedding(
|
| 112 |
+
torch.stack([
|
| 113 |
+
torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
|
| 114 |
+
for u in context
|
| 115 |
+
]))
|
| 116 |
+
|
| 117 |
+
if clip_fea is not None:
|
| 118 |
+
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
| 119 |
+
context = torch.concat([context_clip, context], dim=1)
|
| 120 |
+
|
| 121 |
+
# arguments
|
| 122 |
+
kwargs = dict(
|
| 123 |
+
e=e0,
|
| 124 |
+
seq_lens=seq_lens,
|
| 125 |
+
grid_sizes=grid_sizes,
|
| 126 |
+
freqs=self.freqs,
|
| 127 |
+
context=context,
|
| 128 |
+
context_lens=context_lens)
|
| 129 |
+
|
| 130 |
+
# Context Parallel
|
| 131 |
+
x = torch.chunk(
|
| 132 |
+
x, get_sequence_parallel_world_size(),
|
| 133 |
+
dim=1)[get_sequence_parallel_rank()]
|
| 134 |
+
|
| 135 |
+
for block in self.blocks:
|
| 136 |
+
x = block(x, **kwargs)
|
| 137 |
+
|
| 138 |
+
# head
|
| 139 |
+
x = self.head(x, e)
|
| 140 |
+
|
| 141 |
+
# Context Parallel
|
| 142 |
+
x = get_sp_group().all_gather(x, dim=1)
|
| 143 |
+
|
| 144 |
+
# unpatchify
|
| 145 |
+
x = self.unpatchify(x, grid_sizes)
|
| 146 |
+
return [u.float() for u in x]
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def usp_attn_forward(self,
|
| 150 |
+
x,
|
| 151 |
+
seq_lens,
|
| 152 |
+
grid_sizes,
|
| 153 |
+
freqs,
|
| 154 |
+
dtype=torch.bfloat16):
|
| 155 |
+
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
|
| 156 |
+
half_dtypes = (torch.float16, torch.bfloat16)
|
| 157 |
+
|
| 158 |
+
def half(x):
|
| 159 |
+
return x if x.dtype in half_dtypes else x.to(dtype)
|
| 160 |
+
|
| 161 |
+
# query, key, value function
|
| 162 |
+
def qkv_fn(x):
|
| 163 |
+
q = self.norm_q(self.q(x)).view(b, s, n, d)
|
| 164 |
+
k = self.norm_k(self.k(x)).view(b, s, n, d)
|
| 165 |
+
v = self.v(x).view(b, s, n, d)
|
| 166 |
+
return q, k, v
|
| 167 |
+
|
| 168 |
+
q, k, v = qkv_fn(x)
|
| 169 |
+
q = rope_apply(q, grid_sizes, freqs)
|
| 170 |
+
k = rope_apply(k, grid_sizes, freqs)
|
| 171 |
+
|
| 172 |
+
# TODO: We should use unpaded q,k,v for attention.
|
| 173 |
+
# k_lens = seq_lens // get_sequence_parallel_world_size()
|
| 174 |
+
# if k_lens is not None:
|
| 175 |
+
# q = torch.cat([u[:l] for u, l in zip(q, k_lens)]).unsqueeze(0)
|
| 176 |
+
# k = torch.cat([u[:l] for u, l in zip(k, k_lens)]).unsqueeze(0)
|
| 177 |
+
# v = torch.cat([u[:l] for u, l in zip(v, k_lens)]).unsqueeze(0)
|
| 178 |
+
|
| 179 |
+
x = xFuserLongContextAttention()(
|
| 180 |
+
None,
|
| 181 |
+
query=half(q),
|
| 182 |
+
key=half(k),
|
| 183 |
+
value=half(v),
|
| 184 |
+
window_size=self.window_size)
|
| 185 |
+
|
| 186 |
+
# TODO: padding after attention.
|
| 187 |
+
# x = torch.cat([x, x.new_zeros(b, s - x.size(1), n, d)], dim=1)
|
| 188 |
+
|
| 189 |
+
# output
|
| 190 |
+
x = x.flatten(2)
|
| 191 |
+
x = self.o(x)
|
| 192 |
+
return x
|
wan/image2video.py
ADDED
|
@@ -0,0 +1,347 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
import gc
|
| 3 |
+
import logging
|
| 4 |
+
import math
|
| 5 |
+
import os
|
| 6 |
+
import random
|
| 7 |
+
import sys
|
| 8 |
+
import types
|
| 9 |
+
from contextlib import contextmanager
|
| 10 |
+
from functools import partial
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
import torch.cuda.amp as amp
|
| 15 |
+
import torch.distributed as dist
|
| 16 |
+
import torchvision.transforms.functional as TF
|
| 17 |
+
from tqdm import tqdm
|
| 18 |
+
|
| 19 |
+
from .distributed.fsdp import shard_model
|
| 20 |
+
from .modules.clip import CLIPModel
|
| 21 |
+
from .modules.model import WanModel
|
| 22 |
+
from .modules.t5 import T5EncoderModel
|
| 23 |
+
from .modules.vae import WanVAE
|
| 24 |
+
from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
|
| 25 |
+
get_sampling_sigmas, retrieve_timesteps)
|
| 26 |
+
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class WanI2V:
|
| 30 |
+
|
| 31 |
+
def __init__(
|
| 32 |
+
self,
|
| 33 |
+
config,
|
| 34 |
+
checkpoint_dir,
|
| 35 |
+
device_id=0,
|
| 36 |
+
rank=0,
|
| 37 |
+
t5_fsdp=False,
|
| 38 |
+
dit_fsdp=False,
|
| 39 |
+
use_usp=False,
|
| 40 |
+
t5_cpu=False,
|
| 41 |
+
init_on_cpu=True,
|
| 42 |
+
):
|
| 43 |
+
r"""
|
| 44 |
+
Initializes the image-to-video generation model components.
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
config (EasyDict):
|
| 48 |
+
Object containing model parameters initialized from config.py
|
| 49 |
+
checkpoint_dir (`str`):
|
| 50 |
+
Path to directory containing model checkpoints
|
| 51 |
+
device_id (`int`, *optional*, defaults to 0):
|
| 52 |
+
Id of target GPU device
|
| 53 |
+
rank (`int`, *optional*, defaults to 0):
|
| 54 |
+
Process rank for distributed training
|
| 55 |
+
t5_fsdp (`bool`, *optional*, defaults to False):
|
| 56 |
+
Enable FSDP sharding for T5 model
|
| 57 |
+
dit_fsdp (`bool`, *optional*, defaults to False):
|
| 58 |
+
Enable FSDP sharding for DiT model
|
| 59 |
+
use_usp (`bool`, *optional*, defaults to False):
|
| 60 |
+
Enable distribution strategy of USP.
|
| 61 |
+
t5_cpu (`bool`, *optional*, defaults to False):
|
| 62 |
+
Whether to place T5 model on CPU. Only works without t5_fsdp.
|
| 63 |
+
init_on_cpu (`bool`, *optional*, defaults to True):
|
| 64 |
+
Enable initializing Transformer Model on CPU. Only works without FSDP or USP.
|
| 65 |
+
"""
|
| 66 |
+
self.device = torch.device(f"cuda:{device_id}")
|
| 67 |
+
self.config = config
|
| 68 |
+
self.rank = rank
|
| 69 |
+
self.use_usp = use_usp
|
| 70 |
+
self.t5_cpu = t5_cpu
|
| 71 |
+
|
| 72 |
+
self.num_train_timesteps = config.num_train_timesteps
|
| 73 |
+
self.param_dtype = config.param_dtype
|
| 74 |
+
|
| 75 |
+
shard_fn = partial(shard_model, device_id=device_id)
|
| 76 |
+
self.text_encoder = T5EncoderModel(
|
| 77 |
+
text_len=config.text_len,
|
| 78 |
+
dtype=config.t5_dtype,
|
| 79 |
+
device=torch.device('cpu'),
|
| 80 |
+
checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),
|
| 81 |
+
tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
|
| 82 |
+
shard_fn=shard_fn if t5_fsdp else None,
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
self.vae_stride = config.vae_stride
|
| 86 |
+
self.patch_size = config.patch_size
|
| 87 |
+
self.vae = WanVAE(
|
| 88 |
+
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
|
| 89 |
+
device=self.device)
|
| 90 |
+
|
| 91 |
+
self.clip = CLIPModel(
|
| 92 |
+
dtype=config.clip_dtype,
|
| 93 |
+
device=self.device,
|
| 94 |
+
checkpoint_path=os.path.join(checkpoint_dir,
|
| 95 |
+
config.clip_checkpoint),
|
| 96 |
+
tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer))
|
| 97 |
+
|
| 98 |
+
logging.info(f"Creating WanModel from {checkpoint_dir}")
|
| 99 |
+
self.model = WanModel.from_pretrained(checkpoint_dir)
|
| 100 |
+
self.model.eval().requires_grad_(False)
|
| 101 |
+
|
| 102 |
+
if t5_fsdp or dit_fsdp or use_usp:
|
| 103 |
+
init_on_cpu = False
|
| 104 |
+
|
| 105 |
+
if use_usp:
|
| 106 |
+
from xfuser.core.distributed import \
|
| 107 |
+
get_sequence_parallel_world_size
|
| 108 |
+
|
| 109 |
+
from .distributed.xdit_context_parallel import (usp_attn_forward,
|
| 110 |
+
usp_dit_forward)
|
| 111 |
+
for block in self.model.blocks:
|
| 112 |
+
block.self_attn.forward = types.MethodType(
|
| 113 |
+
usp_attn_forward, block.self_attn)
|
| 114 |
+
self.model.forward = types.MethodType(usp_dit_forward, self.model)
|
| 115 |
+
self.sp_size = get_sequence_parallel_world_size()
|
| 116 |
+
else:
|
| 117 |
+
self.sp_size = 1
|
| 118 |
+
|
| 119 |
+
if dist.is_initialized():
|
| 120 |
+
dist.barrier()
|
| 121 |
+
if dit_fsdp:
|
| 122 |
+
self.model = shard_fn(self.model)
|
| 123 |
+
else:
|
| 124 |
+
if not init_on_cpu:
|
| 125 |
+
self.model.to(self.device)
|
| 126 |
+
|
| 127 |
+
self.sample_neg_prompt = config.sample_neg_prompt
|
| 128 |
+
|
| 129 |
+
def generate(self,
|
| 130 |
+
input_prompt,
|
| 131 |
+
img,
|
| 132 |
+
max_area=720 * 1280,
|
| 133 |
+
frame_num=81,
|
| 134 |
+
shift=5.0,
|
| 135 |
+
sample_solver='unipc',
|
| 136 |
+
sampling_steps=40,
|
| 137 |
+
guide_scale=5.0,
|
| 138 |
+
n_prompt="",
|
| 139 |
+
seed=-1,
|
| 140 |
+
offload_model=True):
|
| 141 |
+
r"""
|
| 142 |
+
Generates video frames from input image and text prompt using diffusion process.
|
| 143 |
+
|
| 144 |
+
Args:
|
| 145 |
+
input_prompt (`str`):
|
| 146 |
+
Text prompt for content generation.
|
| 147 |
+
img (PIL.Image.Image):
|
| 148 |
+
Input image tensor. Shape: [3, H, W]
|
| 149 |
+
max_area (`int`, *optional*, defaults to 720*1280):
|
| 150 |
+
Maximum pixel area for latent space calculation. Controls video resolution scaling
|
| 151 |
+
frame_num (`int`, *optional*, defaults to 81):
|
| 152 |
+
How many frames to sample from a video. The number should be 4n+1
|
| 153 |
+
shift (`float`, *optional*, defaults to 5.0):
|
| 154 |
+
Noise schedule shift parameter. Affects temporal dynamics
|
| 155 |
+
[NOTE]: If you want to generate a 480p video, it is recommended to set the shift value to 3.0.
|
| 156 |
+
sample_solver (`str`, *optional*, defaults to 'unipc'):
|
| 157 |
+
Solver used to sample the video.
|
| 158 |
+
sampling_steps (`int`, *optional*, defaults to 40):
|
| 159 |
+
Number of diffusion sampling steps. Higher values improve quality but slow generation
|
| 160 |
+
guide_scale (`float`, *optional*, defaults 5.0):
|
| 161 |
+
Classifier-free guidance scale. Controls prompt adherence vs. creativity
|
| 162 |
+
n_prompt (`str`, *optional*, defaults to ""):
|
| 163 |
+
Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`
|
| 164 |
+
seed (`int`, *optional*, defaults to -1):
|
| 165 |
+
Random seed for noise generation. If -1, use random seed
|
| 166 |
+
offload_model (`bool`, *optional*, defaults to True):
|
| 167 |
+
If True, offloads models to CPU during generation to save VRAM
|
| 168 |
+
|
| 169 |
+
Returns:
|
| 170 |
+
torch.Tensor:
|
| 171 |
+
Generated video frames tensor. Dimensions: (C, N H, W) where:
|
| 172 |
+
- C: Color channels (3 for RGB)
|
| 173 |
+
- N: Number of frames (81)
|
| 174 |
+
- H: Frame height (from max_area)
|
| 175 |
+
- W: Frame width from max_area)
|
| 176 |
+
"""
|
| 177 |
+
img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device)
|
| 178 |
+
|
| 179 |
+
F = frame_num
|
| 180 |
+
h, w = img.shape[1:]
|
| 181 |
+
aspect_ratio = h / w
|
| 182 |
+
lat_h = round(
|
| 183 |
+
np.sqrt(max_area * aspect_ratio) // self.vae_stride[1] //
|
| 184 |
+
self.patch_size[1] * self.patch_size[1])
|
| 185 |
+
lat_w = round(
|
| 186 |
+
np.sqrt(max_area / aspect_ratio) // self.vae_stride[2] //
|
| 187 |
+
self.patch_size[2] * self.patch_size[2])
|
| 188 |
+
h = lat_h * self.vae_stride[1]
|
| 189 |
+
w = lat_w * self.vae_stride[2]
|
| 190 |
+
|
| 191 |
+
max_seq_len = ((F - 1) // self.vae_stride[0] + 1) * lat_h * lat_w // (
|
| 192 |
+
self.patch_size[1] * self.patch_size[2])
|
| 193 |
+
max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size
|
| 194 |
+
|
| 195 |
+
seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
|
| 196 |
+
seed_g = torch.Generator(device=self.device)
|
| 197 |
+
seed_g.manual_seed(seed)
|
| 198 |
+
noise = torch.randn(
|
| 199 |
+
16,
|
| 200 |
+
21,
|
| 201 |
+
lat_h,
|
| 202 |
+
lat_w,
|
| 203 |
+
dtype=torch.float32,
|
| 204 |
+
generator=seed_g,
|
| 205 |
+
device=self.device)
|
| 206 |
+
|
| 207 |
+
msk = torch.ones(1, 81, lat_h, lat_w, device=self.device)
|
| 208 |
+
msk[:, 1:] = 0
|
| 209 |
+
msk = torch.concat([
|
| 210 |
+
torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]
|
| 211 |
+
],
|
| 212 |
+
dim=1)
|
| 213 |
+
msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
|
| 214 |
+
msk = msk.transpose(1, 2)[0]
|
| 215 |
+
|
| 216 |
+
if n_prompt == "":
|
| 217 |
+
n_prompt = self.sample_neg_prompt
|
| 218 |
+
|
| 219 |
+
# preprocess
|
| 220 |
+
if not self.t5_cpu:
|
| 221 |
+
self.text_encoder.model.to(self.device)
|
| 222 |
+
context = self.text_encoder([input_prompt], self.device)
|
| 223 |
+
context_null = self.text_encoder([n_prompt], self.device)
|
| 224 |
+
if offload_model:
|
| 225 |
+
self.text_encoder.model.cpu()
|
| 226 |
+
else:
|
| 227 |
+
context = self.text_encoder([input_prompt], torch.device('cpu'))
|
| 228 |
+
context_null = self.text_encoder([n_prompt], torch.device('cpu'))
|
| 229 |
+
context = [t.to(self.device) for t in context]
|
| 230 |
+
context_null = [t.to(self.device) for t in context_null]
|
| 231 |
+
|
| 232 |
+
self.clip.model.to(self.device)
|
| 233 |
+
clip_context = self.clip.visual([img[:, None, :, :]])
|
| 234 |
+
if offload_model:
|
| 235 |
+
self.clip.model.cpu()
|
| 236 |
+
|
| 237 |
+
y = self.vae.encode([
|
| 238 |
+
torch.concat([
|
| 239 |
+
torch.nn.functional.interpolate(
|
| 240 |
+
img[None].cpu(), size=(h, w), mode='bicubic').transpose(
|
| 241 |
+
0, 1),
|
| 242 |
+
torch.zeros(3, 80, h, w)
|
| 243 |
+
],
|
| 244 |
+
dim=1).to(self.device)
|
| 245 |
+
])[0]
|
| 246 |
+
y = torch.concat([msk, y])
|
| 247 |
+
|
| 248 |
+
@contextmanager
|
| 249 |
+
def noop_no_sync():
|
| 250 |
+
yield
|
| 251 |
+
|
| 252 |
+
no_sync = getattr(self.model, 'no_sync', noop_no_sync)
|
| 253 |
+
|
| 254 |
+
# evaluation mode
|
| 255 |
+
with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync():
|
| 256 |
+
|
| 257 |
+
if sample_solver == 'unipc':
|
| 258 |
+
sample_scheduler = FlowUniPCMultistepScheduler(
|
| 259 |
+
num_train_timesteps=self.num_train_timesteps,
|
| 260 |
+
shift=1,
|
| 261 |
+
use_dynamic_shifting=False)
|
| 262 |
+
sample_scheduler.set_timesteps(
|
| 263 |
+
sampling_steps, device=self.device, shift=shift)
|
| 264 |
+
timesteps = sample_scheduler.timesteps
|
| 265 |
+
elif sample_solver == 'dpm++':
|
| 266 |
+
sample_scheduler = FlowDPMSolverMultistepScheduler(
|
| 267 |
+
num_train_timesteps=self.num_train_timesteps,
|
| 268 |
+
shift=1,
|
| 269 |
+
use_dynamic_shifting=False)
|
| 270 |
+
sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
|
| 271 |
+
timesteps, _ = retrieve_timesteps(
|
| 272 |
+
sample_scheduler,
|
| 273 |
+
device=self.device,
|
| 274 |
+
sigmas=sampling_sigmas)
|
| 275 |
+
else:
|
| 276 |
+
raise NotImplementedError("Unsupported solver.")
|
| 277 |
+
|
| 278 |
+
# sample videos
|
| 279 |
+
latent = noise
|
| 280 |
+
|
| 281 |
+
arg_c = {
|
| 282 |
+
'context': [context[0]],
|
| 283 |
+
'clip_fea': clip_context,
|
| 284 |
+
'seq_len': max_seq_len,
|
| 285 |
+
'y': [y],
|
| 286 |
+
}
|
| 287 |
+
|
| 288 |
+
arg_null = {
|
| 289 |
+
'context': context_null,
|
| 290 |
+
'clip_fea': clip_context,
|
| 291 |
+
'seq_len': max_seq_len,
|
| 292 |
+
'y': [y],
|
| 293 |
+
}
|
| 294 |
+
|
| 295 |
+
if offload_model:
|
| 296 |
+
torch.cuda.empty_cache()
|
| 297 |
+
|
| 298 |
+
self.model.to(self.device)
|
| 299 |
+
for _, t in enumerate(tqdm(timesteps)):
|
| 300 |
+
latent_model_input = [latent.to(self.device)]
|
| 301 |
+
timestep = [t]
|
| 302 |
+
|
| 303 |
+
timestep = torch.stack(timestep).to(self.device)
|
| 304 |
+
|
| 305 |
+
noise_pred_cond = self.model(
|
| 306 |
+
latent_model_input, t=timestep, **arg_c)[0].to(
|
| 307 |
+
torch.device('cpu') if offload_model else self.device)
|
| 308 |
+
if offload_model:
|
| 309 |
+
torch.cuda.empty_cache()
|
| 310 |
+
noise_pred_uncond = self.model(
|
| 311 |
+
latent_model_input, t=timestep, **arg_null)[0].to(
|
| 312 |
+
torch.device('cpu') if offload_model else self.device)
|
| 313 |
+
if offload_model:
|
| 314 |
+
torch.cuda.empty_cache()
|
| 315 |
+
noise_pred = noise_pred_uncond + guide_scale * (
|
| 316 |
+
noise_pred_cond - noise_pred_uncond)
|
| 317 |
+
|
| 318 |
+
latent = latent.to(
|
| 319 |
+
torch.device('cpu') if offload_model else self.device)
|
| 320 |
+
|
| 321 |
+
temp_x0 = sample_scheduler.step(
|
| 322 |
+
noise_pred.unsqueeze(0),
|
| 323 |
+
t,
|
| 324 |
+
latent.unsqueeze(0),
|
| 325 |
+
return_dict=False,
|
| 326 |
+
generator=seed_g)[0]
|
| 327 |
+
latent = temp_x0.squeeze(0)
|
| 328 |
+
|
| 329 |
+
x0 = [latent.to(self.device)]
|
| 330 |
+
del latent_model_input, timestep
|
| 331 |
+
|
| 332 |
+
if offload_model:
|
| 333 |
+
self.model.cpu()
|
| 334 |
+
torch.cuda.empty_cache()
|
| 335 |
+
|
| 336 |
+
if self.rank == 0:
|
| 337 |
+
videos = self.vae.decode(x0)
|
| 338 |
+
|
| 339 |
+
del noise, latent
|
| 340 |
+
del sample_scheduler
|
| 341 |
+
if offload_model:
|
| 342 |
+
gc.collect()
|
| 343 |
+
torch.cuda.synchronize()
|
| 344 |
+
if dist.is_initialized():
|
| 345 |
+
dist.barrier()
|
| 346 |
+
|
| 347 |
+
return videos[0] if self.rank == 0 else None
|
wan/modules/__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .attention import flash_attention
|
| 2 |
+
from .model import WanModel
|
| 3 |
+
from .t5 import T5Decoder, T5Encoder, T5EncoderModel, T5Model
|
| 4 |
+
from .tokenizers import HuggingfaceTokenizer
|
| 5 |
+
from .vae import WanVAE
|
| 6 |
+
|
| 7 |
+
__all__ = [
|
| 8 |
+
'WanVAE',
|
| 9 |
+
'WanModel',
|
| 10 |
+
'T5Model',
|
| 11 |
+
'T5Encoder',
|
| 12 |
+
'T5Decoder',
|
| 13 |
+
'T5EncoderModel',
|
| 14 |
+
'HuggingfaceTokenizer',
|
| 15 |
+
'flash_attention',
|
| 16 |
+
]
|
wan/modules/attention.py
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
try:
|
| 5 |
+
import flash_attn_interface
|
| 6 |
+
|
| 7 |
+
def is_hopper_gpu():
|
| 8 |
+
if not torch.cuda.is_available():
|
| 9 |
+
return False
|
| 10 |
+
device_name = torch.cuda.get_device_name(0).lower()
|
| 11 |
+
return "h100" in device_name or "hopper" in device_name
|
| 12 |
+
FLASH_ATTN_3_AVAILABLE = is_hopper_gpu()
|
| 13 |
+
except ModuleNotFoundError:
|
| 14 |
+
FLASH_ATTN_3_AVAILABLE = False
|
| 15 |
+
|
| 16 |
+
try:
|
| 17 |
+
import flash_attn
|
| 18 |
+
FLASH_ATTN_2_AVAILABLE = True
|
| 19 |
+
except ModuleNotFoundError:
|
| 20 |
+
FLASH_ATTN_2_AVAILABLE = False
|
| 21 |
+
|
| 22 |
+
# FLASH_ATTN_3_AVAILABLE = False
|
| 23 |
+
|
| 24 |
+
import warnings
|
| 25 |
+
|
| 26 |
+
__all__ = [
|
| 27 |
+
'flash_attention',
|
| 28 |
+
'attention',
|
| 29 |
+
]
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def flash_attention(
|
| 33 |
+
q,
|
| 34 |
+
k,
|
| 35 |
+
v,
|
| 36 |
+
q_lens=None,
|
| 37 |
+
k_lens=None,
|
| 38 |
+
dropout_p=0.,
|
| 39 |
+
softmax_scale=None,
|
| 40 |
+
q_scale=None,
|
| 41 |
+
causal=False,
|
| 42 |
+
window_size=(-1, -1),
|
| 43 |
+
deterministic=False,
|
| 44 |
+
dtype=torch.bfloat16,
|
| 45 |
+
version=None,
|
| 46 |
+
):
|
| 47 |
+
"""
|
| 48 |
+
q: [B, Lq, Nq, C1].
|
| 49 |
+
k: [B, Lk, Nk, C1].
|
| 50 |
+
v: [B, Lk, Nk, C2]. Nq must be divisible by Nk.
|
| 51 |
+
q_lens: [B].
|
| 52 |
+
k_lens: [B].
|
| 53 |
+
dropout_p: float. Dropout probability.
|
| 54 |
+
softmax_scale: float. The scaling of QK^T before applying softmax.
|
| 55 |
+
causal: bool. Whether to apply causal attention mask.
|
| 56 |
+
window_size: (left right). If not (-1, -1), apply sliding window local attention.
|
| 57 |
+
deterministic: bool. If True, slightly slower and uses more memory.
|
| 58 |
+
dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16.
|
| 59 |
+
"""
|
| 60 |
+
half_dtypes = (torch.float16, torch.bfloat16)
|
| 61 |
+
assert dtype in half_dtypes
|
| 62 |
+
assert q.device.type == 'cuda' and q.size(-1) <= 256
|
| 63 |
+
|
| 64 |
+
# params
|
| 65 |
+
b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype
|
| 66 |
+
|
| 67 |
+
def half(x):
|
| 68 |
+
return x if x.dtype in half_dtypes else x.to(dtype)
|
| 69 |
+
|
| 70 |
+
# preprocess query
|
| 71 |
+
if q_lens is None:
|
| 72 |
+
q = half(q.flatten(0, 1))
|
| 73 |
+
q_lens = torch.tensor(
|
| 74 |
+
[lq] * b, dtype=torch.int32).to(
|
| 75 |
+
device=q.device, non_blocking=True)
|
| 76 |
+
else:
|
| 77 |
+
q = half(torch.cat([u[:v] for u, v in zip(q, q_lens)]))
|
| 78 |
+
|
| 79 |
+
# preprocess key, value
|
| 80 |
+
if k_lens is None:
|
| 81 |
+
k = half(k.flatten(0, 1))
|
| 82 |
+
v = half(v.flatten(0, 1))
|
| 83 |
+
k_lens = torch.tensor(
|
| 84 |
+
[lk] * b, dtype=torch.int32).to(
|
| 85 |
+
device=k.device, non_blocking=True)
|
| 86 |
+
else:
|
| 87 |
+
k = half(torch.cat([u[:v] for u, v in zip(k, k_lens)]))
|
| 88 |
+
v = half(torch.cat([u[:v] for u, v in zip(v, k_lens)]))
|
| 89 |
+
|
| 90 |
+
q = q.to(v.dtype)
|
| 91 |
+
k = k.to(v.dtype)
|
| 92 |
+
|
| 93 |
+
if q_scale is not None:
|
| 94 |
+
q = q * q_scale
|
| 95 |
+
|
| 96 |
+
if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE:
|
| 97 |
+
warnings.warn(
|
| 98 |
+
'Flash attention 3 is not available, use flash attention 2 instead.'
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
# apply attention
|
| 102 |
+
if (version is None or version == 3) and FLASH_ATTN_3_AVAILABLE:
|
| 103 |
+
# Note: dropout_p, window_size are not supported in FA3 now.
|
| 104 |
+
x = flash_attn_interface.flash_attn_varlen_func(
|
| 105 |
+
q=q,
|
| 106 |
+
k=k,
|
| 107 |
+
v=v,
|
| 108 |
+
cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
|
| 109 |
+
0, dtype=torch.int32).to(q.device, non_blocking=True),
|
| 110 |
+
cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
|
| 111 |
+
0, dtype=torch.int32).to(q.device, non_blocking=True),
|
| 112 |
+
max_seqlen_q=lq,
|
| 113 |
+
max_seqlen_k=lk,
|
| 114 |
+
softmax_scale=softmax_scale,
|
| 115 |
+
causal=causal,
|
| 116 |
+
deterministic=deterministic)[0].unflatten(0, (b, lq))
|
| 117 |
+
else:
|
| 118 |
+
assert FLASH_ATTN_2_AVAILABLE
|
| 119 |
+
x = flash_attn.flash_attn_varlen_func(
|
| 120 |
+
q=q,
|
| 121 |
+
k=k,
|
| 122 |
+
v=v,
|
| 123 |
+
cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
|
| 124 |
+
0, dtype=torch.int32).to(q.device, non_blocking=True),
|
| 125 |
+
cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
|
| 126 |
+
0, dtype=torch.int32).to(q.device, non_blocking=True),
|
| 127 |
+
max_seqlen_q=lq,
|
| 128 |
+
max_seqlen_k=lk,
|
| 129 |
+
dropout_p=dropout_p,
|
| 130 |
+
softmax_scale=softmax_scale,
|
| 131 |
+
causal=causal,
|
| 132 |
+
window_size=window_size,
|
| 133 |
+
deterministic=deterministic).unflatten(0, (b, lq))
|
| 134 |
+
|
| 135 |
+
# output
|
| 136 |
+
return x.type(out_dtype)
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def attention(
|
| 140 |
+
q,
|
| 141 |
+
k,
|
| 142 |
+
v,
|
| 143 |
+
q_lens=None,
|
| 144 |
+
k_lens=None,
|
| 145 |
+
dropout_p=0.,
|
| 146 |
+
softmax_scale=None,
|
| 147 |
+
q_scale=None,
|
| 148 |
+
causal=False,
|
| 149 |
+
window_size=(-1, -1),
|
| 150 |
+
deterministic=False,
|
| 151 |
+
dtype=torch.bfloat16,
|
| 152 |
+
fa_version=None,
|
| 153 |
+
):
|
| 154 |
+
if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE:
|
| 155 |
+
return flash_attention(
|
| 156 |
+
q=q,
|
| 157 |
+
k=k,
|
| 158 |
+
v=v,
|
| 159 |
+
q_lens=q_lens,
|
| 160 |
+
k_lens=k_lens,
|
| 161 |
+
dropout_p=dropout_p,
|
| 162 |
+
softmax_scale=softmax_scale,
|
| 163 |
+
q_scale=q_scale,
|
| 164 |
+
causal=causal,
|
| 165 |
+
window_size=window_size,
|
| 166 |
+
deterministic=deterministic,
|
| 167 |
+
dtype=dtype,
|
| 168 |
+
version=fa_version,
|
| 169 |
+
)
|
| 170 |
+
else:
|
| 171 |
+
if q_lens is not None or k_lens is not None:
|
| 172 |
+
warnings.warn(
|
| 173 |
+
'Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance.'
|
| 174 |
+
)
|
| 175 |
+
attn_mask = None
|
| 176 |
+
|
| 177 |
+
q = q.transpose(1, 2).to(dtype)
|
| 178 |
+
k = k.transpose(1, 2).to(dtype)
|
| 179 |
+
v = v.transpose(1, 2).to(dtype)
|
| 180 |
+
|
| 181 |
+
out = torch.nn.functional.scaled_dot_product_attention(
|
| 182 |
+
q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p)
|
| 183 |
+
|
| 184 |
+
out = out.transpose(1, 2).contiguous()
|
| 185 |
+
return out
|
wan/modules/causal_model.py
ADDED
|
@@ -0,0 +1,1127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from wan.modules.attention import attention
|
| 2 |
+
from wan.modules.model import (
|
| 3 |
+
WanRMSNorm,
|
| 4 |
+
rope_apply,
|
| 5 |
+
WanLayerNorm,
|
| 6 |
+
WAN_CROSSATTENTION_CLASSES,
|
| 7 |
+
rope_params,
|
| 8 |
+
MLPProj,
|
| 9 |
+
sinusoidal_embedding_1d
|
| 10 |
+
)
|
| 11 |
+
# from torch.nn.attention.flex_attention import create_block_mask, flex_attention
|
| 12 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 13 |
+
# from torch.nn.attention.flex_attention import BlockMask
|
| 14 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
import torch
|
| 17 |
+
import math
|
| 18 |
+
import torch.distributed as dist
|
| 19 |
+
|
| 20 |
+
# wan 1.3B model has a weird channel / head configurations and require max-autotune to work with flexattention
|
| 21 |
+
# see https://github.com/pytorch/pytorch/issues/133254
|
| 22 |
+
# change to default for other models
|
| 23 |
+
# flex_attention = torch.compile(
|
| 24 |
+
# flex_attention, dynamic=False, mode="max-autotune-no-cudagraphs")
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def causal_rope_apply(x, grid_sizes, freqs, start_frame=0):
|
| 28 |
+
n, c = x.size(2), x.size(3) // 2
|
| 29 |
+
|
| 30 |
+
# split freqs
|
| 31 |
+
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
|
| 32 |
+
|
| 33 |
+
# loop over samples
|
| 34 |
+
output = []
|
| 35 |
+
|
| 36 |
+
for i, (f, h, w) in enumerate(grid_sizes.tolist()):
|
| 37 |
+
seq_len = f * h * w
|
| 38 |
+
|
| 39 |
+
# precompute multipliers
|
| 40 |
+
x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float64).reshape(
|
| 41 |
+
seq_len, n, -1, 2))
|
| 42 |
+
freqs_i = torch.cat([
|
| 43 |
+
freqs[0][start_frame:start_frame + f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
| 44 |
+
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
| 45 |
+
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
| 46 |
+
],
|
| 47 |
+
dim=-1).reshape(seq_len, 1, -1)
|
| 48 |
+
|
| 49 |
+
# apply rotary embedding
|
| 50 |
+
x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
|
| 51 |
+
x_i = torch.cat([x_i, x[i, seq_len:]])
|
| 52 |
+
|
| 53 |
+
# append to collection
|
| 54 |
+
output.append(x_i)
|
| 55 |
+
return torch.stack(output).type_as(x)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class CausalWanSelfAttention(nn.Module):
|
| 59 |
+
|
| 60 |
+
def __init__(self,
|
| 61 |
+
dim,
|
| 62 |
+
num_heads,
|
| 63 |
+
local_attn_size=-1,
|
| 64 |
+
sink_size=1,
|
| 65 |
+
qk_norm=True,
|
| 66 |
+
eps=1e-6):
|
| 67 |
+
assert dim % num_heads == 0
|
| 68 |
+
super().__init__()
|
| 69 |
+
self.dim = dim
|
| 70 |
+
self.num_heads = num_heads
|
| 71 |
+
self.head_dim = dim // num_heads
|
| 72 |
+
self.local_attn_size = local_attn_size
|
| 73 |
+
self.qk_norm = qk_norm
|
| 74 |
+
self.eps = eps
|
| 75 |
+
self.frame_length = 1560
|
| 76 |
+
self.max_attention_size = 21 * self.frame_length
|
| 77 |
+
self.block_length = 3 * self.frame_length
|
| 78 |
+
|
| 79 |
+
# layers
|
| 80 |
+
self.q = nn.Linear(dim, dim)
|
| 81 |
+
self.k = nn.Linear(dim, dim)
|
| 82 |
+
self.v = nn.Linear(dim, dim)
|
| 83 |
+
self.o = nn.Linear(dim, dim)
|
| 84 |
+
self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
| 85 |
+
self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
| 86 |
+
|
| 87 |
+
def forward(
|
| 88 |
+
self,
|
| 89 |
+
x,
|
| 90 |
+
seq_lens,
|
| 91 |
+
grid_sizes,
|
| 92 |
+
freqs,
|
| 93 |
+
block_mask,
|
| 94 |
+
kv_cache=None,
|
| 95 |
+
current_start=0,
|
| 96 |
+
cache_start=None,
|
| 97 |
+
updating_cache=False
|
| 98 |
+
):
|
| 99 |
+
r"""
|
| 100 |
+
Args:
|
| 101 |
+
x(Tensor): Shape [B, L, num_heads, C / num_heads]
|
| 102 |
+
seq_lens(Tensor): Shape [B]
|
| 103 |
+
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
|
| 104 |
+
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
|
| 105 |
+
block_mask (BlockMask)
|
| 106 |
+
"""
|
| 107 |
+
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
|
| 108 |
+
if cache_start is None:
|
| 109 |
+
cache_start = current_start
|
| 110 |
+
|
| 111 |
+
# query, key, value function
|
| 112 |
+
def qkv_fn(x):
|
| 113 |
+
q = self.norm_q(self.q(x)).view(b, s, n, d) # [B, L, 12, 128]
|
| 114 |
+
k = self.norm_k(self.k(x)).view(b, s, n, d) # [B, L, 12, 128]
|
| 115 |
+
v = self.v(x).view(b, s, n, d) # [B, L, 12, 128]
|
| 116 |
+
return q, k, v
|
| 117 |
+
|
| 118 |
+
q, k, v = qkv_fn(x)
|
| 119 |
+
|
| 120 |
+
if kv_cache is None:
|
| 121 |
+
# if it is teacher forcing training?
|
| 122 |
+
is_tf = (s == seq_lens[0].item() * 2)
|
| 123 |
+
if is_tf:
|
| 124 |
+
q_chunk = torch.chunk(q, 2, dim=1)
|
| 125 |
+
k_chunk = torch.chunk(k, 2, dim=1)
|
| 126 |
+
roped_query = []
|
| 127 |
+
roped_key = []
|
| 128 |
+
# rope should be same for clean and noisy parts
|
| 129 |
+
for ii in range(2):
|
| 130 |
+
rq = rope_apply(q_chunk[ii], grid_sizes, freqs).type_as(v)
|
| 131 |
+
rk = rope_apply(k_chunk[ii], grid_sizes, freqs).type_as(v)
|
| 132 |
+
roped_query.append(rq)
|
| 133 |
+
roped_key.append(rk)
|
| 134 |
+
|
| 135 |
+
roped_query = torch.cat(roped_query, dim=1)
|
| 136 |
+
roped_key = torch.cat(roped_key, dim=1)
|
| 137 |
+
|
| 138 |
+
padded_length = math.ceil(q.shape[1] / 128) * 128 - q.shape[1]
|
| 139 |
+
padded_roped_query = torch.cat(
|
| 140 |
+
[roped_query,
|
| 141 |
+
torch.zeros([q.shape[0], padded_length, q.shape[2], q.shape[3]],
|
| 142 |
+
device=q.device, dtype=v.dtype)],
|
| 143 |
+
dim=1
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
padded_roped_key = torch.cat(
|
| 147 |
+
[roped_key, torch.zeros([k.shape[0], padded_length, k.shape[2], k.shape[3]],
|
| 148 |
+
device=k.device, dtype=v.dtype)],
|
| 149 |
+
dim=1
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
padded_v = torch.cat(
|
| 153 |
+
[v, torch.zeros([v.shape[0], padded_length, v.shape[2], v.shape[3]],
|
| 154 |
+
device=v.device, dtype=v.dtype)],
|
| 155 |
+
dim=1
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
x = flex_attention(
|
| 159 |
+
query=padded_roped_query.transpose(2, 1),
|
| 160 |
+
key=padded_roped_key.transpose(2, 1),
|
| 161 |
+
value=padded_v.transpose(2, 1),
|
| 162 |
+
block_mask=block_mask
|
| 163 |
+
)[:, :, :-padded_length].transpose(2, 1)
|
| 164 |
+
|
| 165 |
+
else:
|
| 166 |
+
roped_query = rope_apply(q, grid_sizes, freqs).type_as(v)
|
| 167 |
+
roped_key = rope_apply(k, grid_sizes, freqs).type_as(v)
|
| 168 |
+
|
| 169 |
+
padded_length = math.ceil(q.shape[1] / 128) * 128 - q.shape[1]
|
| 170 |
+
padded_roped_query = torch.cat(
|
| 171 |
+
[roped_query,
|
| 172 |
+
torch.zeros([q.shape[0], padded_length, q.shape[2], q.shape[3]],
|
| 173 |
+
device=q.device, dtype=v.dtype)],
|
| 174 |
+
dim=1
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
padded_roped_key = torch.cat(
|
| 178 |
+
[roped_key, torch.zeros([k.shape[0], padded_length, k.shape[2], k.shape[3]],
|
| 179 |
+
device=k.device, dtype=v.dtype)],
|
| 180 |
+
dim=1
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
padded_v = torch.cat(
|
| 184 |
+
[v, torch.zeros([v.shape[0], padded_length, v.shape[2], v.shape[3]],
|
| 185 |
+
device=v.device, dtype=v.dtype)],
|
| 186 |
+
dim=1
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
x = flex_attention(
|
| 190 |
+
query=padded_roped_query.transpose(2, 1),
|
| 191 |
+
key=padded_roped_key.transpose(2, 1),
|
| 192 |
+
value=padded_v.transpose(2, 1),
|
| 193 |
+
block_mask=block_mask
|
| 194 |
+
)[:, :, :-padded_length].transpose(2, 1)
|
| 195 |
+
else:
|
| 196 |
+
frame_seqlen = math.prod(grid_sizes[0][1:]).item()
|
| 197 |
+
current_start_frame = current_start // frame_seqlen
|
| 198 |
+
roped_query = causal_rope_apply(
|
| 199 |
+
q, grid_sizes, freqs, start_frame=current_start_frame).type_as(v) # [B, L, 12, 128]
|
| 200 |
+
roped_key = causal_rope_apply(
|
| 201 |
+
k, grid_sizes, freqs, start_frame=current_start_frame).type_as(v) # [B, L, 12, 128]
|
| 202 |
+
|
| 203 |
+
grid_sizes_one_block = grid_sizes.clone()
|
| 204 |
+
grid_sizes_one_block[:,0] = 3
|
| 205 |
+
|
| 206 |
+
# only caching the first block
|
| 207 |
+
cache_end = cache_start + self.block_length
|
| 208 |
+
num_new_tokens = cache_end - kv_cache["global_end_index"].item()
|
| 209 |
+
kv_cache_size = kv_cache["k"].shape[1]
|
| 210 |
+
|
| 211 |
+
sink_tokens = 1 * self.block_length # we keep the first block in the cache
|
| 212 |
+
|
| 213 |
+
if (num_new_tokens > 0) and (
|
| 214 |
+
num_new_tokens + kv_cache["local_end_index"].item() > kv_cache_size):
|
| 215 |
+
num_evicted_tokens = num_new_tokens + kv_cache["local_end_index"].item() - kv_cache_size
|
| 216 |
+
num_rolled_tokens = kv_cache["local_end_index"].item() - num_evicted_tokens - sink_tokens
|
| 217 |
+
kv_cache["k"][:, sink_tokens:sink_tokens + num_rolled_tokens] = \
|
| 218 |
+
kv_cache["k"][:, sink_tokens + num_evicted_tokens:sink_tokens + num_evicted_tokens + num_rolled_tokens].clone()
|
| 219 |
+
kv_cache["v"][:, sink_tokens:sink_tokens + num_rolled_tokens] = \
|
| 220 |
+
kv_cache["v"][:, sink_tokens + num_evicted_tokens:sink_tokens + num_evicted_tokens + num_rolled_tokens].clone()
|
| 221 |
+
|
| 222 |
+
local_end_index = kv_cache["local_end_index"].item() + cache_end - \
|
| 223 |
+
kv_cache["global_end_index"].item() - num_evicted_tokens
|
| 224 |
+
local_start_index = local_end_index - self.block_length
|
| 225 |
+
kv_cache["k"][:, local_start_index:local_end_index] = roped_key[:, :self.block_length]
|
| 226 |
+
kv_cache["v"][:, local_start_index:local_end_index] = v[:, :self.block_length]
|
| 227 |
+
else:
|
| 228 |
+
local_end_index = kv_cache["local_end_index"].item() + cache_end - kv_cache["global_end_index"].item()
|
| 229 |
+
local_start_index = local_end_index - self.block_length
|
| 230 |
+
if local_start_index == 0: # first block is not roped in the cache
|
| 231 |
+
kv_cache["k"][:, local_start_index:local_end_index] = k[:, :self.block_length]
|
| 232 |
+
else:
|
| 233 |
+
kv_cache["k"][:, local_start_index:local_end_index] = roped_key[:, :self.block_length]
|
| 234 |
+
|
| 235 |
+
kv_cache["v"][:, local_start_index:local_end_index] = v[:, :self.block_length]
|
| 236 |
+
|
| 237 |
+
if num_new_tokens > 0: # prevent updating when caching clean frame
|
| 238 |
+
kv_cache["global_end_index"].fill_(cache_end)
|
| 239 |
+
kv_cache["local_end_index"].fill_(local_end_index)
|
| 240 |
+
|
| 241 |
+
if local_start_index == 0:
|
| 242 |
+
# no kv attn with cache
|
| 243 |
+
x = attention(
|
| 244 |
+
roped_query,
|
| 245 |
+
roped_key,
|
| 246 |
+
v)
|
| 247 |
+
else:
|
| 248 |
+
if updating_cache: # updating working cache with clean frame
|
| 249 |
+
extract_cache_end = local_end_index
|
| 250 |
+
extract_cache_start = max(0, local_end_index-self.max_attention_size)
|
| 251 |
+
working_cache_key = kv_cache["k"][:, extract_cache_start:extract_cache_end].clone()
|
| 252 |
+
working_cache_v = kv_cache["v"][:, extract_cache_start:extract_cache_end]
|
| 253 |
+
|
| 254 |
+
if extract_cache_start == 0: # rope the global first block in working cache
|
| 255 |
+
working_cache_key[:,:self.block_length] = causal_rope_apply(
|
| 256 |
+
working_cache_key[:,:self.block_length], grid_sizes_one_block, freqs, start_frame=0).type_as(v)
|
| 257 |
+
|
| 258 |
+
x = attention(
|
| 259 |
+
roped_query,
|
| 260 |
+
working_cache_key,
|
| 261 |
+
working_cache_v
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
else:
|
| 265 |
+
# 1. extract working cache
|
| 266 |
+
# calculate the length of working cache
|
| 267 |
+
query_length = roped_query.shape[1]
|
| 268 |
+
working_cache_max_length = self.max_attention_size - query_length - self.block_length
|
| 269 |
+
|
| 270 |
+
extract_cache_end = local_start_index
|
| 271 |
+
extract_cache_start = max(self.block_length, local_start_index - working_cache_max_length) # working cache does not include the first anchor block
|
| 272 |
+
working_cache_key = kv_cache["k"][:, extract_cache_start:extract_cache_end]
|
| 273 |
+
working_cache_v = kv_cache["v"][:, extract_cache_start:extract_cache_end]
|
| 274 |
+
|
| 275 |
+
# 2. extract anchor cache, roped as the past frame
|
| 276 |
+
working_cache_frame_length = working_cache_key.shape[1] // self.frame_length
|
| 277 |
+
rope_start_frame = current_start_frame - working_cache_frame_length - 3
|
| 278 |
+
|
| 279 |
+
anchor_cache_key = causal_rope_apply(
|
| 280 |
+
kv_cache["k"][:, :self.block_length], grid_sizes_one_block, freqs, start_frame=rope_start_frame).type_as(v)
|
| 281 |
+
anchor_cache_v = kv_cache["v"][:, :self.block_length]
|
| 282 |
+
|
| 283 |
+
# 3. attention with working cache and anchor cache
|
| 284 |
+
input_key = torch.cat([
|
| 285 |
+
anchor_cache_key,
|
| 286 |
+
working_cache_key,
|
| 287 |
+
roped_key
|
| 288 |
+
], dim=1)
|
| 289 |
+
|
| 290 |
+
input_v = torch.cat([
|
| 291 |
+
anchor_cache_v,
|
| 292 |
+
working_cache_v,
|
| 293 |
+
v
|
| 294 |
+
], dim=1)
|
| 295 |
+
|
| 296 |
+
x = attention(
|
| 297 |
+
roped_query,
|
| 298 |
+
input_key,
|
| 299 |
+
input_v
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
# output
|
| 304 |
+
x = x.flatten(2)
|
| 305 |
+
x = self.o(x)
|
| 306 |
+
return x
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
class CausalWanAttentionBlock(nn.Module):
|
| 310 |
+
|
| 311 |
+
def __init__(self,
|
| 312 |
+
cross_attn_type,
|
| 313 |
+
dim,
|
| 314 |
+
ffn_dim,
|
| 315 |
+
num_heads,
|
| 316 |
+
local_attn_size=-1,
|
| 317 |
+
sink_size=0,
|
| 318 |
+
qk_norm=True,
|
| 319 |
+
cross_attn_norm=False,
|
| 320 |
+
eps=1e-6):
|
| 321 |
+
super().__init__()
|
| 322 |
+
self.dim = dim
|
| 323 |
+
self.ffn_dim = ffn_dim
|
| 324 |
+
self.num_heads = num_heads
|
| 325 |
+
self.local_attn_size = local_attn_size
|
| 326 |
+
self.qk_norm = qk_norm
|
| 327 |
+
self.cross_attn_norm = cross_attn_norm
|
| 328 |
+
self.eps = eps
|
| 329 |
+
|
| 330 |
+
# layers
|
| 331 |
+
self.norm1 = WanLayerNorm(dim, eps)
|
| 332 |
+
self.self_attn = CausalWanSelfAttention(dim, num_heads, local_attn_size, sink_size, qk_norm, eps)
|
| 333 |
+
self.norm3 = WanLayerNorm(
|
| 334 |
+
dim, eps,
|
| 335 |
+
elementwise_affine=True) if cross_attn_norm else nn.Identity()
|
| 336 |
+
self.cross_attn = WAN_CROSSATTENTION_CLASSES[cross_attn_type](dim,
|
| 337 |
+
num_heads,
|
| 338 |
+
(-1, -1),
|
| 339 |
+
qk_norm,
|
| 340 |
+
eps)
|
| 341 |
+
self.norm2 = WanLayerNorm(dim, eps)
|
| 342 |
+
self.ffn = nn.Sequential(
|
| 343 |
+
nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'),
|
| 344 |
+
nn.Linear(ffn_dim, dim))
|
| 345 |
+
|
| 346 |
+
# modulation
|
| 347 |
+
self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
|
| 348 |
+
|
| 349 |
+
def forward(
|
| 350 |
+
self,
|
| 351 |
+
x,
|
| 352 |
+
e,
|
| 353 |
+
seq_lens,
|
| 354 |
+
grid_sizes,
|
| 355 |
+
freqs,
|
| 356 |
+
context,
|
| 357 |
+
context_lens,
|
| 358 |
+
block_mask,
|
| 359 |
+
updating_cache=False,
|
| 360 |
+
kv_cache=None,
|
| 361 |
+
crossattn_cache=None,
|
| 362 |
+
current_start=0,
|
| 363 |
+
cache_start=None
|
| 364 |
+
):
|
| 365 |
+
r"""
|
| 366 |
+
Args:
|
| 367 |
+
x(Tensor): Shape [B, L, C]
|
| 368 |
+
e(Tensor): Shape [B, F, 6, C]
|
| 369 |
+
seq_lens(Tensor): Shape [B], length of each sequence in batch
|
| 370 |
+
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
|
| 371 |
+
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
|
| 372 |
+
"""
|
| 373 |
+
num_frames, frame_seqlen = e.shape[1], x.shape[1] // e.shape[1]
|
| 374 |
+
# assert e.dtype == torch.float32
|
| 375 |
+
# with amp.autocast(dtype=torch.float32):
|
| 376 |
+
e = (self.modulation.unsqueeze(1) + e).chunk(6, dim=2)
|
| 377 |
+
# assert e[0].dtype == torch.float32
|
| 378 |
+
|
| 379 |
+
# self-attention
|
| 380 |
+
y = self.self_attn(
|
| 381 |
+
(self.norm1(x).unflatten(dim=1, sizes=(num_frames, frame_seqlen)) * (1 + e[1]) + e[0]).flatten(1, 2),
|
| 382 |
+
seq_lens, grid_sizes,
|
| 383 |
+
freqs, block_mask, kv_cache, current_start, cache_start, updating_cache=updating_cache)
|
| 384 |
+
|
| 385 |
+
# with amp.autocast(dtype=torch.float32):
|
| 386 |
+
x = x + (y.unflatten(dim=1, sizes=(num_frames, frame_seqlen)) * e[2]).flatten(1, 2)
|
| 387 |
+
|
| 388 |
+
# cross-attention & ffn function
|
| 389 |
+
def cross_attn_ffn(x, context, context_lens, e, crossattn_cache=None):
|
| 390 |
+
x = x + self.cross_attn(self.norm3(x), context,
|
| 391 |
+
context_lens, crossattn_cache=crossattn_cache)
|
| 392 |
+
y = self.ffn(
|
| 393 |
+
(self.norm2(x).unflatten(dim=1, sizes=(num_frames,
|
| 394 |
+
frame_seqlen)) * (1 + e[4]) + e[3]).flatten(1, 2)
|
| 395 |
+
)
|
| 396 |
+
# with amp.autocast(dtype=torch.float32):
|
| 397 |
+
x = x + (y.unflatten(dim=1, sizes=(num_frames,
|
| 398 |
+
frame_seqlen)) * e[5]).flatten(1, 2)
|
| 399 |
+
return x
|
| 400 |
+
|
| 401 |
+
x = cross_attn_ffn(x, context, context_lens, e, crossattn_cache)
|
| 402 |
+
return x
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
class CausalHead(nn.Module):
|
| 406 |
+
|
| 407 |
+
def __init__(self, dim, out_dim, patch_size, eps=1e-6):
|
| 408 |
+
super().__init__()
|
| 409 |
+
self.dim = dim
|
| 410 |
+
self.out_dim = out_dim
|
| 411 |
+
self.patch_size = patch_size
|
| 412 |
+
self.eps = eps
|
| 413 |
+
|
| 414 |
+
# layers
|
| 415 |
+
out_dim = math.prod(patch_size) * out_dim
|
| 416 |
+
self.norm = WanLayerNorm(dim, eps)
|
| 417 |
+
self.head = nn.Linear(dim, out_dim)
|
| 418 |
+
|
| 419 |
+
# modulation
|
| 420 |
+
self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
|
| 421 |
+
|
| 422 |
+
def forward(self, x, e):
|
| 423 |
+
r"""
|
| 424 |
+
Args:
|
| 425 |
+
x(Tensor): Shape [B, L1, C]
|
| 426 |
+
e(Tensor): Shape [B, F, 1, C]
|
| 427 |
+
"""
|
| 428 |
+
# assert e.dtype == torch.float32
|
| 429 |
+
# with amp.autocast(dtype=torch.float32):
|
| 430 |
+
num_frames, frame_seqlen = e.shape[1], x.shape[1] // e.shape[1]
|
| 431 |
+
e = (self.modulation.unsqueeze(1) + e).chunk(2, dim=2)
|
| 432 |
+
x = (self.head(self.norm(x).unflatten(dim=1, sizes=(num_frames, frame_seqlen)) * (1 + e[1]) + e[0]))
|
| 433 |
+
return x
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
class CausalWanModel(ModelMixin, ConfigMixin):
|
| 437 |
+
r"""
|
| 438 |
+
Wan diffusion backbone supporting both text-to-video and image-to-video.
|
| 439 |
+
"""
|
| 440 |
+
|
| 441 |
+
ignore_for_config = [
|
| 442 |
+
'patch_size', 'cross_attn_norm', 'qk_norm', 'text_dim'
|
| 443 |
+
]
|
| 444 |
+
_no_split_modules = ['WanAttentionBlock']
|
| 445 |
+
_supports_gradient_checkpointing = True
|
| 446 |
+
|
| 447 |
+
@register_to_config
|
| 448 |
+
def __init__(self,
|
| 449 |
+
model_type='t2v',
|
| 450 |
+
patch_size=(1, 2, 2),
|
| 451 |
+
text_len=512,
|
| 452 |
+
in_dim=16,
|
| 453 |
+
dim=2048,
|
| 454 |
+
ffn_dim=8192,
|
| 455 |
+
freq_dim=256,
|
| 456 |
+
text_dim=4096,
|
| 457 |
+
out_dim=16,
|
| 458 |
+
num_heads=16,
|
| 459 |
+
num_layers=32,
|
| 460 |
+
local_attn_size=-1,
|
| 461 |
+
sink_size=0,
|
| 462 |
+
qk_norm=True,
|
| 463 |
+
cross_attn_norm=True,
|
| 464 |
+
eps=1e-6):
|
| 465 |
+
r"""
|
| 466 |
+
Initialize the diffusion model backbone.
|
| 467 |
+
|
| 468 |
+
Args:
|
| 469 |
+
model_type (`str`, *optional*, defaults to 't2v'):
|
| 470 |
+
Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video)
|
| 471 |
+
patch_size (`tuple`, *optional*, defaults to (1, 2, 2)):
|
| 472 |
+
3D patch dimensions for video embedding (t_patch, h_patch, w_patch)
|
| 473 |
+
text_len (`int`, *optional*, defaults to 512):
|
| 474 |
+
Fixed length for text embeddings
|
| 475 |
+
in_dim (`int`, *optional*, defaults to 16):
|
| 476 |
+
Input video channels (C_in)
|
| 477 |
+
dim (`int`, *optional*, defaults to 2048):
|
| 478 |
+
Hidden dimension of the transformer
|
| 479 |
+
ffn_dim (`int`, *optional*, defaults to 8192):
|
| 480 |
+
Intermediate dimension in feed-forward network
|
| 481 |
+
freq_dim (`int`, *optional*, defaults to 256):
|
| 482 |
+
Dimension for sinusoidal time embeddings
|
| 483 |
+
text_dim (`int`, *optional*, defaults to 4096):
|
| 484 |
+
Input dimension for text embeddings
|
| 485 |
+
out_dim (`int`, *optional*, defaults to 16):
|
| 486 |
+
Output video channels (C_out)
|
| 487 |
+
num_heads (`int`, *optional*, defaults to 16):
|
| 488 |
+
Number of attention heads
|
| 489 |
+
num_layers (`int`, *optional*, defaults to 32):
|
| 490 |
+
Number of transformer blocks
|
| 491 |
+
local_attn_size (`int`, *optional*, defaults to -1):
|
| 492 |
+
Window size for temporal local attention (-1 indicates global attention)
|
| 493 |
+
sink_size (`int`, *optional*, defaults to 0):
|
| 494 |
+
Size of the attention sink, we keep the first `sink_size` frames unchanged when rolling the KV cache
|
| 495 |
+
qk_norm (`bool`, *optional*, defaults to True):
|
| 496 |
+
Enable query/key normalization
|
| 497 |
+
cross_attn_norm (`bool`, *optional*, defaults to False):
|
| 498 |
+
Enable cross-attention normalization
|
| 499 |
+
eps (`float`, *optional*, defaults to 1e-6):
|
| 500 |
+
Epsilon value for normalization layers
|
| 501 |
+
"""
|
| 502 |
+
|
| 503 |
+
super().__init__()
|
| 504 |
+
|
| 505 |
+
assert model_type in ['t2v', 'i2v']
|
| 506 |
+
self.model_type = model_type
|
| 507 |
+
|
| 508 |
+
self.patch_size = patch_size
|
| 509 |
+
self.text_len = text_len
|
| 510 |
+
self.in_dim = in_dim
|
| 511 |
+
self.dim = dim
|
| 512 |
+
self.ffn_dim = ffn_dim
|
| 513 |
+
self.freq_dim = freq_dim
|
| 514 |
+
self.text_dim = text_dim
|
| 515 |
+
self.out_dim = out_dim
|
| 516 |
+
self.num_heads = num_heads
|
| 517 |
+
self.num_layers = num_layers
|
| 518 |
+
self.local_attn_size = local_attn_size
|
| 519 |
+
self.qk_norm = qk_norm
|
| 520 |
+
self.cross_attn_norm = cross_attn_norm
|
| 521 |
+
self.eps = eps
|
| 522 |
+
|
| 523 |
+
# embeddings
|
| 524 |
+
self.patch_embedding = nn.Conv3d(
|
| 525 |
+
in_dim, dim, kernel_size=patch_size, stride=patch_size)
|
| 526 |
+
self.text_embedding = nn.Sequential(
|
| 527 |
+
nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'),
|
| 528 |
+
nn.Linear(dim, dim))
|
| 529 |
+
|
| 530 |
+
self.time_embedding = nn.Sequential(
|
| 531 |
+
nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
|
| 532 |
+
self.time_projection = nn.Sequential(
|
| 533 |
+
nn.SiLU(), nn.Linear(dim, dim * 6))
|
| 534 |
+
|
| 535 |
+
# blocks
|
| 536 |
+
cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn'
|
| 537 |
+
self.blocks = nn.ModuleList([
|
| 538 |
+
CausalWanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads,
|
| 539 |
+
local_attn_size, sink_size, qk_norm, cross_attn_norm, eps)
|
| 540 |
+
for _ in range(num_layers)
|
| 541 |
+
])
|
| 542 |
+
|
| 543 |
+
# head
|
| 544 |
+
self.head = CausalHead(dim, out_dim, patch_size, eps)
|
| 545 |
+
|
| 546 |
+
# buffers (don't use register_buffer otherwise dtype will be changed in to())
|
| 547 |
+
assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
|
| 548 |
+
d = dim // num_heads
|
| 549 |
+
self.freqs = torch.cat([
|
| 550 |
+
rope_params(1024, d - 4 * (d // 6)),
|
| 551 |
+
rope_params(1024, 2 * (d // 6)),
|
| 552 |
+
rope_params(1024, 2 * (d // 6))
|
| 553 |
+
],
|
| 554 |
+
dim=1)
|
| 555 |
+
|
| 556 |
+
if model_type == 'i2v':
|
| 557 |
+
self.img_emb = MLPProj(1280, dim)
|
| 558 |
+
|
| 559 |
+
# initialize weights
|
| 560 |
+
self.init_weights()
|
| 561 |
+
|
| 562 |
+
self.gradient_checkpointing = False
|
| 563 |
+
|
| 564 |
+
self.block_mask = None
|
| 565 |
+
|
| 566 |
+
self.num_frame_per_block = 1
|
| 567 |
+
self.independent_first_frame = False
|
| 568 |
+
|
| 569 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
| 570 |
+
self.gradient_checkpointing = value
|
| 571 |
+
|
| 572 |
+
@staticmethod
|
| 573 |
+
def _prepare_blockwise_causal_attn_mask(
|
| 574 |
+
device: torch.device | str, num_frames: int = 21,
|
| 575 |
+
frame_seqlen: int = 1560, num_frame_per_block=1, local_attn_size=-1
|
| 576 |
+
):
|
| 577 |
+
"""
|
| 578 |
+
we will divide the token sequence into the following format
|
| 579 |
+
[1 latent frame] [1 latent frame] ... [1 latent frame]
|
| 580 |
+
We use flexattention to construct the attention mask
|
| 581 |
+
"""
|
| 582 |
+
total_length = num_frames * frame_seqlen
|
| 583 |
+
|
| 584 |
+
# we do right padding to get to a multiple of 128
|
| 585 |
+
padded_length = math.ceil(total_length / 128) * 128 - total_length
|
| 586 |
+
|
| 587 |
+
ends = torch.zeros(total_length + padded_length,
|
| 588 |
+
device=device, dtype=torch.long)
|
| 589 |
+
|
| 590 |
+
# Block-wise causal mask will attend to all elements that are before the end of the current chunk
|
| 591 |
+
frame_indices = torch.arange(
|
| 592 |
+
start=0,
|
| 593 |
+
end=total_length,
|
| 594 |
+
step=frame_seqlen * num_frame_per_block,
|
| 595 |
+
device=device
|
| 596 |
+
)
|
| 597 |
+
|
| 598 |
+
for tmp in frame_indices:
|
| 599 |
+
ends[tmp:tmp + frame_seqlen * num_frame_per_block] = tmp + \
|
| 600 |
+
frame_seqlen * num_frame_per_block
|
| 601 |
+
|
| 602 |
+
def attention_mask(b, h, q_idx, kv_idx):
|
| 603 |
+
if local_attn_size == -1:
|
| 604 |
+
return (kv_idx < ends[q_idx]) | (q_idx == kv_idx)
|
| 605 |
+
else:
|
| 606 |
+
return ((kv_idx < ends[q_idx]) & (kv_idx >= (ends[q_idx] - local_attn_size * frame_seqlen))) | (q_idx == kv_idx)
|
| 607 |
+
# return ((kv_idx < total_length) & (q_idx < total_length)) | (q_idx == kv_idx) # bidirectional mask
|
| 608 |
+
|
| 609 |
+
block_mask = create_block_mask(attention_mask, B=None, H=None, Q_LEN=total_length + padded_length,
|
| 610 |
+
KV_LEN=total_length + padded_length, _compile=False, device=device)
|
| 611 |
+
|
| 612 |
+
import torch.distributed as dist
|
| 613 |
+
if not dist.is_initialized() or dist.get_rank() == 0:
|
| 614 |
+
print(
|
| 615 |
+
f" cache a block wise causal mask with block size of {num_frame_per_block} frames")
|
| 616 |
+
print(block_mask)
|
| 617 |
+
|
| 618 |
+
# import imageio
|
| 619 |
+
# import numpy as np
|
| 620 |
+
# from torch.nn.attention.flex_attention import create_mask
|
| 621 |
+
|
| 622 |
+
# mask = create_mask(attention_mask, B=None, H=None, Q_LEN=total_length +
|
| 623 |
+
# padded_length, KV_LEN=total_length + padded_length, device=device)
|
| 624 |
+
# import cv2
|
| 625 |
+
# mask = cv2.resize(mask[0, 0].cpu().float().numpy(), (1024, 1024))
|
| 626 |
+
# imageio.imwrite("mask_%d.jpg" % (0), np.uint8(255. * mask))
|
| 627 |
+
|
| 628 |
+
return block_mask
|
| 629 |
+
|
| 630 |
+
@staticmethod
|
| 631 |
+
def _prepare_teacher_forcing_mask(
|
| 632 |
+
device: torch.device | str, num_frames: int = 21,
|
| 633 |
+
frame_seqlen: int = 1560, num_frame_per_block=1
|
| 634 |
+
):
|
| 635 |
+
"""
|
| 636 |
+
we will divide the token sequence into the following format
|
| 637 |
+
[1 latent frame] [1 latent frame] ... [1 latent frame]
|
| 638 |
+
We use flexattention to construct the attention mask
|
| 639 |
+
"""
|
| 640 |
+
# debug
|
| 641 |
+
DEBUG = False
|
| 642 |
+
if DEBUG:
|
| 643 |
+
num_frames = 9
|
| 644 |
+
frame_seqlen = 256
|
| 645 |
+
|
| 646 |
+
total_length = num_frames * frame_seqlen * 2
|
| 647 |
+
|
| 648 |
+
# we do right padding to get to a multiple of 128
|
| 649 |
+
padded_length = math.ceil(total_length / 128) * 128 - total_length
|
| 650 |
+
|
| 651 |
+
clean_ends = num_frames * frame_seqlen
|
| 652 |
+
# for clean context frames, we can construct their flex attention mask based on a [start, end] interval
|
| 653 |
+
context_ends = torch.zeros(total_length + padded_length, device=device, dtype=torch.long)
|
| 654 |
+
# for noisy frames, we need two intervals to construct the flex attention mask [context_start, context_end] [noisy_start, noisy_end]
|
| 655 |
+
noise_context_starts = torch.zeros(total_length + padded_length, device=device, dtype=torch.long)
|
| 656 |
+
noise_context_ends = torch.zeros(total_length + padded_length, device=device, dtype=torch.long)
|
| 657 |
+
noise_noise_starts = torch.zeros(total_length + padded_length, device=device, dtype=torch.long)
|
| 658 |
+
noise_noise_ends = torch.zeros(total_length + padded_length, device=device, dtype=torch.long)
|
| 659 |
+
|
| 660 |
+
# Block-wise causal mask will attend to all elements that are before the end of the current chunk
|
| 661 |
+
attention_block_size = frame_seqlen * num_frame_per_block
|
| 662 |
+
frame_indices = torch.arange(
|
| 663 |
+
start=0,
|
| 664 |
+
end=num_frames * frame_seqlen,
|
| 665 |
+
step=attention_block_size,
|
| 666 |
+
device=device, dtype=torch.long
|
| 667 |
+
)
|
| 668 |
+
|
| 669 |
+
# attention for clean context frames
|
| 670 |
+
for start in frame_indices:
|
| 671 |
+
context_ends[start:start + attention_block_size] = start + attention_block_size
|
| 672 |
+
|
| 673 |
+
noisy_image_start_list = torch.arange(
|
| 674 |
+
num_frames * frame_seqlen, total_length,
|
| 675 |
+
step=attention_block_size,
|
| 676 |
+
device=device, dtype=torch.long
|
| 677 |
+
)
|
| 678 |
+
noisy_image_end_list = noisy_image_start_list + attention_block_size
|
| 679 |
+
|
| 680 |
+
# attention for noisy frames
|
| 681 |
+
for block_index, (start, end) in enumerate(zip(noisy_image_start_list, noisy_image_end_list)):
|
| 682 |
+
# attend to noisy tokens within the same block
|
| 683 |
+
noise_noise_starts[start:end] = start
|
| 684 |
+
noise_noise_ends[start:end] = end
|
| 685 |
+
# attend to context tokens in previous blocks
|
| 686 |
+
# noise_context_starts[start:end] = 0
|
| 687 |
+
noise_context_ends[start:end] = block_index * attention_block_size
|
| 688 |
+
|
| 689 |
+
def attention_mask(b, h, q_idx, kv_idx):
|
| 690 |
+
# first design the mask for clean frames
|
| 691 |
+
clean_mask = (q_idx < clean_ends) & (kv_idx < context_ends[q_idx])
|
| 692 |
+
# then design the mask for noisy frames
|
| 693 |
+
# noisy frames will attend to all clean preceeding clean frames + itself
|
| 694 |
+
C1 = (kv_idx < noise_noise_ends[q_idx]) & (kv_idx >= noise_noise_starts[q_idx])
|
| 695 |
+
C2 = (kv_idx < noise_context_ends[q_idx]) & (kv_idx >= noise_context_starts[q_idx])
|
| 696 |
+
noise_mask = (q_idx >= clean_ends) & (C1 | C2)
|
| 697 |
+
|
| 698 |
+
eye_mask = q_idx == kv_idx
|
| 699 |
+
return eye_mask | clean_mask | noise_mask
|
| 700 |
+
|
| 701 |
+
block_mask = create_block_mask(attention_mask, B=None, H=None, Q_LEN=total_length + padded_length,
|
| 702 |
+
KV_LEN=total_length + padded_length, _compile=False, device=device)
|
| 703 |
+
|
| 704 |
+
if DEBUG:
|
| 705 |
+
print(block_mask)
|
| 706 |
+
import imageio
|
| 707 |
+
import numpy as np
|
| 708 |
+
from torch.nn.attention.flex_attention import create_mask
|
| 709 |
+
|
| 710 |
+
mask = create_mask(attention_mask, B=None, H=None, Q_LEN=total_length +
|
| 711 |
+
padded_length, KV_LEN=total_length + padded_length, device=device)
|
| 712 |
+
import cv2
|
| 713 |
+
mask = cv2.resize(mask[0, 0].cpu().float().numpy(), (1024, 1024))
|
| 714 |
+
imageio.imwrite("mask_%d.jpg" % (0), np.uint8(255. * mask))
|
| 715 |
+
|
| 716 |
+
return block_mask
|
| 717 |
+
|
| 718 |
+
@staticmethod
|
| 719 |
+
def _prepare_blockwise_causal_attn_mask_i2v(
|
| 720 |
+
device: torch.device | str, num_frames: int = 21,
|
| 721 |
+
frame_seqlen: int = 1560, num_frame_per_block=4, local_attn_size=-1
|
| 722 |
+
):
|
| 723 |
+
"""
|
| 724 |
+
we will divide the token sequence into the following format
|
| 725 |
+
[1 latent frame] [N latent frame] ... [N latent frame]
|
| 726 |
+
The first frame is separated out to support I2V generation
|
| 727 |
+
We use flexattention to construct the attention mask
|
| 728 |
+
"""
|
| 729 |
+
total_length = num_frames * frame_seqlen
|
| 730 |
+
|
| 731 |
+
# we do right padding to get to a multiple of 128
|
| 732 |
+
padded_length = math.ceil(total_length / 128) * 128 - total_length
|
| 733 |
+
|
| 734 |
+
ends = torch.zeros(total_length + padded_length,
|
| 735 |
+
device=device, dtype=torch.long)
|
| 736 |
+
|
| 737 |
+
# special handling for the first frame
|
| 738 |
+
ends[:frame_seqlen] = frame_seqlen
|
| 739 |
+
|
| 740 |
+
# Block-wise causal mask will attend to all elements that are before the end of the current chunk
|
| 741 |
+
frame_indices = torch.arange(
|
| 742 |
+
start=frame_seqlen,
|
| 743 |
+
end=total_length,
|
| 744 |
+
step=frame_seqlen * num_frame_per_block,
|
| 745 |
+
device=device
|
| 746 |
+
)
|
| 747 |
+
|
| 748 |
+
for idx, tmp in enumerate(frame_indices):
|
| 749 |
+
ends[tmp:tmp + frame_seqlen * num_frame_per_block] = tmp + \
|
| 750 |
+
frame_seqlen * num_frame_per_block
|
| 751 |
+
|
| 752 |
+
def attention_mask(b, h, q_idx, kv_idx):
|
| 753 |
+
if local_attn_size == -1:
|
| 754 |
+
return (kv_idx < ends[q_idx]) | (q_idx == kv_idx)
|
| 755 |
+
else:
|
| 756 |
+
return ((kv_idx < ends[q_idx]) & (kv_idx >= (ends[q_idx] - local_attn_size * frame_seqlen))) | \
|
| 757 |
+
(q_idx == kv_idx)
|
| 758 |
+
|
| 759 |
+
block_mask = create_block_mask(attention_mask, B=None, H=None, Q_LEN=total_length + padded_length,
|
| 760 |
+
KV_LEN=total_length + padded_length, _compile=False, device=device)
|
| 761 |
+
|
| 762 |
+
if not dist.is_initialized() or dist.get_rank() == 0:
|
| 763 |
+
print(
|
| 764 |
+
f" cache a block wise causal mask with block size of {num_frame_per_block} frames")
|
| 765 |
+
print(block_mask)
|
| 766 |
+
|
| 767 |
+
# import imageio
|
| 768 |
+
# import numpy as np
|
| 769 |
+
# from torch.nn.attention.flex_attention import create_mask
|
| 770 |
+
|
| 771 |
+
# mask = create_mask(attention_mask, B=None, H=None, Q_LEN=total_length +
|
| 772 |
+
# padded_length, KV_LEN=total_length + padded_length, device=device)
|
| 773 |
+
# import cv2
|
| 774 |
+
# mask = cv2.resize(mask[0, 0].cpu().float().numpy(), (1024, 1024))
|
| 775 |
+
# imageio.imwrite("mask_%d.jpg" % (0), np.uint8(255. * mask))
|
| 776 |
+
|
| 777 |
+
return block_mask
|
| 778 |
+
|
| 779 |
+
def _forward_inference(
|
| 780 |
+
self,
|
| 781 |
+
x,
|
| 782 |
+
t,
|
| 783 |
+
context,
|
| 784 |
+
seq_len,
|
| 785 |
+
updating_cache=False,
|
| 786 |
+
clip_fea=None,
|
| 787 |
+
y=None,
|
| 788 |
+
kv_cache: dict = None,
|
| 789 |
+
crossattn_cache: dict = None,
|
| 790 |
+
current_start: int = 0,
|
| 791 |
+
cache_start: int = 0,
|
| 792 |
+
):
|
| 793 |
+
r"""
|
| 794 |
+
Run the diffusion model with kv caching.
|
| 795 |
+
See Algorithm 2 of CausVid paper https://arxiv.org/abs/2412.07772 for details.
|
| 796 |
+
This function will be run for num_frame times.
|
| 797 |
+
Process the latent frames one by one (1560 tokens each)
|
| 798 |
+
|
| 799 |
+
Args:
|
| 800 |
+
x (List[Tensor]):
|
| 801 |
+
List of input video tensors, each with shape [C_in, F, H, W]
|
| 802 |
+
t (Tensor):
|
| 803 |
+
Diffusion timesteps tensor of shape [B]
|
| 804 |
+
context (List[Tensor]):
|
| 805 |
+
List of text embeddings each with shape [L, C]
|
| 806 |
+
seq_len (`int`):
|
| 807 |
+
Maximum sequence length for positional encoding
|
| 808 |
+
clip_fea (Tensor, *optional*):
|
| 809 |
+
CLIP image features for image-to-video mode
|
| 810 |
+
y (List[Tensor], *optional*):
|
| 811 |
+
Conditional video inputs for image-to-video mode, same shape as x
|
| 812 |
+
|
| 813 |
+
Returns:
|
| 814 |
+
List[Tensor]:
|
| 815 |
+
List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
|
| 816 |
+
"""
|
| 817 |
+
|
| 818 |
+
if self.model_type == 'i2v':
|
| 819 |
+
assert clip_fea is not None and y is not None
|
| 820 |
+
# params
|
| 821 |
+
device = self.patch_embedding.weight.device
|
| 822 |
+
if self.freqs.device != device:
|
| 823 |
+
self.freqs = self.freqs.to(device)
|
| 824 |
+
|
| 825 |
+
if y is not None:
|
| 826 |
+
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
|
| 827 |
+
|
| 828 |
+
# embeddings
|
| 829 |
+
x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
|
| 830 |
+
grid_sizes = torch.stack(
|
| 831 |
+
[torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
|
| 832 |
+
x = [u.flatten(2).transpose(1, 2) for u in x]
|
| 833 |
+
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
|
| 834 |
+
assert seq_lens.max() <= seq_len
|
| 835 |
+
x = torch.cat(x)
|
| 836 |
+
"""
|
| 837 |
+
torch.cat([
|
| 838 |
+
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
|
| 839 |
+
dim=1) for u in x
|
| 840 |
+
])
|
| 841 |
+
"""
|
| 842 |
+
|
| 843 |
+
# time embeddings
|
| 844 |
+
# with amp.autocast(dtype=torch.float32):
|
| 845 |
+
e = self.time_embedding(
|
| 846 |
+
sinusoidal_embedding_1d(self.freq_dim, t.flatten()).type_as(x))
|
| 847 |
+
e0 = self.time_projection(e).unflatten(
|
| 848 |
+
1, (6, self.dim)).unflatten(dim=0, sizes=t.shape)
|
| 849 |
+
# assert e.dtype == torch.float32 and e0.dtype == torch.float32
|
| 850 |
+
|
| 851 |
+
# context
|
| 852 |
+
context_lens = None
|
| 853 |
+
context = self.text_embedding(
|
| 854 |
+
torch.stack([
|
| 855 |
+
torch.cat(
|
| 856 |
+
[u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
|
| 857 |
+
for u in context
|
| 858 |
+
]))
|
| 859 |
+
|
| 860 |
+
if clip_fea is not None:
|
| 861 |
+
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
| 862 |
+
context = torch.concat([context_clip, context], dim=1)
|
| 863 |
+
|
| 864 |
+
# arguments
|
| 865 |
+
kwargs = dict(
|
| 866 |
+
e=e0,
|
| 867 |
+
seq_lens=seq_lens,
|
| 868 |
+
grid_sizes=grid_sizes,
|
| 869 |
+
freqs=self.freqs,
|
| 870 |
+
context=context,
|
| 871 |
+
context_lens=context_lens,
|
| 872 |
+
block_mask=self.block_mask,
|
| 873 |
+
updating_cache=updating_cache,
|
| 874 |
+
)
|
| 875 |
+
|
| 876 |
+
def create_custom_forward(module):
|
| 877 |
+
def custom_forward(*inputs, **kwargs):
|
| 878 |
+
return module(*inputs, **kwargs)
|
| 879 |
+
return custom_forward
|
| 880 |
+
|
| 881 |
+
for block_index, block in enumerate(self.blocks):
|
| 882 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 883 |
+
kwargs.update(
|
| 884 |
+
{
|
| 885 |
+
"kv_cache": kv_cache[block_index],
|
| 886 |
+
"current_start": current_start,
|
| 887 |
+
"cache_start": cache_start
|
| 888 |
+
}
|
| 889 |
+
)
|
| 890 |
+
x = torch.utils.checkpoint.checkpoint(
|
| 891 |
+
create_custom_forward(block),
|
| 892 |
+
x, **kwargs,
|
| 893 |
+
use_reentrant=False,
|
| 894 |
+
)
|
| 895 |
+
else:
|
| 896 |
+
kwargs.update(
|
| 897 |
+
{
|
| 898 |
+
"kv_cache": kv_cache[block_index],
|
| 899 |
+
"crossattn_cache": crossattn_cache[block_index],
|
| 900 |
+
"current_start": current_start,
|
| 901 |
+
"cache_start": cache_start
|
| 902 |
+
}
|
| 903 |
+
)
|
| 904 |
+
x = block(x, **kwargs)
|
| 905 |
+
|
| 906 |
+
# head
|
| 907 |
+
x = self.head(x, e.unflatten(dim=0, sizes=t.shape).unsqueeze(2))
|
| 908 |
+
# unpatchify
|
| 909 |
+
x = self.unpatchify(x, grid_sizes)
|
| 910 |
+
return torch.stack(x)
|
| 911 |
+
|
| 912 |
+
def _forward_train(
|
| 913 |
+
self,
|
| 914 |
+
x,
|
| 915 |
+
t,
|
| 916 |
+
context,
|
| 917 |
+
seq_len,
|
| 918 |
+
clean_x=None,
|
| 919 |
+
aug_t=None,
|
| 920 |
+
clip_fea=None,
|
| 921 |
+
y=None,
|
| 922 |
+
):
|
| 923 |
+
r"""
|
| 924 |
+
Forward pass through the diffusion model
|
| 925 |
+
|
| 926 |
+
Args:
|
| 927 |
+
x (List[Tensor]):
|
| 928 |
+
List of input video tensors, each with shape [C_in, F, H, W]
|
| 929 |
+
t (Tensor):
|
| 930 |
+
Diffusion timesteps tensor of shape [B]
|
| 931 |
+
context (List[Tensor]):
|
| 932 |
+
List of text embeddings each with shape [L, C]
|
| 933 |
+
seq_len (`int`):
|
| 934 |
+
Maximum sequence length for positional encoding
|
| 935 |
+
clip_fea (Tensor, *optional*):
|
| 936 |
+
CLIP image features for image-to-video mode
|
| 937 |
+
y (List[Tensor], *optional*):
|
| 938 |
+
Conditional video inputs for image-to-video mode, same shape as x
|
| 939 |
+
|
| 940 |
+
Returns:
|
| 941 |
+
List[Tensor]:
|
| 942 |
+
List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
|
| 943 |
+
"""
|
| 944 |
+
if self.model_type == 'i2v':
|
| 945 |
+
assert clip_fea is not None and y is not None
|
| 946 |
+
# params
|
| 947 |
+
device = self.patch_embedding.weight.device
|
| 948 |
+
if self.freqs.device != device:
|
| 949 |
+
self.freqs = self.freqs.to(device)
|
| 950 |
+
|
| 951 |
+
# Construct blockwise causal attn mask
|
| 952 |
+
if self.block_mask is None:
|
| 953 |
+
if clean_x is not None:
|
| 954 |
+
if self.independent_first_frame:
|
| 955 |
+
raise NotImplementedError()
|
| 956 |
+
else:
|
| 957 |
+
self.block_mask = self._prepare_teacher_forcing_mask(
|
| 958 |
+
device, num_frames=x.shape[2],
|
| 959 |
+
frame_seqlen=x.shape[-2] * x.shape[-1] // (self.patch_size[1] * self.patch_size[2]),
|
| 960 |
+
num_frame_per_block=self.num_frame_per_block
|
| 961 |
+
)
|
| 962 |
+
else:
|
| 963 |
+
if self.independent_first_frame:
|
| 964 |
+
self.block_mask = self._prepare_blockwise_causal_attn_mask_i2v(
|
| 965 |
+
device, num_frames=x.shape[2],
|
| 966 |
+
frame_seqlen=x.shape[-2] * x.shape[-1] // (self.patch_size[1] * self.patch_size[2]),
|
| 967 |
+
num_frame_per_block=self.num_frame_per_block,
|
| 968 |
+
local_attn_size=self.local_attn_size
|
| 969 |
+
)
|
| 970 |
+
else:
|
| 971 |
+
self.block_mask = self._prepare_blockwise_causal_attn_mask(
|
| 972 |
+
device, num_frames=x.shape[2],
|
| 973 |
+
frame_seqlen=x.shape[-2] * x.shape[-1] // (self.patch_size[1] * self.patch_size[2]),
|
| 974 |
+
num_frame_per_block=self.num_frame_per_block,
|
| 975 |
+
local_attn_size=self.local_attn_size
|
| 976 |
+
)
|
| 977 |
+
|
| 978 |
+
if y is not None:
|
| 979 |
+
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
|
| 980 |
+
|
| 981 |
+
# embeddings
|
| 982 |
+
x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
|
| 983 |
+
|
| 984 |
+
grid_sizes = torch.stack(
|
| 985 |
+
[torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
|
| 986 |
+
x = [u.flatten(2).transpose(1, 2) for u in x]
|
| 987 |
+
|
| 988 |
+
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
|
| 989 |
+
assert seq_lens.max() <= seq_len
|
| 990 |
+
x = torch.cat([
|
| 991 |
+
torch.cat([u, u.new_zeros(1, seq_lens[0] - u.size(1), u.size(2))],
|
| 992 |
+
dim=1) for u in x
|
| 993 |
+
])
|
| 994 |
+
|
| 995 |
+
# time embeddings
|
| 996 |
+
# with amp.autocast(dtype=torch.float32):
|
| 997 |
+
e = self.time_embedding(
|
| 998 |
+
sinusoidal_embedding_1d(self.freq_dim, t.flatten()).type_as(x))
|
| 999 |
+
e0 = self.time_projection(e).unflatten(
|
| 1000 |
+
1, (6, self.dim)).unflatten(dim=0, sizes=t.shape)
|
| 1001 |
+
# assert e.dtype == torch.float32 and e0.dtype == torch.float32
|
| 1002 |
+
|
| 1003 |
+
# context
|
| 1004 |
+
context_lens = None
|
| 1005 |
+
context = self.text_embedding(
|
| 1006 |
+
torch.stack([
|
| 1007 |
+
torch.cat(
|
| 1008 |
+
[u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
|
| 1009 |
+
for u in context
|
| 1010 |
+
]))
|
| 1011 |
+
|
| 1012 |
+
if clip_fea is not None:
|
| 1013 |
+
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
| 1014 |
+
context = torch.concat([context_clip, context], dim=1)
|
| 1015 |
+
|
| 1016 |
+
if clean_x is not None:
|
| 1017 |
+
clean_x = [self.patch_embedding(u.unsqueeze(0)) for u in clean_x]
|
| 1018 |
+
clean_x = [u.flatten(2).transpose(1, 2) for u in clean_x]
|
| 1019 |
+
|
| 1020 |
+
seq_lens_clean = torch.tensor([u.size(1) for u in clean_x], dtype=torch.long)
|
| 1021 |
+
assert seq_lens_clean.max() <= seq_len
|
| 1022 |
+
clean_x = torch.cat([
|
| 1023 |
+
torch.cat([u, u.new_zeros(1, seq_lens_clean[0] - u.size(1), u.size(2))], dim=1) for u in clean_x
|
| 1024 |
+
])
|
| 1025 |
+
|
| 1026 |
+
x = torch.cat([clean_x, x], dim=1)
|
| 1027 |
+
if aug_t is None:
|
| 1028 |
+
aug_t = torch.zeros_like(t)
|
| 1029 |
+
e_clean = self.time_embedding(
|
| 1030 |
+
sinusoidal_embedding_1d(self.freq_dim, aug_t.flatten()).type_as(x))
|
| 1031 |
+
e0_clean = self.time_projection(e_clean).unflatten(
|
| 1032 |
+
1, (6, self.dim)).unflatten(dim=0, sizes=t.shape)
|
| 1033 |
+
e0 = torch.cat([e0_clean, e0], dim=1)
|
| 1034 |
+
|
| 1035 |
+
# arguments
|
| 1036 |
+
kwargs = dict(
|
| 1037 |
+
e=e0,
|
| 1038 |
+
seq_lens=seq_lens,
|
| 1039 |
+
grid_sizes=grid_sizes,
|
| 1040 |
+
freqs=self.freqs,
|
| 1041 |
+
context=context,
|
| 1042 |
+
context_lens=context_lens,
|
| 1043 |
+
block_mask=self.block_mask)
|
| 1044 |
+
|
| 1045 |
+
def create_custom_forward(module):
|
| 1046 |
+
def custom_forward(*inputs, **kwargs):
|
| 1047 |
+
return module(*inputs, **kwargs)
|
| 1048 |
+
return custom_forward
|
| 1049 |
+
|
| 1050 |
+
for block in self.blocks:
|
| 1051 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 1052 |
+
x = torch.utils.checkpoint.checkpoint(
|
| 1053 |
+
create_custom_forward(block),
|
| 1054 |
+
x, **kwargs,
|
| 1055 |
+
use_reentrant=False,
|
| 1056 |
+
)
|
| 1057 |
+
else:
|
| 1058 |
+
x = block(x, **kwargs)
|
| 1059 |
+
|
| 1060 |
+
if clean_x is not None:
|
| 1061 |
+
x = x[:, x.shape[1] // 2:]
|
| 1062 |
+
|
| 1063 |
+
# head
|
| 1064 |
+
x = self.head(x, e.unflatten(dim=0, sizes=t.shape).unsqueeze(2))
|
| 1065 |
+
|
| 1066 |
+
# unpatchify
|
| 1067 |
+
x = self.unpatchify(x, grid_sizes)
|
| 1068 |
+
return torch.stack(x)
|
| 1069 |
+
|
| 1070 |
+
def forward(
|
| 1071 |
+
self,
|
| 1072 |
+
*args,
|
| 1073 |
+
**kwargs
|
| 1074 |
+
):
|
| 1075 |
+
if kwargs.get('kv_cache', None) is not None:
|
| 1076 |
+
return self._forward_inference(*args, **kwargs)
|
| 1077 |
+
else:
|
| 1078 |
+
return self._forward_train(*args, **kwargs)
|
| 1079 |
+
|
| 1080 |
+
def unpatchify(self, x, grid_sizes):
|
| 1081 |
+
r"""
|
| 1082 |
+
Reconstruct video tensors from patch embeddings.
|
| 1083 |
+
|
| 1084 |
+
Args:
|
| 1085 |
+
x (List[Tensor]):
|
| 1086 |
+
List of patchified features, each with shape [L, C_out * prod(patch_size)]
|
| 1087 |
+
grid_sizes (Tensor):
|
| 1088 |
+
Original spatial-temporal grid dimensions before patching,
|
| 1089 |
+
shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)
|
| 1090 |
+
|
| 1091 |
+
Returns:
|
| 1092 |
+
List[Tensor]:
|
| 1093 |
+
Reconstructed video tensors with shape [C_out, F, H / 8, W / 8]
|
| 1094 |
+
"""
|
| 1095 |
+
|
| 1096 |
+
c = self.out_dim
|
| 1097 |
+
out = []
|
| 1098 |
+
for u, v in zip(x, grid_sizes.tolist()):
|
| 1099 |
+
u = u[:math.prod(v)].view(*v, *self.patch_size, c)
|
| 1100 |
+
u = torch.einsum('fhwpqrc->cfphqwr', u)
|
| 1101 |
+
u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
|
| 1102 |
+
out.append(u)
|
| 1103 |
+
return out
|
| 1104 |
+
|
| 1105 |
+
def init_weights(self):
|
| 1106 |
+
r"""
|
| 1107 |
+
Initialize model parameters using Xavier initialization.
|
| 1108 |
+
"""
|
| 1109 |
+
|
| 1110 |
+
# basic init
|
| 1111 |
+
for m in self.modules():
|
| 1112 |
+
if isinstance(m, nn.Linear):
|
| 1113 |
+
nn.init.xavier_uniform_(m.weight)
|
| 1114 |
+
if m.bias is not None:
|
| 1115 |
+
nn.init.zeros_(m.bias)
|
| 1116 |
+
|
| 1117 |
+
# init embeddings
|
| 1118 |
+
nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1))
|
| 1119 |
+
for m in self.text_embedding.modules():
|
| 1120 |
+
if isinstance(m, nn.Linear):
|
| 1121 |
+
nn.init.normal_(m.weight, std=.02)
|
| 1122 |
+
for m in self.time_embedding.modules():
|
| 1123 |
+
if isinstance(m, nn.Linear):
|
| 1124 |
+
nn.init.normal_(m.weight, std=.02)
|
| 1125 |
+
|
| 1126 |
+
# init output layer
|
| 1127 |
+
nn.init.zeros_(self.head.head.weight)
|