diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..9e64b16b720bce98764b06bef9db5b4074b460c6 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+.github/workflows/logo.gif filter=lfs diff=lfs merge=lfs -text
+diffsynth/tokenizer_configs/hunyuan_video/tokenizer_2/tokenizer.json filter=lfs diff=lfs merge=lfs -text
+diffsynth/tokenizer_configs/kolors/tokenizer/vocab.txt filter=lfs diff=lfs merge=lfs -text
diff --git a/.github/workflows/logo.gif b/.github/workflows/logo.gif
new file mode 100644
index 0000000000000000000000000000000000000000..ef5717efc17bbb2018a37a530a9d7a09e86277d9
--- /dev/null
+++ b/.github/workflows/logo.gif
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:36a7627b7f0f0a508ec64aba72e5d95d38dfe7958bd8cf42d2a63f6ac2641529
+size 149067
diff --git a/.github/workflows/publish.yaml b/.github/workflows/publish.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f31e6bb7bf1a10f9394f9a395c093aa37eba4855
--- /dev/null
+++ b/.github/workflows/publish.yaml
@@ -0,0 +1,29 @@
+name: release
+
+on:
+ push:
+ tags:
+ - 'v**'
+
+concurrency:
+ group: ${{ github.workflow }}-${{ github.ref }}-publish
+ cancel-in-progress: true
+
+jobs:
+ build-n-publish:
+ runs-on: ubuntu-20.04
+ #if: startsWith(github.event.ref, 'refs/tags')
+ steps:
+ - uses: actions/checkout@v2
+ - name: Set up Python 3.10
+ uses: actions/setup-python@v2
+ with:
+ python-version: '3.10'
+ - name: Install wheel
+ run: pip install wheel==0.44.0 && pip install -r requirements.txt
+ - name: Build DiffSynth
+ run: python setup.py sdist bdist_wheel
+ - name: Publish package to PyPI
+ run: |
+ pip install twine
+ twine upload dist/* --skip-existing -u __token__ -p ${{ secrets.PYPI_API_TOKEN }}
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..d484f7458c0c5addfadb614125a4abc7e3041d72
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,8 @@
+*.png
+*pycache*
+*.safetensors
+*.ckpt
+models/
+*.log
+*.html
+*.jpg
diff --git a/.vscode/launch.json b/.vscode/launch.json
new file mode 100644
index 0000000000000000000000000000000000000000..11f7607fead5f64761524f75899365ac37478e0b
--- /dev/null
+++ b/.vscode/launch.json
@@ -0,0 +1,36 @@
+{
+ "version": "0.2.0",
+ "configurations": [
+ {
+ "name": "KontextT2I 1201 debug",
+ "type": "python",
+ "request": "launch",
+ "program": "examples/flux/model_training/train.py", // 关键:指定 accelerate 可执行文件路径
+ "args": [
+ "--dataset_base_path", "/fi-lib/workspace/sjx/DiffSynth-Studio/dataset/multi_frame",
+ "--dataset_metadata_path", "/fi-lib/workspace/sjx/DiffSynth-Studio/dataset/multi_frame/pairs.txt",
+ "--data_file_keys", "image,prompt",
+ "--max_pixels", "1048576",
+ "--dataset_repeat", "400",
+ "--model_id_with_origin_paths", "black-forest-labs/FLUX.1-Kontext-dev:flux1-kontext-dev.safetensors,black-forest-labs/FLUX.1-Kontext-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-Kontext-dev:text_encoder_2/,black-forest-labs/FLUX.1-Kontext-dev:ae.safetensors",
+ "--learning_rate", "1e-5",
+ "--num_epochs", "5",
+ "--remove_prefix_in_ckpt", "pipe.dit.",
+ "--output_path", "./models/train/FLUX.1_lora_1127_mbti",
+ "--lora_base_model", "dit",
+ "--lora_target_modules", "a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp",
+ "--lora_rank", "32",
+ "--align_to_opensource_format",
+ "--use_gradient_checkpointing",
+ "--default_caption", "Convert this real photo into a mbti style."
+ ],
+ "console": "integratedTerminal", // 输出日志到 VS Code 内置终端(方便查看)
+ "justMyCode": false, // 允许调试第三方库(如 accelerate、transformers)
+ "cwd": "${workspaceFolder}", // 工作目录设为项目根目录(确保路径正确)
+ "env": {
+ "PYTHONUNBUFFERED": "1", // 禁用输出缓冲,实时查看日志
+ "CUDA_VISIBLE_DEVICES": "7"
+ }
+ }
+ ]
+}
\ No newline at end of file
diff --git a/FLUX.1-Kontext-dev.py b/FLUX.1-Kontext-dev.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b5f00e011ae59a8731957a5c69aa46b0f03782a
--- /dev/null
+++ b/FLUX.1-Kontext-dev.py
@@ -0,0 +1,68 @@
+import torch
+from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
+from PIL import Image
+import os
+import json
+
+
+
+# for i in range(2):
+# pipe.load_lora(pipe.dit, f"models/train/FLUX.1_lora_1126/epoch-{i}.safetensors", alpha=1)
+# step = 25
+# input_path = "dataset/multi_frame"
+# base_path = f"validate_result/multi_frame{step}"
+# save_path = f"{base_path}/epoch{i}"
+# save_path_GT = f"{base_path}/GT"
+# os.makedirs(save_path, exist_ok=True)
+# os.makedirs(save_path_GT, exist_ok=True)
+# for img in os.listdir(input_path):
+# image = Image.open(os.path.join(input_path,img))
+# image.save(os.path.join(save_path_GT,img))
+# prompt="Convert this image into a line art style: retain the original scenes and characters unchanged, present it as a black-and-white sketch effect, and make it suitable for storyboard design. Requirements: use bold and powerful lines, highlight structures and textures with concise strokes, adopt a style close to comic sketching, roughly outline the scenes and character movements with simple lines, prohibit the depiction of details, and represent the characters' facial features with the simplest lines.",
+# # prompt = "Convert this image into a mbti style"
+# for fig in os.listdir(input_path):
+# if not fig.endswith(".png"):
+# continue
+# image = pipe(
+# prompt = prompt,
+# kontext_images=Image.open(os.path.join(input_path,fig)).resize((768, 768)),
+# height=768, width=768,
+# seed=0,
+# num_inference_steps=step
+# )
+# image.save(os.path.join(save_path,fig))
+
+for i in range(2):
+ pipe = FluxImagePipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(model_id="black-forest-labs/FLUX.1-Kontext-dev", origin_file_pattern="flux1-kontext-dev.safetensors"),
+ ModelConfig(model_id="black-forest-labs/FLUX.1-Kontext-dev", origin_file_pattern="text_encoder/model.safetensors"),
+ ModelConfig(model_id="black-forest-labs/FLUX.1-Kontext-dev", origin_file_pattern="text_encoder_2/"),
+ ModelConfig(model_id="black-forest-labs/FLUX.1-Kontext-dev", origin_file_pattern="ae.safetensors"),
+ ],
+ )
+ pipe.load_lora(pipe.dit, f"models/train/FLUX.1_lora_1126/epoch-{i}.safetensors", alpha=1)
+ step = 25
+ base_path = "/fi-lib/workspace/sjx/DiffSynth-Studio/validate_result/t2i_1201{step}"
+ save_path = f"{base_path}/epoch{i}"
+ os.makedirs(save_path, exist_ok=True)
+ with open("nano_comprehension_1201.txt", "r") as f:
+ prompts = f.readlines()
+ for prompt in prompts:
+ prompt = prompt.strip()
+ if prompt == "":
+ continue
+ prompt_dict = json.loads(prompt)
+ fig = f"{prompt_dict["Image_Name"]}.png"
+ del prompt_dict["Image_Name"]
+ prompt = json.dumps(prompt_dict, ensure_ascii=False)
+ image = pipe(
+ prompt = prompt,
+ height=768, width=768,
+ seed=0,
+ num_inference_steps=step
+ )
+ image.save(os.path.join(save_path,fig))
+
\ No newline at end of file
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..0e5a49f5f70af9e9d37278f72315d1b1afd34895
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [2023] [Zhongjie Duan]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/README.md b/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..a04fabbf7a4c4abc2694cde80253168762a07f5a
--- /dev/null
+++ b/README.md
@@ -0,0 +1,522 @@
+# DiffSynth-Studio
+
+

+
+[](https://pypi.org/project/DiffSynth/)
+[](https://github.com/modelscope/DiffSynth-Studio/blob/master/LICENSE)
+[](https://github.com/modelscope/DiffSynth-Studio/issues)
+[](https://GitHub.com/modelscope/DiffSynth-Studio/pull/)
+[](https://GitHub.com/modelscope/DiffSynth-Studio/commit/)
+
+[切换到中文](./README_zh.md)
+
+## Introduction
+
+Welcome to the magic world of Diffusion models! DiffSynth-Studio is an open-source Diffusion model engine developed and maintained by [ModelScope](https://www.modelscope.cn/) team. We aim to foster technical innovation through framework development, bring together the power of the open-source community, and explore the limits of generative models!
+
+DiffSynth currently includes two open-source projects:
+* [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio): Focused on aggressive technical exploration, for academia, providing support for more cutting-edge model capabilities.
+* [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine): Focused on stable model deployment, for industry, offering higher computing performance and more stable features.
+
+[DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio) and [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine) are the core projects behind ModelScope [AIGC zone](https://modelscope.cn/aigc/home), offering powerful AI content generation abilities. Come and try our carefully designed features and start your AI creation journey!
+
+## Installation
+
+Install from source (recommended):
+
+```
+git clone https://github.com/modelscope/DiffSynth-Studio.git
+cd DiffSynth-Studio
+pip install -e .
+```
+
+
+Other installation methods
+
+Install from PyPI (version updates may be delayed; for latest features, install from source)
+
+```
+pip install diffsynth
+```
+
+If you meet problems during installation, they might be caused by upstream dependencies. Please check the docs of these packages:
+
+* [torch](https://pytorch.org/get-started/locally/)
+* [sentencepiece](https://github.com/google/sentencepiece)
+* [cmake](https://cmake.org)
+* [cupy](https://docs.cupy.dev/en/stable/install.html)
+
+
+
+## Basic Framework
+
+DiffSynth-Studio redesigns the inference and training pipelines for mainstream Diffusion models (including FLUX, Wan, etc.), enabling efficient memory management and flexible model training.
+
+### Qwen-Image Series (🔥New Model)
+
+Details: [./examples/qwen_image/](./examples/qwen_image/)
+
+
+
+
+
+Quick Start
+
+```python
+from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
+from PIL import Image
+import torch
+
+pipe = QwenImagePipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
+ ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
+ ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
+ ],
+ tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
+)
+prompt = "A detailed portrait of a girl underwater, wearing a blue flowing dress, hair gently floating, clear light and shadow, surrounded by bubbles, calm expression, fine details, dreamy and beautiful."
+image = pipe(
+ prompt, seed=0, num_inference_steps=40,
+ # edit_image=Image.open("xxx.jpg").resize((1328, 1328)) # For Qwen-Image-Edit
+)
+image.save("image.jpg")
+```
+
+
+
+
+
+Model Overview
+
+|Model ID|Inference|Low VRAM Inference|Full Training|Validation after Full Training|LoRA Training|Validation after LoRA Training|
+|-|-|-|-|-|-|-|
+|[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](./examples/qwen_image/model_inference/Qwen-Image.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image.py)|
+|[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](./examples/qwen_image/model_inference/Qwen-Image-Edit.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Edit.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py)|
+|[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](./examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py)|
+|[DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2)|[code](./examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)|-|-|[code](./examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
+|[DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster)|[code](./examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py)|-|-|[code](./examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py)|
+|[DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full)|[code](./examples/qwen_image/model_inference/Qwen-Image-Distill-Full.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-Full.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Distill-Full.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Distill-Full.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Distill-Full.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-Full.py)|
+|[DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA)|[code](./examples/qwen_image/model_inference/Qwen-Image-Distill-LoRA.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-LoRA.py)|-|-|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-LoRA.py)|
+|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](./examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](./examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
+|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny)|[code](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Canny.py)|
+|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth)|[code](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Depth.py)|
+|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint)|[code](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|
+|[DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union)|[code](./examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-In-Context-Control-Union.py)|-|-|[code](./examples/qwen_image/model_training/lora/Qwen-Image-In-Context-Control-Union.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-In-Context-Control-Union.py)|
+|[DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix)|[code](./examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-Lowres-Fix.py)|-|-|-|-|
+
+
+
+### FLUX Series
+
+Detail page: [./examples/flux/](./examples/flux/)
+
+
+
+
+
+Quick Start
+
+```python
+import torch
+from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
+
+pipe = FluxImagePipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
+ ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
+ ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
+ ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
+ ],
+)
+
+image = pipe(prompt="a cat", seed=0)
+image.save("image.jpg")
+```
+
+
+
+
+
+Model Overview
+
+| Model ID | Extra Parameters | Inference | Low VRAM Inference | Full Training | Validate After Full Training | LoRA Training | Validate After LoRA Training |
+|-|-|-|-|-|-|-|-|
+|[FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev)||[code](./examples/flux/model_inference/FLUX.1-dev.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev.py)|[code](./examples/flux/model_training/full/FLUX.1-dev.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-dev.py)|[code](./examples/flux/model_training/lora/FLUX.1-dev.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-dev.py)|
+|[FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev)||[code](./examples/flux/model_inference/FLUX.1-Krea-dev.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-Krea-dev.py)|[code](./examples/flux/model_training/full/FLUX.1-Krea-dev.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-Krea-dev.py)|[code](./examples/flux/model_training/lora/FLUX.1-Krea-dev.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-Krea-dev.py)|
+|[FLUX.1-Kontext-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Kontext-dev)|`kontext_images`|[code](./examples/flux/model_inference/FLUX.1-Kontext-dev.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-Kontext-dev.py)|[code](./examples/flux/model_training/full/FLUX.1-Kontext-dev.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-Kontext-dev.py)|[code](./examples/flux/model_training/lora/FLUX.1-Kontext-dev.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-Kontext-dev.py)|
+|[FLUX.1-dev-Controlnet-Inpainting-Beta](https://www.modelscope.cn/models/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta)|`controlnet_inputs`|[code](./examples/flux/model_inference/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](./examples/flux/model_training/full/FLUX.1-dev-Controlnet-Inpainting-Beta.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](./examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Inpainting-Beta.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|
+|[FLUX.1-dev-Controlnet-Union-alpha](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-Controlnet-Union-alpha)|`controlnet_inputs`|[code](./examples/flux/model_inference/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](./examples/flux/model_training/full/FLUX.1-dev-Controlnet-Union-alpha.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](./examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Union-alpha.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Union-alpha.py)|
+|[FLUX.1-dev-Controlnet-Upscaler](https://www.modelscope.cn/models/jasperai/Flux.1-dev-Controlnet-Upscaler)|`controlnet_inputs`|[code](./examples/flux/model_inference/FLUX.1-dev-Controlnet-Upscaler.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Upscaler.py)|[code](./examples/flux/model_training/full/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Upscaler.py)|[code](./examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Upscaler.py)|
+|[FLUX.1-dev-IP-Adapter](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-IP-Adapter)|`ipadapter_images`, `ipadapter_scale`|[code](./examples/flux/model_inference/FLUX.1-dev-IP-Adapter.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev-IP-Adapter.py)|[code](./examples/flux/model_training/full/FLUX.1-dev-IP-Adapter.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-dev-IP-Adapter.py)|[code](./examples/flux/model_training/lora/FLUX.1-dev-IP-Adapter.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-dev-IP-Adapter.py)|
+|[FLUX.1-dev-InfiniteYou](https://www.modelscope.cn/models/ByteDance/InfiniteYou)|`infinityou_id_image`, `infinityou_guidance`, `controlnet_inputs`|[code](./examples/flux/model_inference/FLUX.1-dev-InfiniteYou.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev-InfiniteYou.py)|[code](./examples/flux/model_training/full/FLUX.1-dev-InfiniteYou.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-dev-InfiniteYou.py)|[code](./examples/flux/model_training/lora/FLUX.1-dev-InfiniteYou.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-dev-InfiniteYou.py)|
+|[FLUX.1-dev-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen)|`eligen_entity_prompts`, `eligen_entity_masks`, `eligen_enable_on_negative`, `eligen_enable_inpaint`|[code](./examples/flux/model_inference/FLUX.1-dev-EliGen.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev-EliGen.py)|-|-|[code](./examples/flux/model_training/lora/FLUX.1-dev-EliGen.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-dev-EliGen.py)|
+|[FLUX.1-dev-LoRA-Encoder](https://www.modelscope.cn/models/DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev)|`lora_encoder_inputs`, `lora_encoder_scale`|[code](./examples/flux/model_inference/FLUX.1-dev-LoRA-Encoder.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev-LoRA-Encoder.py)|[code](./examples/flux/model_training/full/FLUX.1-dev-LoRA-Encoder.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-dev-LoRA-Encoder.py)|-|-|
+|[FLUX.1-dev-LoRA-Fusion-Preview](https://modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev)||[code](./examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py)|-|-|-|-|-|
+|[Step1X-Edit](https://www.modelscope.cn/models/stepfun-ai/Step1X-Edit)|`step1x_reference_image`|[code](./examples/flux/model_inference/Step1X-Edit.py)|[code](./examples/flux/model_inference_low_vram/Step1X-Edit.py)|[code](./examples/flux/model_training/full/Step1X-Edit.sh)|[code](./examples/flux/model_training/validate_full/Step1X-Edit.py)|[code](./examples/flux/model_training/lora/Step1X-Edit.sh)|[code](./examples/flux/model_training/validate_lora/Step1X-Edit.py)|
+|[FLEX.2-preview](https://www.modelscope.cn/models/ostris/Flex.2-preview)|`flex_inpaint_image`, `flex_inpaint_mask`, `flex_control_image`, `flex_control_strength`, `flex_control_stop`|[code](./examples/flux/model_inference/FLEX.2-preview.py)|[code](./examples/flux/model_inference_low_vram/FLEX.2-preview.py)|[code](./examples/flux/model_training/full/FLEX.2-preview.sh)|[code](./examples/flux/model_training/validate_full/FLEX.2-preview.py)|[code](./examples/flux/model_training/lora/FLEX.2-preview.sh)|[code](./examples/flux/model_training/validate_lora/FLEX.2-preview.py)|
+|[Nexus-Gen](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2)|`nexus_gen_reference_image`|[code](./examples/flux/model_inference/Nexus-Gen-Editing.py)|[code](./examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py)|[code](./examples/flux/model_training/full/Nexus-Gen.sh)|[code](./examples/flux/model_training/validate_full/Nexus-Gen.py)|[code](./examples/flux/model_training/lora/Nexus-Gen.sh)|[code](./examples/flux/model_training/validate_lora/Nexus-Gen.py)|
+
+
+
+
+
+### Wan Series
+
+Detail page: [./examples/wanvideo/](./examples/wanvideo/)
+
+https://github.com/user-attachments/assets/1d66ae74-3b02-40a9-acc3-ea95fc039314
+
+
+
+Quick Start
+
+```python
+import torch
+from diffsynth import save_video
+from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
+
+pipe = WanVideoPipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
+ ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
+ ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
+ ],
+)
+pipe.enable_vram_management()
+
+video = pipe(
+ prompt="A documentary photography style scene: a lively puppy rapidly running on green grass. The puppy has brown-yellow fur, upright ears, and looks focused and joyful. Sunlight shines on its body, making the fur appear soft and shiny. The background is an open field with occasional wildflowers, and faint blue sky and clouds in the distance. Strong sense of perspective captures the motion of the puppy and the vitality of the surrounding grass. Mid-shot side-moving view.",
+ negative_prompt="Bright colors, overexposed, static, blurry details, subtitles, style, artwork, image, still, overall gray, worst quality, low quality, JPEG compression artifacts, ugly, deformed, extra fingers, poorly drawn hands, poorly drawn face, malformed limbs, fused fingers, still frame, messy background, three legs, crowded background people, walking backwards",
+ seed=0, tiled=True,
+)
+save_video(video, "video1.mp4", fps=15, quality=5)
+```
+
+
+
+
+
+Model Overview
+
+| Model ID | Extra Parameters | Inference | Full Training | Validate After Full Training | LoRA Training | Validate After LoRA Training |
+|-|-|-|-|-|-|-|
+|[Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B)|`input_image`, `animate_pose_video`, `animate_face_video`, `animate_inpaint_video`, `animate_mask_video`|[code](./examples/wanvideo/model_inference/Wan2.2-Animate-14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-Animate-14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-Animate-14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-Animate-14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-Animate-14B.py)|
+|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](./examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-S2V-14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-S2V-14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-S2V-14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-S2V-14B.py)|
+|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py)|
+|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py)|
+|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py)|
+|[Wan-AI/Wan2.2-VACE-Fun-A14B](https://www.modelscope.cn/models/PAI/Wan2.2-VACE-Fun-A14B)|`vace_control_video`, `vace_reference_image`|[code](./examples/wanvideo/model_inference/Wan2.2-VACE-Fun-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-VACE-Fun-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-VACE-Fun-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-VACE-Fun-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-VACE-Fun-A14B.py)|
+|[PAI/Wan2.2-Fun-A14B-InP](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-InP)|`input_image`, `end_image`|[code](./examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-InP.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-InP.py)|
+|[PAI/Wan2.2-Fun-A14B-Control](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control)|`control_video`, `reference_image`|[code](./examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control.py)|
+|[PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera)|`control_camera_video`, `input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py)|
+|[Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B)||[code](./examples/wanvideo/model_inference/Wan2.1-T2V-1.3B.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-T2V-1.3B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-T2V-1.3B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-T2V-1.3B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-1.3B.py)|
+|[Wan-AI/Wan2.1-T2V-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B)||[code](./examples/wanvideo/model_inference/Wan2.1-T2V-14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-T2V-14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-T2V-14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-T2V-14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-14B.py)|
+|[Wan-AI/Wan2.1-I2V-14B-480P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.1-I2V-14B-480P.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-I2V-14B-480P.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-480P.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-480P.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-480P.py)|
+|[Wan-AI/Wan2.1-I2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.1-I2V-14B-720P.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-I2V-14B-720P.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-720P.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-720P.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-720P.py)|
+|[Wan-AI/Wan2.1-FLF2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P)|`input_image`, `end_image`|[code](./examples/wanvideo/model_inference/Wan2.1-FLF2V-14B-720P.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-FLF2V-14B-720P.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-FLF2V-14B-720P.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-FLF2V-14B-720P.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-FLF2V-14B-720P.py)|
+|[PAI/Wan2.1-Fun-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP)|`input_image`, `end_image`|[code](./examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-InP.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-InP.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-InP.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-InP.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-InP.py)|
+|[PAI/Wan2.1-Fun-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control)|`control_video`|[code](./examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-Control.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-Control.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-Control.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-Control.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-Control.py)|
+|[PAI/Wan2.1-Fun-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-InP)|`input_image`, `end_image`|[code](./examples/wanvideo/model_inference/Wan2.1-Fun-14B-InP.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-Fun-14B-InP.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-InP.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-InP.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-InP.py)|
+|[PAI/Wan2.1-Fun-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control)|`control_video`|[code](./examples/wanvideo/model_inference/Wan2.1-Fun-14B-Control.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-Fun-14B-Control.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-Control.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-Control.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-Control.py)|
+|[PAI/Wan2.1-Fun-V1.1-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control)|`control_video`, `reference_image`|[code](./examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control.py)|
+|[PAI/Wan2.1-Fun-V1.1-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control)|`control_video`, `reference_image`|[code](./examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](./examples/wanvideo/examples/wanmodel_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control.py)|
+|[PAI/Wan2.1-Fun-V1.1-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-InP)|`input_image`, `end_image`|[code](./examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-InP.py)|
+|[PAI/Wan2.1-Fun-V1.1-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-InP)|`input_image`, `end_image`|[code](./examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-InP.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-InP.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-InP.py)|
+|[PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera)|`control_camera_video`, `input_image`|[code](./examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|
+|[PAI/Wan2.1-Fun-V1.1-14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control-Camera)|`control_camera_video`, `input_image`|[code](./examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|
+|[iic/VACE-Wan2.1-1.3B-Preview](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview)|`vace_control_video`, `vace_reference_image`|[code](./examples/wanvideo/model_inference/Wan2.1-VACE-1.3B-Preview.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B-Preview.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B-Preview.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B-Preview.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B-Preview.py)|
+|[Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B)|`vace_control_video`, `vace_reference_image`|[code](./examples/wanvideo/model_inference/Wan2.1-VACE-1.3B.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B.py)|
+|[Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B)|`vace_control_video`, `vace_reference_image`|[code](./examples/wanvideo/model_inference/Wan2.1-VACE-14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-VACE-14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-VACE-14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-14B.py)|
+|[DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|`motion_bucket_id`|[code](./examples/wanvideo/model_inference/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py)|
+|[krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video)||[code](./examples/wanvideo/model_inference/krea-realtime-video.py)|[code](./examples/wanvideo/model_training/full/krea-realtime-video.sh)|[code](./examples/wanvideo/model_training/validate_full/krea-realtime-video.py)|[code](./examples/wanvideo/model_training/lora/krea-realtime-video.sh)|[code](./examples/wanvideo/model_training/validate_lora/krea-realtime-video.py)|
+|[meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video)|`longcat_video`|[code](./examples/wanvideo/model_inference/LongCat-Video.py)|[code](./examples/wanvideo/model_training/full/LongCat-Video.sh)|[code](./examples/wanvideo/model_training/validate_full/LongCat-Video.py)|[code](./examples/wanvideo/model_training/lora/LongCat-Video.sh)|[code](./examples/wanvideo/model_training/validate_lora/LongCat-Video.py)|
+|[ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B)|`vap_video`, `vap_prompt`|[code](./examples/wanvideo/model_inference/Video-As-Prompt-Wan2.1-14B.py)|[code](./examples/wanvideo/model_training/full/Video-As-Prompt-Wan2.1-14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Video-As-Prompt-Wan2.1-14B.py)|[code](./examples/wanvideo/model_training/lora/Video-As-Prompt-Wan2.1-14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Video-As-Prompt-Wan2.1-14B.py)|
+
+
+
+### More Models
+
+
+
+
+Image Generation Models
+
+Detail page: [./examples/image_synthesis/](./examples/image_synthesis/)
+
+|FLUX|Stable Diffusion 3|
+|-|-|
+|||
+
+|Kolors|Hunyuan-DiT|
+|-|-|
+|||
+
+|Stable Diffusion|Stable Diffusion XL|
+|-|-|
+|||
+
+
+
+
+
+
+Video Generation Models
+
+- HunyuanVideo: [./examples/HunyuanVideo/](./examples/HunyuanVideo/)
+
+https://github.com/user-attachments/assets/48dd24bb-0cc6-40d2-88c3-10feed3267e9
+
+- StepVideo: [./examples/stepvideo/](./examples/stepvideo/)
+
+https://github.com/user-attachments/assets/5954fdaa-a3cf-45a3-bd35-886e3cc4581b
+
+- CogVideoX: [./examples/CogVideoX/](./examples/CogVideoX/)
+
+https://github.com/user-attachments/assets/26b044c1-4a60-44a4-842f-627ff289d006
+
+
+
+
+
+
+Image Quality Assessment Models
+
+We have integrated a series of image quality assessment models. These models can be used for evaluating image generation models, alignment training, and similar tasks.
+
+Detail page: [./examples/image_quality_metric/](./examples/image_quality_metric/)
+
+* [ImageReward](https://github.com/THUDM/ImageReward)
+* [Aesthetic](https://github.com/christophschuhmann/improved-aesthetic-predictor)
+* [PickScore](https://github.com/yuvalkirstain/pickscore)
+* [CLIP](https://github.com/openai/CLIP)
+* [HPSv2](https://github.com/tgxs002/HPSv2)
+* [HPSv2.1](https://github.com/tgxs002/HPSv2)
+* [MPS](https://github.com/Kwai-Kolors/MPS)
+
+
+
+
+
+## Innovative Achievements
+
+DiffSynth-Studio is not just an engineering model framework, but also a platform for incubating innovative results.
+
+
+Nexus-Gen: Unified Architecture for Image Understanding, Generation, and Editing
+
+- Detail page: https://github.com/modelscope/Nexus-Gen
+- Paper: [Nexus-Gen: Unified Image Understanding, Generation, and Editing via Prefilled Autoregression in Shared Embedding Space](https://arxiv.org/pdf/2504.21356)
+- Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2), [HuggingFace](https://huggingface.co/modelscope/Nexus-GenV2)
+- Dataset: [ModelScope Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Nexus-Gen-Training-Dataset)
+- Online Demo: [ModelScope Nexus-Gen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/Nexus-Gen)
+
+
+
+
+
+
+ArtAug: Aesthetic Enhancement for Image Generation Models
+
+- Detail page: [./examples/ArtAug/](./examples/ArtAug/)
+- Paper: [ArtAug: Enhancing Text-to-Image Generation through Synthesis-Understanding Interaction](https://arxiv.org/abs/2412.12888)
+- Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1), [HuggingFace](https://huggingface.co/ECNU-CILab/ArtAug-lora-FLUX.1dev-v1)
+- Online Demo: [ModelScope AIGC Tab](https://www.modelscope.cn/aigc/imageGeneration?tab=advanced&versionId=7228&modelType=LoRA&sdVersion=FLUX_1&modelUrl=modelscope%3A%2F%2FDiffSynth-Studio%2FArtAug-lora-FLUX.1dev-v1%3Frevision%3Dv1.0)
+
+|FLUX.1-dev|FLUX.1-dev + ArtAug LoRA|
+|-|-|
+|||
+
+
+
+
+EliGen: Precise Image Region Control
+
+- Detail page: [./examples/EntityControl/](./examples/EntityControl/)
+- Paper: [EliGen: Entity-Level Controlled Image Generation with Regional Attention](https://arxiv.org/abs/2501.01097)
+- Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen), [HuggingFace](https://huggingface.co/modelscope/EliGen)
+- Online Demo: [ModelScope EliGen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/EliGen)
+- Dataset: [EliGen Train Set](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet)
+
+|Entity Control Mask|Generated Image|
+|-|-|
+|||
+
+
+
+
+ExVideo: Extended Training for Video Generation Models
+
+- Project Page: [Project Page](https://ecnu-cilab.github.io/ExVideoProjectPage/)
+- Paper: [ExVideo: Extending Video Diffusion Models via Parameter-Efficient Post-Tuning](https://arxiv.org/abs/2406.14130)
+- Code Example: [./examples/ExVideo/](./examples/ExVideo/)
+- Model: [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-SVD-128f-v1), [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1)
+
+https://github.com/modelscope/DiffSynth-Studio/assets/35051019/d97f6aa9-8064-4b5b-9d49-ed6001bb9acc
+
+
+
+
+Diffutoon: High-Resolution Anime-Style Video Rendering
+
+- Project Page: [Project Page](https://ecnu-cilab.github.io/DiffutoonProjectPage/)
+- Paper: [Diffutoon: High-Resolution Editable Toon Shading via Diffusion Models](https://arxiv.org/abs/2401.16224)
+- Code Example: [./examples/Diffutoon/](./examples/Diffutoon/)
+
+https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/b54c05c5-d747-4709-be5e-b39af82404dd
+
+
+
+
+DiffSynth: The Initial Version of This Project
+
+- Project Page: [Project Page](https://ecnu-cilab.github.io/DiffSynth.github.io/)
+- Paper: [DiffSynth: Latent In-Iteration Deflickering for Realistic Video Synthesis](https://arxiv.org/abs/2308.03463)
+- Code Example: [./examples/diffsynth/](./examples/diffsynth/)
+
+https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-4481-b79f-0c3a7361a1ea
+
+
+
+
+
+## Update History
+
+- **November 4, 2025**: We support [ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B) model, which is trained on Wan 2.1 and enables motion generation conditioned on reference videos.
+
+- **October 30, 2025**: We support [meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video) model, which enables text-to-video, image-to-video, and video continuation capabilities. This model adopts Wan's framework for both inference and training in this project.
+
+- **October 27, 2025**: We support [krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video) model, further expanding Wan's ecosystem.
+
+- **September 23, 2025** [DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster) is released! This model is jointly developed and open-sourced by us and the Taobao Design Team. The model is built upon Qwen-Image, specifically designed for e-commerce poster scenarios, and supports precise partition layout control. Please refer to [our example code](./examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py).
+
+- **September 9, 2025**: Our training framework now supports multiple training modes and has been adapted for Qwen-Image. In addition to the standard SFT training mode, Direct Distill is now also supported; please refer to [our example code](./examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh). This feature is experimental, and we will continue to improve it to support comprehensive model training capabilities.
+
+- **August 28, 2025** We support Wan2.2-S2V, an audio-driven cinematic video generation model open-sourced by Alibaba. See [./examples/wanvideo/](./examples/wanvideo/).
+
+- **August 21, 2025**: [DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2) is released! Compared to the V1 version, the training dataset has been updated to the [Qwen-Image-Self-Generated-Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Qwen-Image-Self-Generated-Dataset), enabling generated images to better align with the inherent image distribution and style of Qwen-Image. Please refer to [our sample code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py).
+
+- **August 21, 2025**: We open-sourced the [DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union) structure control LoRA model. Following "In Context" routine, it supports various types of structural control conditions, including canny, depth, lineart, softedge, normal, and openpose. Please refer to [our sample code](./examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py).
+
+- **August 20, 2025** We open-sourced [DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix), which improves the editing performance of Qwen-Image-Edit on low-resolution image inputs. Please refer to [our example code](./examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py).
+
+- **August 19, 2025** 🔥 Qwen-Image-Edit is now open source. Welcome the new member to the image editing model family!
+
+- **August 18, 2025** We trained and open-sourced the Inpaint ControlNet model for Qwen-Image, [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint), which adopts a lightweight architectural design. Please refer to [our sample code](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py).
+
+- **August 15, 2025** We open-sourced the [Qwen-Image-Self-Generated-Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Qwen-Image-Self-Generated-Dataset). This is an image dataset generated using the Qwen-Image model, with a total of 160,000 `1024 x 1024` images. It includes the general, English text rendering, and Chinese text rendering subsets. We provide caption, entity and control images annotations for each image. Developers can use this dataset to train models such as ControlNet and EliGen for the Qwen-Image model. We aim to promote technological development through open-source contributions!
+
+- **August 13, 2025** We trained and open-sourced the ControlNet model for Qwen-Image, [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth), which adopts a lightweight architectural design. Please refer to [our sample code](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py).
+
+- **August 12, 2025** We trained and open-sourced the ControlNet model for Qwen-Image, [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny), which adopts a lightweight architectural design. Please refer to [our sample code](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py).
+
+- **August 11, 2025** We released another distilled acceleration model for Qwen-Image, [DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA). It uses the same training process as [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full), but the model structure is changed to LoRA. This makes it work better with other open-source models.
+
+- **August 7, 2025** We open-sourced the entity control LoRA of Qwen-Image, [DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen). Qwen-Image-EliGen is able to achieve entity-level controlled text-to-image generation. See the [paper](https://arxiv.org/abs/2501.01097) for technical details. Training dataset: [EliGenTrainSet](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet).
+
+- **August 5, 2025** We open-sourced the distilled acceleration model of Qwen-Image, [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full), achieving approximately 5x speedup.
+
+- **August 4, 2025** 🔥 Qwen-Image is now open source. Welcome the new member to the image generation model family!
+
+- **August 1, 2025** [FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev) with a focus on aesthetic photography is comprehensively supported, including low-GPU-memory layer-by-layer offload, LoRA training and full training. See [./examples/flux/](./examples/flux/).
+
+- **July 28, 2025** With the open-sourcing of Wan 2.2, we immediately provided comprehensive support, including low-GPU-memory layer-by-layer offload, FP8 quantization, sequence parallelism, LoRA training, full training. See [./examples/wanvideo/](./examples/wanvideo/).
+
+- **July 11, 2025** We propose Nexus-Gen, a unified model that synergizes the language reasoning capabilities of LLMs with the image synthesis power of diffusion models. This framework enables seamless image understanding, generation, and editing tasks.
+ - Paper: [Nexus-Gen: Unified Image Understanding, Generation, and Editing via Prefilled Autoregression in Shared Embedding Space](https://arxiv.org/pdf/2504.21356)
+ - Github Repo: https://github.com/modelscope/Nexus-Gen
+ - Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2), [HuggingFace](https://huggingface.co/modelscope/Nexus-GenV2)
+ - Training Dataset: [ModelScope Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Nexus-Gen-Training-Dataset)
+ - Online Demo: [ModelScope Nexus-Gen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/Nexus-Gen)
+
+
+More
+
+- **June 15, 2025** ModelScope's official evaluation framework, [EvalScope](https://github.com/modelscope/evalscope), now supports text-to-image generation evaluation. Try it with the [Best Practices](https://evalscope.readthedocs.io/zh-cn/latest/best_practice/t2i_eval.html) guide.
+
+- **March 25, 2025** Our new open-source project, [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine), is now open-sourced! Focused on stable model deployment. Geared towards industry. Offers better engineering support, higher computational performance, and more stable functionality.
+
+- **March 31, 2025** We support InfiniteYou, an identity preserving method for FLUX. Please refer to [./examples/InfiniteYou/](./examples/InfiniteYou/) for more details.
+
+- **March 13, 2025** We support HunyuanVideo-I2V, the image-to-video generation version of HunyuanVideo open-sourced by Tencent. Please refer to [./examples/HunyuanVideo/](./examples/HunyuanVideo/) for more details.
+
+- **February 25, 2025** We support Wan-Video, a collection of SOTA video synthesis models open-sourced by Alibaba. See [./examples/wanvideo/](./examples/wanvideo/).
+
+- **February 17, 2025** We support [StepVideo](https://modelscope.cn/models/stepfun-ai/stepvideo-t2v/summary)! State-of-the-art video synthesis model! See [./examples/stepvideo](./examples/stepvideo/).
+
+- **December 31, 2024** We propose EliGen, a novel framework for precise entity-level controlled text-to-image generation, complemented by an inpainting fusion pipeline to extend its capabilities to image inpainting tasks. EliGen seamlessly integrates with existing community models, such as IP-Adapter and In-Context LoRA, enhancing its versatility. For more details, see [./examples/EntityControl](./examples/EntityControl/).
+ - Paper: [EliGen: Entity-Level Controlled Image Generation with Regional Attention](https://arxiv.org/abs/2501.01097)
+ - Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen), [HuggingFace](https://huggingface.co/modelscope/EliGen)
+ - Online Demo: [ModelScope EliGen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/EliGen)
+ - Training Dataset: [EliGen Train Set](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet)
+
+- **December 19, 2024** We implement advanced VRAM management for HunyuanVideo, making it possible to generate videos at a resolution of 129x720x1280 using 24GB of VRAM, or at 129x512x384 resolution with just 6GB of VRAM. Please refer to [./examples/HunyuanVideo/](./examples/HunyuanVideo/) for more details.
+
+- **December 18, 2024** We propose ArtAug, an approach designed to improve text-to-image synthesis models through synthesis-understanding interactions. We have trained an ArtAug enhancement module for FLUX.1-dev in the format of LoRA. This model integrates the aesthetic understanding of Qwen2-VL-72B into FLUX.1-dev, leading to an improvement in the quality of generated images.
+ - Paper: https://arxiv.org/abs/2412.12888
+ - Examples: https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/ArtAug
+ - Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1), [HuggingFace](https://huggingface.co/ECNU-CILab/ArtAug-lora-FLUX.1dev-v1)
+ - Demo: [ModelScope](https://modelscope.cn/aigc/imageGeneration?tab=advanced&versionId=7228&modelType=LoRA&sdVersion=FLUX_1&modelUrl=modelscope%3A%2F%2FDiffSynth-Studio%2FArtAug-lora-FLUX.1dev-v1%3Frevision%3Dv1.0), HuggingFace (Coming soon)
+
+- **October 25, 2024** We provide extensive FLUX ControlNet support. This project supports many different ControlNet models that can be freely combined, even if their structures differ. Additionally, ControlNet models are compatible with high-resolution refinement and partition control techniques, enabling very powerful controllable image generation. See [`./examples/ControlNet/`](./examples/ControlNet/).
+
+- **October 8, 2024.** We release the extended LoRA based on CogVideoX-5B and ExVideo. You can download this model from [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-CogVideoX-LoRA-129f-v1) or [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-CogVideoX-LoRA-129f-v1).
+
+- **August 22, 2024.** CogVideoX-5B is supported in this project. See [here](/examples/video_synthesis/). We provide several interesting features for this text-to-video model, including
+ - Text to video
+ - Video editing
+ - Self-upscaling
+ - Video interpolation
+
+- **August 22, 2024.** We have implemented an interesting painter that supports all text-to-image models. Now you can create stunning images using the painter, with assistance from AI!
+ - Use it in our [WebUI](#usage-in-webui).
+
+- **August 21, 2024.** FLUX is supported in DiffSynth-Studio.
+ - Enable CFG and highres-fix to improve visual quality. See [here](/examples/image_synthesis/README.md)
+ - LoRA, ControlNet, and additional models will be available soon.
+
+- **June 21, 2024.** We propose ExVideo, a post-tuning technique aimed at enhancing the capability of video generation models. We have extended Stable Video Diffusion to achieve the generation of long videos up to 128 frames.
+ - [Project Page](https://ecnu-cilab.github.io/ExVideoProjectPage/)
+ - Source code is released in this repo. See [`examples/ExVideo`](./examples/ExVideo/).
+ - Models are released on [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1) and [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-SVD-128f-v1).
+ - Technical report is released on [arXiv](https://arxiv.org/abs/2406.14130).
+ - You can try ExVideo in this [Demo](https://huggingface.co/spaces/modelscope/ExVideo-SVD-128f-v1)!
+
+- **June 13, 2024.** DiffSynth Studio is transferred to ModelScope. The developers have transitioned from "I" to "we". Of course, I will still participate in development and maintenance.
+
+- **Jan 29, 2024.** We propose Diffutoon, a fantastic solution for toon shading.
+ - [Project Page](https://ecnu-cilab.github.io/DiffutoonProjectPage/)
+ - The source codes are released in this project.
+ - The technical report (IJCAI 2024) is released on [arXiv](https://arxiv.org/abs/2401.16224).
+
+- **Dec 8, 2023.** We decide to develop a new Project, aiming to release the potential of diffusion models, especially in video synthesis. The development of this project is started.
+
+- **Nov 15, 2023.** We propose FastBlend, a powerful video deflickering algorithm.
+ - The sd-webui extension is released on [GitHub](https://github.com/Artiprocher/sd-webui-fastblend).
+ - Demo videos are shown on Bilibili, including three tasks.
+ - [Video deflickering](https://www.bilibili.com/video/BV1d94y1W7PE)
+ - [Video interpolation](https://www.bilibili.com/video/BV1Lw411m71p)
+ - [Image-driven video rendering](https://www.bilibili.com/video/BV1RB4y1Z7LF)
+ - The technical report is released on [arXiv](https://arxiv.org/abs/2311.09265).
+ - An unofficial ComfyUI extension developed by other users is released on [GitHub](https://github.com/AInseven/ComfyUI-fastblend).
+
+- **Oct 1, 2023.** We release an early version of this project, namely FastSDXL. A try for building a diffusion engine.
+ - The source codes are released on [GitHub](https://github.com/Artiprocher/FastSDXL).
+ - FastSDXL includes a trainable OLSS scheduler for efficiency improvement.
+ - The original repo of OLSS is [here](https://github.com/alibaba/EasyNLP/tree/master/diffusion/olss_scheduler).
+ - The technical report (CIKM 2023) is released on [arXiv](https://arxiv.org/abs/2305.14677).
+ - A demo video is shown on [Bilibili](https://www.bilibili.com/video/BV1w8411y7uj).
+ - Since OLSS requires additional training, we don't implement it in this project.
+
+- **Aug 29, 2023.** We propose DiffSynth, a video synthesis framework.
+ - [Project Page](https://ecnu-cilab.github.io/DiffSynth.github.io/).
+ - The source codes are released in [EasyNLP](https://github.com/alibaba/EasyNLP/tree/master/diffusion/DiffSynth).
+ - The technical report (ECML PKDD 2024) is released on [arXiv](https://arxiv.org/abs/2308.03463).
+
+
\ No newline at end of file
diff --git a/README_zh.md b/README_zh.md
new file mode 100644
index 0000000000000000000000000000000000000000..22a879d647741e42ffe2a883f0a682cda99da6c7
--- /dev/null
+++ b/README_zh.md
@@ -0,0 +1,538 @@
+# DiffSynth-Studio
+
+
+
+[](https://pypi.org/project/DiffSynth/)
+[](https://github.com/modelscope/DiffSynth-Studio/blob/master/LICENSE)
+[](https://github.com/modelscope/DiffSynth-Studio/issues)
+[](https://GitHub.com/modelscope/DiffSynth-Studio/pull/)
+[](https://GitHub.com/modelscope/DiffSynth-Studio/commit/)
+
+[Switch to English](./README.md)
+
+## 简介
+
+欢迎来到 Diffusion 模型的魔法世界!DiffSynth-Studio 是由[魔搭社区](https://www.modelscope.cn/)团队开发和维护的开源 Diffusion 模型引擎。我们期望以框架建设孵化技术创新,凝聚开源社区的力量,探索生成式模型技术的边界!
+
+DiffSynth 目前包括两个开源项目:
+* [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio): 聚焦于激进的技术探索,面向学术界,提供更前沿的模型能力支持。
+* [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine): 聚焦于稳定的模型部署,面向工业界,提供更高的计算性能与更稳定的功能。
+
+[DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio) 与 [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine) 作为魔搭社区 [AIGC 专区](https://modelscope.cn/aigc/home) 的核心技术支撑,提供了强大的AI生成内容能力。欢迎体验我们精心打造的产品化功能,开启您的AI创作之旅!
+
+## 安装
+
+从源码安装(推荐):
+
+```
+git clone https://github.com/modelscope/DiffSynth-Studio.git
+cd DiffSynth-Studio
+pip install -e .
+```
+
+
+其他安装方式
+
+从 pypi 安装(存在版本更新延迟,如需使用最新功能,请从源码安装)
+
+```
+pip install diffsynth
+```
+
+如果在安装过程中遇到问题,可能是由上游依赖包导致的,请参考这些包的文档:
+
+* [torch](https://pytorch.org/get-started/locally/)
+* [sentencepiece](https://github.com/google/sentencepiece)
+* [cmake](https://cmake.org)
+* [cupy](https://docs.cupy.dev/en/stable/install.html)
+
+
+
+
+
+## 基础框架
+
+DiffSynth-Studio 为主流 Diffusion 模型(包括 FLUX、Wan 等)重新设计了推理和训练流水线,能够实现高效的显存管理、灵活的模型训练。
+
+### Qwen-Image 系列 (🔥新模型)
+
+详细页面:[./examples/qwen_image/](./examples/qwen_image/)
+
+
+
+
+
+快速开始
+
+```python
+from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
+from PIL import Image
+import torch
+
+pipe = QwenImagePipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
+ ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
+ ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
+ ],
+ tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
+)
+prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
+image = pipe(
+ prompt, seed=0, num_inference_steps=40,
+ # edit_image=Image.open("xxx.jpg").resize((1328, 1328)) # For Qwen-Image-Edit
+)
+image.save("image.jpg")
+```
+
+
+
+
+
+模型总览
+
+|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
+|-|-|-|-|-|-|-|
+|[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](./examples/qwen_image/model_inference/Qwen-Image.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image.py)|
+|[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](./examples/qwen_image/model_inference/Qwen-Image-Edit.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Edit.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py)|
+|[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](./examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py)|
+|[DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2)|[code](./examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)|-|-|[code](./examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
+|[DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster)|[code](./examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py)|-|-|[code](./examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py)|
+|[DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full)|[code](./examples/qwen_image/model_inference/Qwen-Image-Distill-Full.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-Full.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Distill-Full.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Distill-Full.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Distill-Full.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-Full.py)|
+|[DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA)|[code](./examples/qwen_image/model_inference/Qwen-Image-Distill-LoRA.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-LoRA.py)|-|-|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-LoRA.py)|
+|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](./examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](./examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
+|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny)|[code](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Canny.py)|
+|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth)|[code](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Depth.py)|
+|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint)|[code](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|
+|[DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union)|[code](./examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-In-Context-Control-Union.py)|-|-|[code](./examples/qwen_image/model_training/lora/Qwen-Image-In-Context-Control-Union.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-In-Context-Control-Union.py)|
+|[DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix)|[code](./examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-Lowres-Fix.py)|-|-|-|-|
+
+
+
+### FLUX 系列
+
+详细页面:[./examples/flux/](./examples/flux/)
+
+
+
+
+
+快速开始
+
+```python
+import torch
+from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
+
+pipe = FluxImagePipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
+ ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
+ ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
+ ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
+ ],
+)
+
+image = pipe(prompt="a cat", seed=0)
+image.save("image.jpg")
+```
+
+
+
+
+
+模型总览
+
+|模型 ID|额外参数|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
+|-|-|-|-|-|-|-|-|
+|[FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev)||[code](./examples/flux/model_inference/FLUX.1-dev.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev.py)|[code](./examples/flux/model_training/full/FLUX.1-dev.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-dev.py)|[code](./examples/flux/model_training/lora/FLUX.1-dev.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-dev.py)|
+|[FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev)||[code](./examples/flux/model_inference/FLUX.1-Krea-dev.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-Krea-dev.py)|[code](./examples/flux/model_training/full/FLUX.1-Krea-dev.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-Krea-dev.py)|[code](./examples/flux/model_training/lora/FLUX.1-Krea-dev.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-Krea-dev.py)|
+|[FLUX.1-Kontext-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Kontext-dev)|`kontext_images`|[code](./examples/flux/model_inference/FLUX.1-Kontext-dev.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-Kontext-dev.py)|[code](./examples/flux/model_training/full/FLUX.1-Kontext-dev.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-Kontext-dev.py)|[code](./examples/flux/model_training/lora/FLUX.1-Kontext-dev.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-Kontext-dev.py)|
+|[FLUX.1-dev-Controlnet-Inpainting-Beta](https://www.modelscope.cn/models/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta)|`controlnet_inputs`|[code](./examples/flux/model_inference/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](./examples/flux/model_training/full/FLUX.1-dev-Controlnet-Inpainting-Beta.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](./examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Inpainting-Beta.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|
+|[FLUX.1-dev-Controlnet-Union-alpha](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-Controlnet-Union-alpha)|`controlnet_inputs`|[code](./examples/flux/model_inference/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](./examples/flux/model_training/full/FLUX.1-dev-Controlnet-Union-alpha.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](./examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Union-alpha.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Union-alpha.py)|
+|[FLUX.1-dev-Controlnet-Upscaler](https://www.modelscope.cn/models/jasperai/Flux.1-dev-Controlnet-Upscaler)|`controlnet_inputs`|[code](./examples/flux/model_inference/FLUX.1-dev-Controlnet-Upscaler.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Upscaler.py)|[code](./examples/flux/model_training/full/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Upscaler.py)|[code](./examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Upscaler.py)|
+|[FLUX.1-dev-IP-Adapter](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-IP-Adapter)|`ipadapter_images`, `ipadapter_scale`|[code](./examples/flux/model_inference/FLUX.1-dev-IP-Adapter.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev-IP-Adapter.py)|[code](./examples/flux/model_training/full/FLUX.1-dev-IP-Adapter.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-dev-IP-Adapter.py)|[code](./examples/flux/model_training/lora/FLUX.1-dev-IP-Adapter.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-dev-IP-Adapter.py)|
+|[FLUX.1-dev-InfiniteYou](https://www.modelscope.cn/models/ByteDance/InfiniteYou)|`infinityou_id_image`, `infinityou_guidance`, `controlnet_inputs`|[code](./examples/flux/model_inference/FLUX.1-dev-InfiniteYou.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev-InfiniteYou.py)|[code](./examples/flux/model_training/full/FLUX.1-dev-InfiniteYou.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-dev-InfiniteYou.py)|[code](./examples/flux/model_training/lora/FLUX.1-dev-InfiniteYou.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-dev-InfiniteYou.py)|
+|[FLUX.1-dev-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen)|`eligen_entity_prompts`, `eligen_entity_masks`, `eligen_enable_on_negative`, `eligen_enable_inpaint`|[code](./examples/flux/model_inference/FLUX.1-dev-EliGen.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev-EliGen.py)|-|-|[code](./examples/flux/model_training/lora/FLUX.1-dev-EliGen.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-dev-EliGen.py)|
+|[FLUX.1-dev-LoRA-Encoder](https://www.modelscope.cn/models/DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev)|`lora_encoder_inputs`, `lora_encoder_scale`|[code](./examples/flux/model_inference/FLUX.1-dev-LoRA-Encoder.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev-LoRA-Encoder.py)|[code](./examples/flux/model_training/full/FLUX.1-dev-LoRA-Encoder.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-dev-LoRA-Encoder.py)|-|-|
+|[FLUX.1-dev-LoRA-Fusion-Preview](https://modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev)||[code](./examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py)|-|-|-|-|-|
+|[Step1X-Edit](https://www.modelscope.cn/models/stepfun-ai/Step1X-Edit)|`step1x_reference_image`|[code](./examples/flux/model_inference/Step1X-Edit.py)|[code](./examples/flux/model_inference_low_vram/Step1X-Edit.py)|[code](./examples/flux/model_training/full/Step1X-Edit.sh)|[code](./examples/flux/model_training/validate_full/Step1X-Edit.py)|[code](./examples/flux/model_training/lora/Step1X-Edit.sh)|[code](./examples/flux/model_training/validate_lora/Step1X-Edit.py)|
+|[FLEX.2-preview](https://www.modelscope.cn/models/ostris/Flex.2-preview)|`flex_inpaint_image`, `flex_inpaint_mask`, `flex_control_image`, `flex_control_strength`, `flex_control_stop`|[code](./examples/flux/model_inference/FLEX.2-preview.py)|[code](./examples/flux/model_inference_low_vram/FLEX.2-preview.py)|[code](./examples/flux/model_training/full/FLEX.2-preview.sh)|[code](./examples/flux/model_training/validate_full/FLEX.2-preview.py)|[code](./examples/flux/model_training/lora/FLEX.2-preview.sh)|[code](./examples/flux/model_training/validate_lora/FLEX.2-preview.py)|
+|[Nexus-Gen](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2)|`nexus_gen_reference_image`|[code](./examples/flux/model_inference/Nexus-Gen-Editing.py)|[code](./examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py)|[code](./examples/flux/model_training/full/Nexus-Gen.sh)|[code](./examples/flux/model_training/validate_full/Nexus-Gen.py)|[code](./examples/flux/model_training/lora/Nexus-Gen.sh)|[code](./examples/flux/model_training/validate_lora/Nexus-Gen.py)|
+
+
+
+### Wan 系列
+
+详细页面:[./examples/wanvideo/](./examples/wanvideo/)
+
+https://github.com/user-attachments/assets/1d66ae74-3b02-40a9-acc3-ea95fc039314
+
+
+
+快速开始
+
+```python
+import torch
+from diffsynth import save_video
+from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
+
+pipe = WanVideoPipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
+ ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
+ ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
+ ],
+)
+pipe.enable_vram_management()
+
+video = pipe(
+ prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。",
+ negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
+ seed=0, tiled=True,
+)
+save_video(video, "video1.mp4", fps=15, quality=5)
+```
+
+
+
+
+
+模型总览
+
+|模型 ID|额外参数|推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
+|-|-|-|-|-|-|-|
+|[Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B)|`input_image`, `animate_pose_video`, `animate_face_video`, `animate_inpaint_video`, `animate_mask_video`|[code](./examples/wanvideo/model_inference/Wan2.2-Animate-14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-Animate-14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-Animate-14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-Animate-14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-Animate-14B.py)|
+|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](./examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-S2V-14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-S2V-14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-S2V-14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-S2V-14B.py)|
+|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py)|
+|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py)|
+|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py)|
+|[Wan-AI/Wan2.2-VACE-Fun-A14B](https://www.modelscope.cn/models/PAI/Wan2.2-VACE-Fun-A14B)|`vace_control_video`, `vace_reference_image`|[code](./examples/wanvideo/model_inference/Wan2.2-VACE-Fun-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-VACE-Fun-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-VACE-Fun-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-VACE-Fun-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-VACE-Fun-A14B.py)|
+|[PAI/Wan2.2-Fun-A14B-InP](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-InP)|`input_image`, `end_image`|[code](./examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-InP.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-InP.py)|
+|[PAI/Wan2.2-Fun-A14B-Control](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control)|`control_video`, `reference_image`|[code](./examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control.py)|
+|[PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera)|`control_camera_video`, `input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py)|
+|[Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B)||[code](./examples/wanvideo/model_inference/Wan2.1-T2V-1.3B.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-T2V-1.3B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-T2V-1.3B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-T2V-1.3B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-1.3B.py)|
+|[Wan-AI/Wan2.1-T2V-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B)||[code](./examples/wanvideo/model_inference/Wan2.1-T2V-14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-T2V-14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-T2V-14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-T2V-14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-14B.py)|
+|[Wan-AI/Wan2.1-I2V-14B-480P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.1-I2V-14B-480P.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-I2V-14B-480P.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-480P.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-480P.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-480P.py)|
+|[Wan-AI/Wan2.1-I2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.1-I2V-14B-720P.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-I2V-14B-720P.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-720P.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-720P.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-720P.py)|
+|[Wan-AI/Wan2.1-FLF2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P)|`input_image`, `end_image`|[code](./examples/wanvideo/model_inference/Wan2.1-FLF2V-14B-720P.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-FLF2V-14B-720P.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-FLF2V-14B-720P.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-FLF2V-14B-720P.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-FLF2V-14B-720P.py)|
+|[PAI/Wan2.1-Fun-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP)|`input_image`, `end_image`|[code](./examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-InP.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-InP.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-InP.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-InP.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-InP.py)|
+|[PAI/Wan2.1-Fun-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control)|`control_video`|[code](./examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-Control.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-Control.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-Control.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-Control.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-Control.py)|
+|[PAI/Wan2.1-Fun-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-InP)|`input_image`, `end_image`|[code](./examples/wanvideo/model_inference/Wan2.1-Fun-14B-InP.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-Fun-14B-InP.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-InP.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-InP.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-InP.py)|
+|[PAI/Wan2.1-Fun-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control)|`control_video`|[code](./examples/wanvideo/model_inference/Wan2.1-Fun-14B-Control.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-Fun-14B-Control.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-Control.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-Control.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-Control.py)|
+|[PAI/Wan2.1-Fun-V1.1-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control)|`control_video`, `reference_image`|[code](./examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control.py)|
+|[PAI/Wan2.1-Fun-V1.1-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control)|`control_video`, `reference_image`|[code](./examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](./examples/wanvideo/examples/wanmodel_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control.py)|
+|[PAI/Wan2.1-Fun-V1.1-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-InP)|`input_image`, `end_image`|[code](./examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-InP.py)|
+|[PAI/Wan2.1-Fun-V1.1-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-InP)|`input_image`, `end_image`|[code](./examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-InP.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-InP.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-InP.py)|
+|[PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera)|`control_camera_video`, `input_image`|[code](./examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|
+|[PAI/Wan2.1-Fun-V1.1-14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control-Camera)|`control_camera_video`, `input_image`|[code](./examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|
+|[iic/VACE-Wan2.1-1.3B-Preview](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview)|`vace_control_video`, `vace_reference_image`|[code](./examples/wanvideo/model_inference/Wan2.1-VACE-1.3B-Preview.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B-Preview.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B-Preview.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B-Preview.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B-Preview.py)|
+|[Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B)|`vace_control_video`, `vace_reference_image`|[code](./examples/wanvideo/model_inference/Wan2.1-VACE-1.3B.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B.py)|
+|[Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B)|`vace_control_video`, `vace_reference_image`|[code](./examples/wanvideo/model_inference/Wan2.1-VACE-14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-VACE-14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-VACE-14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-14B.py)|
+|[DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|`motion_bucket_id`|[code](./examples/wanvideo/model_inference/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py)|
+|[krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video)||[code](./examples/wanvideo/model_inference/krea-realtime-video.py)|[code](./examples/wanvideo/model_training/full/krea-realtime-video.sh)|[code](./examples/wanvideo/model_training/validate_full/krea-realtime-video.py)|[code](./examples/wanvideo/model_training/lora/krea-realtime-video.sh)|[code](./examples/wanvideo/model_training/validate_lora/krea-realtime-video.py)|
+|[meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video)|`longcat_video`|[code](./examples/wanvideo/model_inference/LongCat-Video.py)|[code](./examples/wanvideo/model_training/full/LongCat-Video.sh)|[code](./examples/wanvideo/model_training/validate_full/LongCat-Video.py)|[code](./examples/wanvideo/model_training/lora/LongCat-Video.sh)|[code](./examples/wanvideo/model_training/validate_lora/LongCat-Video.py)|
+|[ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B)|`vap_video`, `vap_prompt`|[code](./examples/wanvideo/model_inference/Video-As-Prompt-Wan2.1-14B.py)|[code](./examples/wanvideo/model_training/full/Video-As-Prompt-Wan2.1-14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Video-As-Prompt-Wan2.1-14B.py)|[code](./examples/wanvideo/model_training/lora/Video-As-Prompt-Wan2.1-14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Video-As-Prompt-Wan2.1-14B.py)|
+
+
+
+
+
+### 更多模型
+
+
+
+
+图像生成模型
+
+详细页面:[./examples/image_synthesis/](./examples/image_synthesis/)
+
+|FLUX|Stable Diffusion 3|
+|-|-|
+|||
+
+|Kolors|Hunyuan-DiT|
+|-|-|
+|||
+
+|Stable Diffusion|Stable Diffusion XL|
+|-|-|
+|||
+
+
+
+
+
+
+视频生成模型
+
+- HunyuanVideo:[./examples/HunyuanVideo/](./examples/HunyuanVideo/)
+
+https://github.com/user-attachments/assets/48dd24bb-0cc6-40d2-88c3-10feed3267e9
+
+- StepVideo:[./examples/stepvideo/](./examples/stepvideo/)
+
+https://github.com/user-attachments/assets/5954fdaa-a3cf-45a3-bd35-886e3cc4581b
+
+- CogVideoX:[./examples/CogVideoX/](./examples/CogVideoX/)
+
+https://github.com/user-attachments/assets/26b044c1-4a60-44a4-842f-627ff289d006
+
+
+
+
+
+
+图像质量评估模型
+
+我们集成了一系列图像质量评估模型,这些模型可以用于图像生成模型的评测、对齐训练等场景中。
+
+详细页面:[./examples/image_quality_metric/](./examples/image_quality_metric/)
+
+* [ImageReward](https://github.com/THUDM/ImageReward)
+* [Aesthetic](https://github.com/christophschuhmann/improved-aesthetic-predictor)
+* [PickScore](https://github.com/yuvalkirstain/pickscore)
+* [CLIP](https://github.com/openai/CLIP)
+* [HPSv2](https://github.com/tgxs002/HPSv2)
+* [HPSv2.1](https://github.com/tgxs002/HPSv2)
+* [MPS](https://github.com/Kwai-Kolors/MPS)
+
+
+
+
+
+## 创新成果
+
+DiffSynth-Studio 不仅仅是一个工程化的模型框架,更是创新成果的孵化器。
+
+
+Nexus-Gen: 统一架构的图像理解、生成、编辑
+
+- 详细页面:https://github.com/modelscope/Nexus-Gen
+- 论文:[Nexus-Gen: Unified Image Understanding, Generation, and Editing via Prefilled Autoregression in Shared Embedding Space](https://arxiv.org/pdf/2504.21356)
+- 模型:[ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2), [HuggingFace](https://huggingface.co/modelscope/Nexus-GenV2)
+- 数据集:[ModelScope Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Nexus-Gen-Training-Dataset)
+- 在线体验:[ModelScope Nexus-Gen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/Nexus-Gen)
+
+
+
+
+
+
+
+
+ArtAug: 图像生成模型的美学提升
+
+- 详细页面:[./examples/ArtAug/](./examples/ArtAug/)
+- 论文:[ArtAug: Enhancing Text-to-Image Generation through Synthesis-Understanding Interaction](https://arxiv.org/abs/2412.12888)
+- 模型:[ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1), [HuggingFace](https://huggingface.co/ECNU-CILab/ArtAug-lora-FLUX.1dev-v1)
+- 在线体验:[ModelScope AIGC Tab](https://www.modelscope.cn/aigc/imageGeneration?tab=advanced&versionId=7228&modelType=LoRA&sdVersion=FLUX_1&modelUrl=modelscope%3A%2F%2FDiffSynth-Studio%2FArtAug-lora-FLUX.1dev-v1%3Frevision%3Dv1.0)
+
+|FLUX.1-dev|FLUX.1-dev + ArtAug LoRA|
+|-|-|
+|||
+
+
+
+
+
+
+
+EliGen: 精准的图像分区控制
+
+- 详细页面:[./examples/EntityControl/](./examples/EntityControl/)
+- 论文:[EliGen: Entity-Level Controlled Image Generation with Regional Attention](https://arxiv.org/abs/2501.01097)
+- 模型:[ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen), [HuggingFace](https://huggingface.co/modelscope/EliGen)
+- 在线体验:[ModelScope EliGen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/EliGen)
+- 数据集:[EliGen Train Set](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet)
+
+|实体控制区域|生成图像|
+|-|-|
+|||
+
+
+
+
+
+
+
+ExVideo: 视频生成模型的扩展训练
+
+- 项目页面:[Project Page](https://ecnu-cilab.github.io/ExVideoProjectPage/)
+- 论文:[ExVideo: Extending Video Diffusion Models via Parameter-Efficient Post-Tuning](https://arxiv.org/abs/2406.14130)
+- 代码样例:[./examples/ExVideo/](./examples/ExVideo/)
+- 模型:[ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-SVD-128f-v1), [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1)
+
+https://github.com/modelscope/DiffSynth-Studio/assets/35051019/d97f6aa9-8064-4b5b-9d49-ed6001bb9acc
+
+
+
+
+
+
+
+Diffutoon: 高分辨率动漫风格视频渲染
+
+- 项目页面:[Project Page](https://ecnu-cilab.github.io/DiffutoonProjectPage/)
+- 论文:[Diffutoon: High-Resolution Editable Toon Shading via Diffusion Models](https://arxiv.org/abs/2401.16224)
+- 代码样例:[./examples/Diffutoon/](./examples/Diffutoon/)
+
+https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/b54c05c5-d747-4709-be5e-b39af82404dd
+
+
+
+
+
+
+
+DiffSynth: 本项目的初代版本
+
+- 项目页面:[Project Page](https://ecnu-cilab.github.io/DiffSynth.github.io/)
+- 论文:[DiffSynth: Latent In-Iteration Deflickering for Realistic Video Synthesis](https://arxiv.org/abs/2308.03463)
+- 代码样例:[./examples/diffsynth/](./examples/diffsynth/)
+
+https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-4481-b79f-0c3a7361a1ea
+
+
+
+
+
+## 更新历史
+
+- **2025年11月4日** 支持了 [ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B) 模型,该模型基于 Wan 2.1 训练,支持根据参考视频生成相应的动作。
+
+- **2025年10月30日** 支持了 [meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video) 模型,该模型支持文生视频、图生视频、视频续写。这个模型在本项目中沿用 Wan 的框架进行推理和训练。
+
+- **2025年10月27日** 支持了 [krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video) 模型,Wan 模型生态再添一员。
+
+- **2025年9月23日** [DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster) 发布!本模型由我们与淘天体验设计团队联合研发并开源。模型基于 Qwen-Image 构建,专为电商海报场景设计,支持精确的分区布局控制。 请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)。
+
+- **2025年9月9日** 我们的训练框架支持了多种训练模式,目前已适配 Qwen-Image,除标准 SFT 训练模式外,已支持 Direct Distill,请参考[我们的示例代码](./examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh)。这项功能是实验性的,我们将会继续完善已支持更全面的模型训练功能。
+
+- **2025年8月28日** 我们支持了Wan2.2-S2V,一个音频驱动的电影级视频生成模型。请参见[./examples/wanvideo/](./examples/wanvideo/)。
+
+- **2025年8月21日** [DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2) 发布!相比于 V1 版本,训练数据集变为 [Qwen-Image-Self-Generated-Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Qwen-Image-Self-Generated-Dataset),因此,生成的图像更符合 Qwen-Image 本身的图像分布和风格。 请参考[我们的示例代码](./examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)。
+
+- **2025年8月21日** 我们开源了 [DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union) 结构控制 LoRA 模型,采用 In Context 的技术路线,支持多种类别的结构控制条件,包括 canny, depth, lineart, softedge, normal, openpose。 请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py)。
+
+- **2025年8月20日** 我们开源了 [DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix) 模型,提升了 Qwen-Image-Edit 对低分辨率图像输入的编辑效果。请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py)
+
+- **2025年8月19日** 🔥 Qwen-Image-Edit 开源,欢迎图像编辑模型新成员!
+
+- **2025年8月18日** 我们训练并开源了 Qwen-Image 的图像重绘 ControlNet 模型 [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint),模型结构采用了轻量化的设计,请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py)。
+
+- **2025年8月15日** 我们开源了 [Qwen-Image-Self-Generated-Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Qwen-Image-Self-Generated-Dataset) 数据集。这是一个使用 Qwen-Image 模型生成的图像数据集,共包含 160,000 张`1024 x 1024`图像。它包括通用、英文文本渲染和中文文本渲染子集。我们为每张图像提供了图像描述、实体和结构控制图像的标注。开发者可以使用这个数据集来训练 Qwen-Image 模型的 ControlNet 和 EliGen 等模型,我们旨在通过开源推动技术发展!
+
+- **2025年8月13日** 我们训练并开源了 Qwen-Image 的 ControlNet 模型 [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth),模型结构采用了轻量化的设计,请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py)。
+
+- **2025年8月12日** 我们训练并开源了 Qwen-Image 的 ControlNet 模型 [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny),模型结构采用了轻量化的设计,请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py)。
+
+- **2025年8月11日** 我们开源了 Qwen-Image 的蒸馏加速模型 [DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA),沿用了与 [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full) 相同的训练流程,但模型结构修改为了 LoRA,因此能够更好地与其他开源生态模型兼容。
+
+- **2025年8月7日** 我们开源了 Qwen-Image 的实体控制 LoRA 模型 [DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)。Qwen-Image-EliGen 能够实现实体级可控的文生图。技术细节请参见[论文](https://arxiv.org/abs/2501.01097)。训练数据集:[EliGenTrainSet](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet)。
+
+- **2025年8月5日** 我们开源了 Qwen-Image 的蒸馏加速模型 [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full),实现了约 5 倍加速。
+
+- **2025年8月4日** 🔥 Qwen-Image 开源,欢迎图像生成模型家族新成员!
+
+- **2025年8月1日** [FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev) 开源,这是一个专注于美学摄影的文生图模型。我们第一时间提供了全方位支持,包括低显存逐层 offload、LoRA 训练、全量训练。详细信息请参考 [./examples/flux/](./examples/flux/)。
+
+- **2025年7月28日** Wan 2.2 开源,我们第一时间提供了全方位支持,包括低显存逐层 offload、FP8 量化、序列并行、LoRA 训练、全量训练。详细信息请参考 [./examples/wanvideo/](./examples/wanvideo/)。
+
+- **2025年7月11日** 我们提出 Nexus-Gen,一个将大语言模型(LLM)的语言推理能力与扩散模型的图像生成能力相结合的统一框架。该框架支持无缝的图像理解、生成和编辑任务。
+ - 论文: [Nexus-Gen: Unified Image Understanding, Generation, and Editing via Prefilled Autoregression in Shared Embedding Space](https://arxiv.org/pdf/2504.21356)
+ - Github 仓库: https://github.com/modelscope/Nexus-Gen
+ - 模型: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2), [HuggingFace](https://huggingface.co/modelscope/Nexus-GenV2)
+ - 训练数据集: [ModelScope Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Nexus-Gen-Training-Dataset)
+ - 在线体验: [ModelScope Nexus-Gen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/Nexus-Gen)
+
+
+更多
+
+- **2025年6月15日** ModelScope 官方评测框架 [EvalScope](https://github.com/modelscope/evalscope) 现已支持文生图生成评测。请参考[最佳实践](https://evalscope.readthedocs.io/zh-cn/latest/best_practice/t2i_eval.html)指南进行尝试。
+
+- **2025年3月25日** 我们的新开源项目 [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine) 现已开源!专注于稳定的模型部署,面向工业界,提供更好的工程支持、更高的计算性能和更稳定的功能。
+
+- **2025年3月31日** 我们支持 InfiniteYou,一种用于 FLUX 的人脸特征保留方法。更多细节请参考 [./examples/InfiniteYou/](./examples/InfiniteYou/)。
+
+- **2025年3月13日** 我们支持 HunyuanVideo-I2V,即腾讯开源的 HunyuanVideo 的图像到视频生成版本。更多细节请参考 [./examples/HunyuanVideo/](./examples/HunyuanVideo/)。
+
+- **2025年2月25日** 我们支持 Wan-Video,这是阿里巴巴开源的一系列最先进的视频合成模型。详见 [./examples/wanvideo/](./examples/wanvideo/)。
+
+- **2025年2月17日** 我们支持 [StepVideo](https://modelscope.cn/models/stepfun-ai/stepvideo-t2v/summary)!先进的视频合成模型!详见 [./examples/stepvideo](./examples/stepvideo/)。
+
+- **2024年12月31日** 我们提出 EliGen,一种用于精确实体级别控制的文本到图像生成的新框架,并辅以修复融合管道,将其能力扩展到图像修复任务。EliGen 可以无缝集成现有的社区模型,如 IP-Adapter 和 In-Context LoRA,提升其通用性。更多详情,请见 [./examples/EntityControl](./examples/EntityControl/)。
+ - 论文: [EliGen: Entity-Level Controlled Image Generation with Regional Attention](https://arxiv.org/abs/2501.01097)
+ - 模型: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen), [HuggingFace](https://huggingface.co/modelscope/EliGen)
+ - 在线体验: [ModelScope EliGen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/EliGen)
+ - 训练数据集: [EliGen Train Set](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet)
+
+- **2024年12月19日** 我们为 HunyuanVideo 实现了高级显存管理,使得在 24GB 显存下可以生成分辨率为 129x720x1280 的视频,或在仅 6GB 显存下生成分辨率为 129x512x384 的视频。更多细节请参考 [./examples/HunyuanVideo/](./examples/HunyuanVideo/)。
+
+- **2024年12月18日** 我们提出 ArtAug,一种通过合成-理解交互来改进文生图模型的方法。我们以 LoRA 格式为 FLUX.1-dev 训练了一个 ArtAug 增强模块。该模型将 Qwen2-VL-72B 的美学理解融入 FLUX.1-dev,从而提升了生成图像的质量。
+ - 论文: https://arxiv.org/abs/2412.12888
+ - 示例: https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/ArtAug
+ - 模型: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1), [HuggingFace](https://huggingface.co/ECNU-CILab/ArtAug-lora-FLUX.1dev-v1)
+ - 演示: [ModelScope](https://modelscope.cn/aigc/imageGeneration?tab=advanced&versionId=7228&modelType=LoRA&sdVersion=FLUX_1&modelUrl=modelscope%3A%2F%2FDiffSynth-Studio%2FArtAug-lora-FLUX.1dev-v1%3Frevision%3Dv1.0), HuggingFace (即将上线)
+
+- **2024年10月25日** 我们提供了广泛的 FLUX ControlNet 支持。该项目支持许多不同的 ControlNet 模型,并且可以自由组合,即使它们的结构不同。此外,ControlNet 模型兼容高分辨率优化和分区控制技术,能够实现非常强大的可控图像生成。详见 [`./examples/ControlNet/`](./examples/ControlNet/)。
+
+- **2024年10月8日** 我们发布了基于 CogVideoX-5B 和 ExVideo 的扩展 LoRA。您可以从 [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-CogVideoX-LoRA-129f-v1) 或 [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-CogVideoX-LoRA-129f-v1) 下载此模型。
+
+- **2024年8月22日** 本项目现已支持 CogVideoX-5B。详见 [此处](/examples/video_synthesis/)。我们为这个文生视频模型提供了几个有趣的功能,包括:
+ - 文本到视频
+ - 视频编辑
+ - 自我超分
+ - 视频插帧
+
+- **2024年8月22日** 我们实现了一个有趣的画笔功能,支持所有文生图模型。现在,您可以在 AI 的辅助下使用画笔创作惊艳的图像了!
+ - 在我们的 [WebUI](#usage-in-webui) 中使用它。
+
+- **2024年8月21日** DiffSynth-Studio 现已支持 FLUX。
+ - 启用 CFG 和高分辨率修复以提升视觉质量。详见 [此处](/examples/image_synthesis/README.md)
+ - LoRA、ControlNet 和其他附加模型将很快推出。
+
+- **2024年6月21日** 我们提出 ExVideo,一种旨在增强视频生成模型能力的后训练微调技术。我们将 Stable Video Diffusion 进行了扩展,实现了长达 128 帧的长视频生成。
+ - [项目页面](https://ecnu-cilab.github.io/ExVideoProjectPage/)
+ - 源代码已在此仓库中发布。详见 [`examples/ExVideo`](./examples/ExVideo/)。
+ - 模型已发布于 [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1) 和 [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-SVD-128f-v1)。
+ - 技术报告已发布于 [arXiv](https://arxiv.org/abs/2406.14130)。
+ - 您可以在此 [演示](https://huggingface.co/spaces/modelscope/ExVideo-SVD-128f-v1) 中试用 ExVideo!
+
+- **2024年6月13日** DiffSynth Studio 已迁移至 ModelScope。开发团队也从“我”转变为“我们”。当然,我仍会参与后续的开发和维护工作。
+
+- **2024年1月29日** 我们提出 Diffutoon,这是一个出色的卡通着色解决方案。
+ - [项目页面](https://ecnu-cilab.github.io/DiffutoonProjectPage/)
+ - 源代码已在此项目中发布。
+ - 技术报告(IJCAI 2024)已发布于 [arXiv](https://arxiv.org/abs/2401.16224)。
+
+- **2023年12月8日** 我们决定启动一个新项目,旨在释放扩散模型的潜力,尤其是在视频合成方面。该项目的开发工作正式开始。
+
+- **2023年11月15日** 我们提出 FastBlend,一种强大的视频去闪烁算法。
+ - sd-webui 扩展已发布于 [GitHub](https://github.com/Artiprocher/sd-webui-fastblend)。
+ - 演示视频已在 Bilibili 上展示,包含三个任务:
+ - [视频去闪烁](https://www.bilibili.com/video/BV1d94y1W7PE)
+ - [视频插帧](https://www.bilibili.com/video/BV1Lw411m71p)
+ - [图像驱动的视频渲染](https://www.bilibili.com/video/BV1RB4y1Z7LF)
+ - 技术报告已发布于 [arXiv](https://arxiv.org/abs/2311.09265)。
+ - 其他用户开发的非官方 ComfyUI 扩展已发布于 [GitHub](https://github.com/AInseven/ComfyUI-fastblend)。
+
+- **2023年10月1日** 我们发布了该项目的早期版本,名为 FastSDXL。这是构建一个扩散引擎的初步尝试。
+ - 源代码已发布于 [GitHub](https://github.com/Artiprocher/FastSDXL)。
+ - FastSDXL 包含一个可训练的 OLSS 调度器,以提高效率。
+ - OLSS 的原始仓库位于 [此处](https://github.com/alibaba/EasyNLP/tree/master/diffusion/olss_scheduler)。
+ - 技术报告(CIKM 2023)已发布于 [arXiv](https://arxiv.org/abs/2305.14677)。
+ - 演示视频已发布于 [Bilibili](https://www.bilibili.com/video/BV1w8411y7uj)。
+ - 由于 OLSS 需要额外训练,我们未在本项目中实现它。
+
+- **2023年8月29日** 我们提出 DiffSynth,一个视频合成框架。
+ - [项目页面](https://ecnu-cilab.github.io/DiffSynth.github.io/)。
+ - 源代码已发布在 [EasyNLP](https://github.com/alibaba/EasyNLP/tree/master/diffusion/DiffSynth)。
+ - 技术报告(ECML PKDD 2024)已发布于 [arXiv](https://arxiv.org/abs/2308.03463)。
+
+
diff --git a/_temp_.py b/_temp_.py
new file mode 100644
index 0000000000000000000000000000000000000000..5a010f7ba80d9ea4dffd57315b62a4200d771bed
--- /dev/null
+++ b/_temp_.py
@@ -0,0 +1,184 @@
+
+# import os
+# render2real_path = "/fi-lib/workspace/sjx/DiffSynth-Studio/dataset/spotlight_sketch/epoch0"
+# sketch_enhance_body_path = "/fi-lib/workspace/sjx/DiffSynth-Studio/dataset/spotlight_sketch/GT"
+# render2real_files = set(os.listdir(render2real_path))
+# sketch_enhance_body_files = set(os.listdir(sketch_enhance_body_path))
+# for file in render2real_files:
+# if file not in sketch_enhance_body_files:
+# print(f"Removing {file} from render2real_path")
+# os.remove(os.path.join(render2real_path, file))
+# for file in sketch_enhance_body_files:
+# if file not in render2real_files:
+# print(f"Removing {file} from sketch_enhance_body_path")
+# os.remove(os.path.join(sketch_enhance_body_path, file))
+
+# import os
+# input_path = "/fi-lib/workspace/sjx/DiffSynth-Studio/dataset/mbti/Realistic"
+# for file in os.listdir(input_path):
+# # 把文件名称最后的_Realistic去掉
+# new_name = file.replace("_Realistic", "")
+# os.rename(os.path.join(input_path, file), os.path.join(input_path, new_name))
+
+# import os
+# for file in os.listdir("dataset/spotlight_sketch_cat/GT"):
+# with open("dataset/spotlight_sketch_cat/pairs_t2t.txt", "a") as f:
+# # 目标图 原图
+# f.write(f"GT/{file}\tepoch0/{file}\n")
+
+
+# import os
+# import json
+# from tqdm import tqdm
+# input_txt = "/fi-lib/workspace/sjx/DiffSynth-Studio/dataset/spotlight_sketch_cat/spotlight_nano_comprehension_1203.txt"
+# with open(input_txt, "r") as f:
+# lines = f.readlines()
+# for i in tqdm(range(len(lines))):
+# data = json.loads(lines[i])
+# fig_id = f"{data['Image_Name']}.png"
+# del data["Image_Name"]
+# input_dir = "dataset/spotlight_sketch_cat/epoch0"
+# GT_dir = "dataset/spotlight_sketch_cat/GT"
+# for file in os.listdir(input_dir):
+# if fig_id in file:
+# with open("dataset/spotlight_sketch_cat/pairs_i2i.txt", "a") as f:
+# # 目标 原图 prompt
+# f.write(f"{GT_dir}/{file}\t{input_dir}/{file}\t{data}\n")
+
+# 把文件夹中的图片每六张拼成一个3行两列的大图,保存到另一个文件夹中,原图拼接不要截图
+import os
+from PIL import Image
+from tqdm import tqdm
+import numpy as np
+base_dirs = ["/fi-lib/workspace/sjx/DiffSynth-Studio/dataset/the roses","/fi-lib/workspace/sjx/DiffSynth-Studio/dataset/nouvelle","/fi-lib/workspace/sjx/DiffSynth-Studio/dataset/legs","/fi-lib/workspace/sjx/DiffSynth-Studio/dataset/frankenstein"]
+# 核心配置参数(按需求定义)
+crop_size = (1920, 800) # 目标CenterCrop尺寸 (宽, 高)
+resize_size = (477, 188) # 下采样后的单张尺寸 (宽, 高)
+line_width = 6 # 黑线宽度(6像素)
+target_merge_size = (960, 576) # 最终拼接目标尺寸 (宽, 高)
+
+
+def center_crop_to_size(img, target_size):
+ """
+ 对图片进行CenterCrop到指定尺寸,不足部分用黑色像素填充
+ :param img: PIL Image对象
+ :param target_size: (target_w, target_h) 目标裁剪尺寸
+ :return: crop+补黑后的PIL Image
+ """
+ target_w, target_h = target_size
+ img_w, img_h = img.size
+
+ # Step1: 计算CenterCrop的区域(中心对齐)
+ # 水平方向裁剪
+ if img_w >= target_w:
+ left = (img_w - target_w) // 2
+ right = left + target_w
+ else:
+ left = 0
+ right = img_w
+ # 垂直方向裁剪
+ if img_h >= target_h:
+ top = (img_h - target_h) // 2
+ bottom = top + target_h
+ else:
+ top = 0
+ bottom = img_h
+
+ # Step2: 执行CenterCrop
+ cropped = img.crop((left, top, right, bottom))
+
+ # Step3: 不足目标尺寸的部分用黑色填充
+ if cropped.size != (target_w, target_h):
+ new_img = Image.new("RGB", (target_w, target_h), (0, 0, 0)) # 黑色背景
+ new_img.paste(cropped, ((target_w - cropped.width) // 2, (target_h - cropped.height) // 2))
+ cropped = new_img
+
+ return cropped
+for k in range(len(base_dirs)):
+ save_path = f"{base_dirs[k]}_dedup_cat"
+ os.makedirs(save_path, exist_ok=True)
+ input_path = f"{base_dirs[k]}_dedup"
+# 获取并排序文件列表
+ files = [f for f in os.listdir(input_path) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
+ files.sort()
+
+ # 遍历文件,每6张拼接一次
+ for i in tqdm(range(0, len(files), 6), desc="拼接图片"):
+ # 初始化最终拼接画布(目标尺寸960×576,黑色背景)
+ merged_image = np.zeros((target_merge_size[1], target_merge_size[0], 3), dtype=np.uint8)
+
+ # 逐个处理6张图片
+ valid_imgs = [] # 存储处理后的有效图片
+ for j in range(6):
+ if i + j >= len(files):
+ # 不足6张时,break
+ img_np = np.zeros((resize_size[1], resize_size[0], 3), dtype=np.uint8)
+ valid_imgs.append(img_np)
+ continue
+
+ img_path = os.path.join(input_path, files[i + j])
+ try:
+ # 读取图片(确保RGB格式)
+ img = Image.open(img_path).convert("RGB")
+ img_w, img_h = img.size
+
+ # 过滤条件:原始宽度不足1800则跳过并打印
+ if img_w < 1800:
+ print(f"跳过文件 {files[i+j]}: 原始宽度 {img_w} < 1800")
+ # 用黑色图片填充该位置
+ img_np = np.zeros((resize_size[1], resize_size[0], 3), dtype=np.uint8)
+ valid_imgs.append(img_np)
+ continue
+
+ # Step1: CenterCrop到1920×800,不足补黑
+ cropped_img = center_crop_to_size(img, crop_size)
+
+ # Step2: 下采样到477×188(LANCZOS插值,保持画质)
+ resized_img = cropped_img.resize(resize_size, resample=Image.LANCZOS)
+
+ # 转为numpy数组
+ img_np = np.array(resized_img)
+ valid_imgs.append(img_np)
+
+ except Exception as e:
+ print(f"处理文件 {files[i+j]} 出错: {str(e)}")
+ # 出错时用黑色图片填充
+ img_np = np.zeros((resize_size[1], resize_size[0], 3), dtype=np.uint8)
+ valid_imgs.append(img_np)
+
+ # Step3: 计算每张图在拼接画布中的位置(3行2列 + 6像素黑线)
+ # 验证拼接尺寸兼容性(防止配置错误)
+ assert len(valid_imgs) == 6, "有效图片数量必须为6张"
+ # 计算图片+黑线的总占位,确保适配960×576
+ total_col = 2 * resize_size[0] + 1 * line_width # 2列+1条竖线
+ total_row = 3 * resize_size[1] + 2 * line_width # 3行+2条横线
+ # 计算画布中的偏移(居中放置,保证最终尺寸960×576)
+ offset_x = (target_merge_size[0] - total_col) // 2
+ offset_y = (target_merge_size[1] - total_row) // 2
+
+ # 逐个放置图片到拼接画布
+ for idx, img_np in enumerate(valid_imgs):
+ row = idx // 2 # 0/1/2行
+ col = idx % 2 # 0/1列
+
+ # 计算当前图片的起始位置(含黑线+整体偏移)
+ x_start = offset_x + col * (resize_size[0] + line_width)
+ y_start = offset_y + row * (resize_size[1] + line_width)
+ x_end = x_start + resize_size[0]
+ y_end = y_start + resize_size[1]
+
+ # 确保不超出画布边界
+ x_end = min(x_end, target_merge_size[0])
+ y_end = min(y_end, target_merge_size[1])
+ x_start = max(x_start, 0)
+ y_start = max(y_start, 0)
+
+ # 放置图片到画布
+ merged_image[y_start:y_end, x_start:x_end, :] = img_np[:y_end-y_start, :x_end-x_start, :]
+
+ # Step4: 保存最终拼接图片
+ save_name = f'merged_{i//6}.png'
+ save_full_path = os.path.join(save_path, save_name)
+ Image.fromarray(merged_image).save(save_full_path)
+
+ print(f"所有图片处理完成!结果保存至: {save_path}")
\ No newline at end of file
diff --git a/apps/gradio/DiffSynth_Studio.py b/apps/gradio/DiffSynth_Studio.py
new file mode 100644
index 0000000000000000000000000000000000000000..d26549202db2540a3b411cc7ba21d4d643c9d21b
--- /dev/null
+++ b/apps/gradio/DiffSynth_Studio.py
@@ -0,0 +1,252 @@
+import gradio as gr
+from diffsynth import ModelManager, SDImagePipeline, SDXLImagePipeline, SD3ImagePipeline, HunyuanDiTImagePipeline, FluxImagePipeline
+import os, torch
+from PIL import Image
+import numpy as np
+
+
+config = {
+ "model_config": {
+ "Stable Diffusion": {
+ "model_folder": "models/stable_diffusion",
+ "pipeline_class": SDImagePipeline,
+ "default_parameters": {
+ "cfg_scale": 7.0,
+ "height": 512,
+ "width": 512,
+ }
+ },
+ "Stable Diffusion XL": {
+ "model_folder": "models/stable_diffusion_xl",
+ "pipeline_class": SDXLImagePipeline,
+ "default_parameters": {
+ "cfg_scale": 7.0,
+ }
+ },
+ "Stable Diffusion 3": {
+ "model_folder": "models/stable_diffusion_3",
+ "pipeline_class": SD3ImagePipeline,
+ "default_parameters": {
+ "cfg_scale": 7.0,
+ }
+ },
+ "Stable Diffusion XL Turbo": {
+ "model_folder": "models/stable_diffusion_xl_turbo",
+ "pipeline_class": SDXLImagePipeline,
+ "default_parameters": {
+ "negative_prompt": "",
+ "cfg_scale": 1.0,
+ "num_inference_steps": 1,
+ "height": 512,
+ "width": 512,
+ }
+ },
+ "Kolors": {
+ "model_folder": "models/kolors",
+ "pipeline_class": SDXLImagePipeline,
+ "default_parameters": {
+ "cfg_scale": 7.0,
+ }
+ },
+ "HunyuanDiT": {
+ "model_folder": "models/HunyuanDiT",
+ "pipeline_class": HunyuanDiTImagePipeline,
+ "default_parameters": {
+ "cfg_scale": 7.0,
+ }
+ },
+ "FLUX": {
+ "model_folder": "models/FLUX",
+ "pipeline_class": FluxImagePipeline,
+ "default_parameters": {
+ "cfg_scale": 1.0,
+ }
+ }
+ },
+ "max_num_painter_layers": 8,
+ "max_num_model_cache": 1,
+}
+
+
+def load_model_list(model_type):
+ if model_type is None:
+ return []
+ folder = config["model_config"][model_type]["model_folder"]
+ file_list = [i for i in os.listdir(folder) if i.endswith(".safetensors")]
+ if model_type in ["HunyuanDiT", "Kolors", "FLUX"]:
+ file_list += [i for i in os.listdir(folder) if os.path.isdir(os.path.join(folder, i))]
+ file_list = sorted(file_list)
+ return file_list
+
+
+def load_model(model_type, model_path):
+ global model_dict
+ model_key = f"{model_type}:{model_path}"
+ if model_key in model_dict:
+ return model_dict[model_key]
+ model_path = os.path.join(config["model_config"][model_type]["model_folder"], model_path)
+ model_manager = ModelManager()
+ if model_type == "HunyuanDiT":
+ model_manager.load_models([
+ os.path.join(model_path, "clip_text_encoder/pytorch_model.bin"),
+ os.path.join(model_path, "mt5/pytorch_model.bin"),
+ os.path.join(model_path, "model/pytorch_model_ema.pt"),
+ os.path.join(model_path, "sdxl-vae-fp16-fix/diffusion_pytorch_model.bin"),
+ ])
+ elif model_type == "Kolors":
+ model_manager.load_models([
+ os.path.join(model_path, "text_encoder"),
+ os.path.join(model_path, "unet/diffusion_pytorch_model.safetensors"),
+ os.path.join(model_path, "vae/diffusion_pytorch_model.safetensors"),
+ ])
+ elif model_type == "FLUX":
+ model_manager.torch_dtype = torch.bfloat16
+ file_list = [
+ os.path.join(model_path, "text_encoder/model.safetensors"),
+ os.path.join(model_path, "text_encoder_2"),
+ ]
+ for file_name in os.listdir(model_path):
+ if file_name.endswith(".safetensors"):
+ file_list.append(os.path.join(model_path, file_name))
+ model_manager.load_models(file_list)
+ else:
+ model_manager.load_model(model_path)
+ pipe = config["model_config"][model_type]["pipeline_class"].from_model_manager(model_manager)
+ while len(model_dict) + 1 > config["max_num_model_cache"]:
+ key = next(iter(model_dict.keys()))
+ model_manager_to_release, _ = model_dict[key]
+ model_manager_to_release.to("cpu")
+ del model_dict[key]
+ torch.cuda.empty_cache()
+ model_dict[model_key] = model_manager, pipe
+ return model_manager, pipe
+
+
+model_dict = {}
+
+with gr.Blocks() as app:
+ gr.Markdown("# DiffSynth-Studio Painter")
+ with gr.Row():
+ with gr.Column(scale=382, min_width=100):
+
+ with gr.Accordion(label="Model"):
+ model_type = gr.Dropdown(choices=[i for i in config["model_config"]], label="Model type")
+ model_path = gr.Dropdown(choices=[], interactive=True, label="Model path")
+
+ @gr.on(inputs=model_type, outputs=model_path, triggers=model_type.change)
+ def model_type_to_model_path(model_type):
+ return gr.Dropdown(choices=load_model_list(model_type))
+
+ with gr.Accordion(label="Prompt"):
+ prompt = gr.Textbox(label="Prompt", lines=3)
+ negative_prompt = gr.Textbox(label="Negative prompt", lines=1)
+ cfg_scale = gr.Slider(minimum=1.0, maximum=10.0, value=7.0, step=0.1, interactive=True, label="Classifier-free guidance scale")
+ embedded_guidance = gr.Slider(minimum=0.0, maximum=10.0, value=0.0, step=0.1, interactive=True, label="Embedded guidance scale (only for FLUX)")
+
+ with gr.Accordion(label="Image"):
+ num_inference_steps = gr.Slider(minimum=1, maximum=100, value=20, step=1, interactive=True, label="Inference steps")
+ height = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, interactive=True, label="Height")
+ width = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, interactive=True, label="Width")
+ with gr.Column():
+ use_fixed_seed = gr.Checkbox(value=True, interactive=False, label="Use fixed seed")
+ seed = gr.Number(minimum=0, maximum=10**9, value=0, interactive=True, label="Random seed", show_label=False)
+
+ @gr.on(
+ inputs=[model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width],
+ outputs=[prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width],
+ triggers=model_path.change
+ )
+ def model_path_to_default_params(model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width):
+ load_model(model_type, model_path)
+ cfg_scale = config["model_config"][model_type]["default_parameters"].get("cfg_scale", cfg_scale)
+ embedded_guidance = config["model_config"][model_type]["default_parameters"].get("embedded_guidance", embedded_guidance)
+ num_inference_steps = config["model_config"][model_type]["default_parameters"].get("num_inference_steps", num_inference_steps)
+ height = config["model_config"][model_type]["default_parameters"].get("height", height)
+ width = config["model_config"][model_type]["default_parameters"].get("width", width)
+ return prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width
+
+
+ with gr.Column(scale=618, min_width=100):
+ with gr.Accordion(label="Painter"):
+ enable_local_prompt_list = []
+ local_prompt_list = []
+ mask_scale_list = []
+ canvas_list = []
+ for painter_layer_id in range(config["max_num_painter_layers"]):
+ with gr.Tab(label=f"Layer {painter_layer_id}"):
+ enable_local_prompt = gr.Checkbox(label="Enable", value=False, key=f"enable_local_prompt_{painter_layer_id}")
+ local_prompt = gr.Textbox(label="Local prompt", key=f"local_prompt_{painter_layer_id}")
+ mask_scale = gr.Slider(minimum=0.0, maximum=5.0, value=1.0, step=0.1, interactive=True, label="Mask scale", key=f"mask_scale_{painter_layer_id}")
+ canvas = gr.ImageEditor(canvas_size=(512, 1), sources=None, layers=False, interactive=True, image_mode="RGBA",
+ brush=gr.Brush(default_size=100, default_color="#000000", colors=["#000000"]),
+ label="Painter", key=f"canvas_{painter_layer_id}")
+ @gr.on(inputs=[height, width, canvas], outputs=canvas, triggers=[height.change, width.change, canvas.clear, enable_local_prompt.change], show_progress="hidden")
+ def resize_canvas(height, width, canvas):
+ h, w = canvas["background"].shape[:2]
+ if h != height or width != w:
+ return np.ones((height, width, 3), dtype=np.uint8) * 255
+ else:
+ return canvas
+
+ enable_local_prompt_list.append(enable_local_prompt)
+ local_prompt_list.append(local_prompt)
+ mask_scale_list.append(mask_scale)
+ canvas_list.append(canvas)
+ with gr.Accordion(label="Results"):
+ run_button = gr.Button(value="Generate", variant="primary")
+ output_image = gr.Image(sources=None, show_label=False, interactive=False, type="pil")
+ with gr.Row():
+ with gr.Column():
+ output_to_painter_button = gr.Button(value="Set as painter's background")
+ with gr.Column():
+ output_to_input_button = gr.Button(value="Set as input image")
+ painter_background = gr.State(None)
+ input_background = gr.State(None)
+ @gr.on(
+ inputs=[model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, seed] + enable_local_prompt_list + local_prompt_list + mask_scale_list + canvas_list,
+ outputs=[output_image],
+ triggers=run_button.click
+ )
+ def generate_image(model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, seed, *args, progress=gr.Progress()):
+ _, pipe = load_model(model_type, model_path)
+ input_params = {
+ "prompt": prompt,
+ "negative_prompt": negative_prompt,
+ "cfg_scale": cfg_scale,
+ "num_inference_steps": num_inference_steps,
+ "height": height,
+ "width": width,
+ "progress_bar_cmd": progress.tqdm,
+ }
+ if isinstance(pipe, FluxImagePipeline):
+ input_params["embedded_guidance"] = embedded_guidance
+ enable_local_prompt_list, local_prompt_list, mask_scale_list, canvas_list = (
+ args[0 * config["max_num_painter_layers"]: 1 * config["max_num_painter_layers"]],
+ args[1 * config["max_num_painter_layers"]: 2 * config["max_num_painter_layers"]],
+ args[2 * config["max_num_painter_layers"]: 3 * config["max_num_painter_layers"]],
+ args[3 * config["max_num_painter_layers"]: 4 * config["max_num_painter_layers"]]
+ )
+ local_prompts, masks, mask_scales = [], [], []
+ for enable_local_prompt, local_prompt, mask_scale, canvas in zip(
+ enable_local_prompt_list, local_prompt_list, mask_scale_list, canvas_list
+ ):
+ if enable_local_prompt:
+ local_prompts.append(local_prompt)
+ masks.append(Image.fromarray(canvas["layers"][0][:, :, -1]).convert("RGB"))
+ mask_scales.append(mask_scale)
+ input_params.update({
+ "local_prompts": local_prompts,
+ "masks": masks,
+ "mask_scales": mask_scales,
+ })
+ torch.manual_seed(seed)
+ image = pipe(**input_params)
+ return image
+
+ @gr.on(inputs=[output_image] + canvas_list, outputs=canvas_list, triggers=output_to_painter_button.click)
+ def send_output_to_painter_background(output_image, *canvas_list):
+ for canvas in canvas_list:
+ h, w = canvas["background"].shape[:2]
+ canvas["background"] = output_image.resize((w, h))
+ return tuple(canvas_list)
+app.launch()
diff --git a/apps/gradio/entity_level_control.py b/apps/gradio/entity_level_control.py
new file mode 100644
index 0000000000000000000000000000000000000000..58f4722d5a1b36b1746610bfb231727fc8b0a986
--- /dev/null
+++ b/apps/gradio/entity_level_control.py
@@ -0,0 +1,390 @@
+import os
+import torch
+import numpy as np
+from PIL import Image, ImageDraw, ImageFont
+import random
+import json
+import gradio as gr
+from diffsynth import ModelManager, FluxImagePipeline, download_customized_models
+from modelscope import dataset_snapshot_download
+
+
+dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern=f"data/examples/eligen/entity_control/*")
+example_json = 'data/examples/eligen/entity_control/ui_examples.json'
+with open(example_json, 'r') as f:
+ examples = json.load(f)['examples']
+
+for idx in range(len(examples)):
+ example_id = examples[idx]['example_id']
+ entity_prompts = examples[idx]['local_prompt_list']
+ examples[idx]['mask_lists'] = [Image.open(f"data/examples/eligen/entity_control/example_{example_id}/{i}.png").convert('RGB') for i in range(len(entity_prompts))]
+
+def create_canvas_data(background, masks):
+ if background.shape[-1] == 3:
+ background = np.dstack([background, np.full(background.shape[:2], 255, dtype=np.uint8)])
+ layers = []
+ for mask in masks:
+ if mask is not None:
+ mask_single_channel = mask if mask.ndim == 2 else mask[..., 0]
+ layer = np.zeros((mask_single_channel.shape[0], mask_single_channel.shape[1], 4), dtype=np.uint8)
+ layer[..., -1] = mask_single_channel
+ layers.append(layer)
+ else:
+ layers.append(np.zeros_like(background))
+
+ composite = background.copy()
+ for layer in layers:
+ if layer.size > 0:
+ composite = np.where(layer[..., -1:] > 0, layer, composite)
+ return {
+ "background": background,
+ "layers": layers,
+ "composite": composite,
+ }
+
+def load_example(load_example_button):
+ example_idx = int(load_example_button.split()[-1]) - 1
+ example = examples[example_idx]
+ result = [
+ 50,
+ example["global_prompt"],
+ example["negative_prompt"],
+ example["seed"],
+ *example["local_prompt_list"],
+ ]
+ num_entities = len(example["local_prompt_list"])
+ result += [""] * (config["max_num_painter_layers"] - num_entities)
+ masks = []
+ for mask in example["mask_lists"]:
+ mask_single_channel = np.array(mask.convert("L"))
+ masks.append(mask_single_channel)
+ for _ in range(config["max_num_painter_layers"] - len(masks)):
+ blank_mask = np.zeros_like(masks[0]) if masks else np.zeros((512, 512), dtype=np.uint8)
+ masks.append(blank_mask)
+ background = np.ones((masks[0].shape[0], masks[0].shape[1], 4), dtype=np.uint8) * 255
+ canvas_data_list = []
+ for mask in masks:
+ canvas_data = create_canvas_data(background, [mask])
+ canvas_data_list.append(canvas_data)
+ result.extend(canvas_data_list)
+ return result
+
+def save_mask_prompts(masks, mask_prompts, global_prompt, seed=0, random_dir='0000000'):
+ save_dir = os.path.join('workdirs/tmp_mask', random_dir)
+ print(f'save to {save_dir}')
+ os.makedirs(save_dir, exist_ok=True)
+ for i, mask in enumerate(masks):
+ save_path = os.path.join(save_dir, f'{i}.png')
+ mask.save(save_path)
+ sample = {
+ "global_prompt": global_prompt,
+ "mask_prompts": mask_prompts,
+ "seed": seed,
+ }
+ with open(os.path.join(save_dir, f"prompts.json"), 'w') as f:
+ json.dump(sample, f, indent=4)
+
+def visualize_masks(image, masks, mask_prompts, font_size=35, use_random_colors=False):
+ # Create a blank image for overlays
+ overlay = Image.new('RGBA', image.size, (0, 0, 0, 0))
+ colors = [
+ (165, 238, 173, 80),
+ (76, 102, 221, 80),
+ (221, 160, 77, 80),
+ (204, 93, 71, 80),
+ (145, 187, 149, 80),
+ (134, 141, 172, 80),
+ (157, 137, 109, 80),
+ (153, 104, 95, 80),
+ (165, 238, 173, 80),
+ (76, 102, 221, 80),
+ (221, 160, 77, 80),
+ (204, 93, 71, 80),
+ (145, 187, 149, 80),
+ (134, 141, 172, 80),
+ (157, 137, 109, 80),
+ (153, 104, 95, 80),
+ ]
+ # Generate random colors for each mask
+ if use_random_colors:
+ colors = [(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 80) for _ in range(len(masks))]
+ # Font settings
+ try:
+ font = ImageFont.truetype("arial", font_size) # Adjust as needed
+ except IOError:
+ font = ImageFont.load_default(font_size)
+ # Overlay each mask onto the overlay image
+ for mask, mask_prompt, color in zip(masks, mask_prompts, colors):
+ if mask is None:
+ continue
+ # Convert mask to RGBA mode
+ mask_rgba = mask.convert('RGBA')
+ mask_data = mask_rgba.getdata()
+ new_data = [(color if item[:3] == (255, 255, 255) else (0, 0, 0, 0)) for item in mask_data]
+ mask_rgba.putdata(new_data)
+ # Draw the mask prompt text on the mask
+ draw = ImageDraw.Draw(mask_rgba)
+ mask_bbox = mask.getbbox() # Get the bounding box of the mask
+ if mask_bbox is None:
+ continue
+ text_position = (mask_bbox[0] + 10, mask_bbox[1] + 10) # Adjust text position based on mask position
+ draw.text(text_position, mask_prompt, fill=(255, 255, 255, 255), font=font)
+ # Alpha composite the overlay with this mask
+ overlay = Image.alpha_composite(overlay, mask_rgba)
+ # Composite the overlay onto the original image
+ result = Image.alpha_composite(image.convert('RGBA'), overlay)
+ return result
+
+config = {
+ "model_config": {
+ "FLUX": {
+ "model_folder": "models/FLUX",
+ "pipeline_class": FluxImagePipeline,
+ "default_parameters": {
+ "cfg_scale": 3.0,
+ "embedded_guidance": 3.5,
+ "num_inference_steps": 30,
+ }
+ },
+ },
+ "max_num_painter_layers": 8,
+ "max_num_model_cache": 1,
+}
+
+model_dict = {}
+
+def load_model(model_type='FLUX', model_path='FLUX.1-dev'):
+ global model_dict
+ model_key = f"{model_type}:{model_path}"
+ if model_key in model_dict:
+ return model_dict[model_key]
+ model_path = os.path.join(config["model_config"][model_type]["model_folder"], model_path)
+ model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["FLUX.1-dev"])
+ model_manager.load_lora(
+ download_customized_models(
+ model_id="DiffSynth-Studio/Eligen",
+ origin_file_path="model_bf16.safetensors",
+ local_dir="models/lora/entity_control",
+ ),
+ lora_alpha=1,
+ )
+ pipe = config["model_config"][model_type]["pipeline_class"].from_model_manager(model_manager)
+ model_dict[model_key] = model_manager, pipe
+ return model_manager, pipe
+
+
+with gr.Blocks() as app:
+ gr.Markdown(
+ """## EliGen: Entity-Level Controllable Text-to-Image Model
+ 1. On the left, input the **global prompt** for the overall image, such as "a person stands by the river."
+ 2. On the right, input the **local prompt** for each entity, such as "person," and draw the corresponding mask in the **Entity Mask Painter**. Generally, solid rectangular masks yield better results.
+ 3. Click the **Generate** button to create the image. By selecting different **random seeds**, you can generate diverse images.
+ 4. **You can directly click the "Load Example" button on any sample at the bottom to load example inputs.**
+ """
+ )
+
+ loading_status = gr.Textbox(label="Loading Model...", value="Loading model... Please wait...", visible=True)
+ main_interface = gr.Column(visible=False)
+
+ def initialize_model():
+ try:
+ load_model()
+ return {
+ loading_status: gr.update(value="Model loaded successfully!", visible=False),
+ main_interface: gr.update(visible=True),
+ }
+ except Exception as e:
+ print(f'Failed to load model with error: {e}')
+ return {
+ loading_status: gr.update(value=f"Failed to load model: {str(e)}", visible=True),
+ main_interface: gr.update(visible=True),
+ }
+
+ app.load(initialize_model, inputs=None, outputs=[loading_status, main_interface])
+
+ with main_interface:
+ with gr.Row():
+ local_prompt_list = []
+ canvas_list = []
+ random_mask_dir = gr.State(f'{random.randint(0, 1000000):08d}')
+ with gr.Column(scale=382, min_width=100):
+ model_type = gr.State('FLUX')
+ model_path = gr.State('FLUX.1-dev')
+ with gr.Accordion(label="Global prompt"):
+ prompt = gr.Textbox(label="Global Prompt", lines=3)
+ negative_prompt = gr.Textbox(label="Negative prompt", value="worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw, blur,", lines=3)
+ with gr.Accordion(label="Inference Options", open=True):
+ seed = gr.Number(minimum=0, maximum=10**9, value=42, interactive=True, label="Random seed", show_label=True)
+ num_inference_steps = gr.Slider(minimum=1, maximum=100, value=30, step=1, interactive=True, label="Inference steps")
+ cfg_scale = gr.Slider(minimum=2.0, maximum=10.0, value=3.0, step=0.1, interactive=True, label="Classifier-free guidance scale")
+ embedded_guidance = gr.Slider(minimum=0.0, maximum=10.0, value=3.5, step=0.1, interactive=True, label="Embedded guidance scale")
+ height = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, interactive=True, label="Height")
+ width = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, interactive=True, label="Width")
+ with gr.Accordion(label="Inpaint Input Image", open=False):
+ input_image = gr.Image(sources=None, show_label=False, interactive=True, type="pil")
+ background_weight = gr.Slider(minimum=0.0, maximum=1000., value=0., step=1, interactive=False, label="background_weight", visible=False)
+
+ with gr.Column():
+ reset_input_button = gr.Button(value="Reset Inpaint Input")
+ send_input_to_painter = gr.Button(value="Set as painter's background")
+ @gr.on(inputs=[input_image], outputs=[input_image], triggers=reset_input_button.click)
+ def reset_input_image(input_image):
+ return None
+
+ with gr.Column(scale=618, min_width=100):
+ with gr.Accordion(label="Entity Painter"):
+ for painter_layer_id in range(config["max_num_painter_layers"]):
+ with gr.Tab(label=f"Entity {painter_layer_id}"):
+ local_prompt = gr.Textbox(label="Local prompt", key=f"local_prompt_{painter_layer_id}")
+ canvas = gr.ImageEditor(
+ canvas_size=(512, 512),
+ sources=None,
+ layers=False,
+ interactive=True,
+ image_mode="RGBA",
+ brush=gr.Brush(
+ default_size=50,
+ default_color="#000000",
+ colors=["#000000"],
+ ),
+ label="Entity Mask Painter",
+ key=f"canvas_{painter_layer_id}",
+ width=width,
+ height=height,
+ )
+ @gr.on(inputs=[height, width, canvas], outputs=canvas, triggers=[height.change, width.change, canvas.clear], show_progress="hidden")
+ def resize_canvas(height, width, canvas):
+ h, w = canvas["background"].shape[:2]
+ if h != height or width != w:
+ return np.ones((height, width, 3), dtype=np.uint8) * 255
+ else:
+ return canvas
+ local_prompt_list.append(local_prompt)
+ canvas_list.append(canvas)
+ with gr.Accordion(label="Results"):
+ run_button = gr.Button(value="Generate", variant="primary")
+ output_image = gr.Image(sources=None, show_label=False, interactive=False, type="pil")
+ with gr.Row():
+ with gr.Column():
+ output_to_painter_button = gr.Button(value="Set as painter's background")
+ with gr.Column():
+ return_with_mask = gr.Checkbox(value=False, interactive=True, label="show result with mask painting")
+ output_to_input_button = gr.Button(value="Set as input image", visible=False, interactive=False)
+ real_output = gr.State(None)
+ mask_out = gr.State(None)
+
+ @gr.on(
+ inputs=[model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, return_with_mask, seed, input_image, background_weight, random_mask_dir] + local_prompt_list + canvas_list,
+ outputs=[output_image, real_output, mask_out],
+ triggers=run_button.click
+ )
+ def generate_image(model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, return_with_mask, seed, input_image, background_weight, random_mask_dir, *args, progress=gr.Progress()):
+ _, pipe = load_model(model_type, model_path)
+ input_params = {
+ "prompt": prompt,
+ "negative_prompt": negative_prompt,
+ "cfg_scale": cfg_scale,
+ "num_inference_steps": num_inference_steps,
+ "height": height,
+ "width": width,
+ "progress_bar_cmd": progress.tqdm,
+ }
+ if isinstance(pipe, FluxImagePipeline):
+ input_params["embedded_guidance"] = embedded_guidance
+ if input_image is not None:
+ input_params["input_image"] = input_image.resize((width, height)).convert("RGB")
+ input_params["enable_eligen_inpaint"] = True
+
+ local_prompt_list, canvas_list = (
+ args[0 * config["max_num_painter_layers"]: 1 * config["max_num_painter_layers"]],
+ args[1 * config["max_num_painter_layers"]: 2 * config["max_num_painter_layers"]],
+ )
+ local_prompts, masks = [], []
+ for local_prompt, canvas in zip(local_prompt_list, canvas_list):
+ if isinstance(local_prompt, str) and len(local_prompt) > 0:
+ local_prompts.append(local_prompt)
+ masks.append(Image.fromarray(canvas["layers"][0][:, :, -1]).convert("RGB"))
+ entity_masks = None if len(masks) == 0 else masks
+ entity_prompts = None if len(local_prompts) == 0 else local_prompts
+ input_params.update({
+ "eligen_entity_prompts": entity_prompts,
+ "eligen_entity_masks": entity_masks,
+ })
+ torch.manual_seed(seed)
+ # save_mask_prompts(masks, local_prompts, prompt, seed, random_mask_dir)
+ image = pipe(**input_params)
+ masks = [mask.resize(image.size) for mask in masks]
+ image_with_mask = visualize_masks(image, masks, local_prompts)
+
+ real_output = gr.State(image)
+ mask_out = gr.State(image_with_mask)
+
+ if return_with_mask:
+ return image_with_mask, real_output, mask_out
+ return image, real_output, mask_out
+
+ @gr.on(inputs=[input_image] + canvas_list, outputs=canvas_list, triggers=send_input_to_painter.click)
+ def send_input_to_painter_background(input_image, *canvas_list):
+ if input_image is None:
+ return tuple(canvas_list)
+ for canvas in canvas_list:
+ h, w = canvas["background"].shape[:2]
+ canvas["background"] = input_image.resize((w, h))
+ return tuple(canvas_list)
+ @gr.on(inputs=[real_output] + canvas_list, outputs=canvas_list, triggers=output_to_painter_button.click)
+ def send_output_to_painter_background(real_output, *canvas_list):
+ if real_output is None:
+ return tuple(canvas_list)
+ for canvas in canvas_list:
+ h, w = canvas["background"].shape[:2]
+ canvas["background"] = real_output.value.resize((w, h))
+ return tuple(canvas_list)
+ @gr.on(inputs=[return_with_mask, real_output, mask_out], outputs=[output_image], triggers=[return_with_mask.change], show_progress="hidden")
+ def show_output(return_with_mask, real_output, mask_out):
+ if return_with_mask:
+ return mask_out.value
+ else:
+ return real_output.value
+ @gr.on(inputs=[real_output], outputs=[input_image], triggers=output_to_input_button.click)
+ def send_output_to_pipe_input(real_output):
+ return real_output.value
+
+ with gr.Column():
+ gr.Markdown("## Examples")
+ for i in range(0, len(examples), 2):
+ with gr.Row():
+ if i < len(examples):
+ example = examples[i]
+ with gr.Column():
+ example_image = gr.Image(
+ value=f"data/examples/eligen/entity_control/example_{example['example_id']}/example_image.png",
+ label=example["description"],
+ interactive=False,
+ width=1024,
+ height=512
+ )
+ load_example_button = gr.Button(value=f"Load Example {example['example_id']}")
+ load_example_button.click(
+ load_example,
+ inputs=[load_example_button],
+ outputs=[num_inference_steps, prompt, negative_prompt, seed] + local_prompt_list + canvas_list
+ )
+
+ if i + 1 < len(examples):
+ example = examples[i + 1]
+ with gr.Column():
+ example_image = gr.Image(
+ value=f"data/examples/eligen/entity_control/example_{example['example_id']}/example_image.png",
+ label=example["description"],
+ interactive=False,
+ width=1024,
+ height=512
+ )
+ load_example_button = gr.Button(value=f"Load Example {example['example_id']}")
+ load_example_button.click(
+ load_example,
+ inputs=[load_example_button],
+ outputs=[num_inference_steps, prompt, negative_prompt, seed] + local_prompt_list + canvas_list
+ )
+app.config["show_progress"] = "hidden"
+app.launch()
diff --git a/apps/gradio/qwen_image_eligen.py b/apps/gradio/qwen_image_eligen.py
new file mode 100644
index 0000000000000000000000000000000000000000..c224f01cad5e7fd41bb8102728a899a8ecb3b90f
--- /dev/null
+++ b/apps/gradio/qwen_image_eligen.py
@@ -0,0 +1,382 @@
+import os
+import torch
+import numpy as np
+from PIL import Image, ImageDraw, ImageFont
+import random
+import json
+import gradio as gr
+from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
+from modelscope import dataset_snapshot_download, snapshot_download
+
+# pip install pydantic==2.10.6
+# pip install gradio==5.4.0
+
+snapshot_download("DiffSynth-Studio/Qwen-Image-EliGen", local_dir="models/DiffSynth-Studio/Qwen-Image-EliGen", allow_file_pattern="model.safetensors")
+
+dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern=f"data/examples/eligen/qwen-image/*")
+example_json = 'data/examples/eligen/qwen-image/ui_examples.json'
+with open(example_json, 'r') as f:
+ examples = json.load(f)['examples']
+
+for idx in range(len(examples)):
+ example_id = examples[idx]['example_id']
+ entity_prompts = examples[idx]['local_prompt_list']
+ examples[idx]['mask_lists'] = [Image.open(f"data/examples/eligen/qwen-image/example_{example_id}/{i}.png").convert('RGB') for i in range(len(entity_prompts))]
+
+def create_canvas_data(background, masks):
+ if background.shape[-1] == 3:
+ background = np.dstack([background, np.full(background.shape[:2], 255, dtype=np.uint8)])
+ layers = []
+ for mask in masks:
+ if mask is not None:
+ mask_single_channel = mask if mask.ndim == 2 else mask[..., 0]
+ layer = np.zeros((mask_single_channel.shape[0], mask_single_channel.shape[1], 4), dtype=np.uint8)
+ layer[..., -1] = mask_single_channel
+ layers.append(layer)
+ else:
+ layers.append(np.zeros_like(background))
+
+ composite = background.copy()
+ for layer in layers:
+ if layer.size > 0:
+ composite = np.where(layer[..., -1:] > 0, layer, composite)
+ return {
+ "background": background,
+ "layers": layers,
+ "composite": composite,
+ }
+
+def load_example(load_example_button):
+ example_idx = int(load_example_button.split()[-1]) - 1
+ example = examples[example_idx]
+ result = [
+ 50,
+ example["global_prompt"],
+ example["negative_prompt"],
+ example["seed"],
+ *example["local_prompt_list"],
+ ]
+ num_entities = len(example["local_prompt_list"])
+ result += [""] * (config["max_num_painter_layers"] - num_entities)
+ masks = []
+ for mask in example["mask_lists"]:
+ mask_single_channel = np.array(mask.convert("L"))
+ masks.append(mask_single_channel)
+ for _ in range(config["max_num_painter_layers"] - len(masks)):
+ blank_mask = np.zeros_like(masks[0]) if masks else np.zeros((512, 512), dtype=np.uint8)
+ masks.append(blank_mask)
+ background = np.ones((masks[0].shape[0], masks[0].shape[1], 4), dtype=np.uint8) * 255
+ canvas_data_list = []
+ for mask in masks:
+ canvas_data = create_canvas_data(background, [mask])
+ canvas_data_list.append(canvas_data)
+ result.extend(canvas_data_list)
+ return result
+
+def save_mask_prompts(masks, mask_prompts, global_prompt, seed=0, random_dir='0000000'):
+ save_dir = os.path.join('workdirs/tmp_mask', random_dir)
+ print(f'save to {save_dir}')
+ os.makedirs(save_dir, exist_ok=True)
+ for i, mask in enumerate(masks):
+ save_path = os.path.join(save_dir, f'{i}.png')
+ mask.save(save_path)
+ sample = {
+ "global_prompt": global_prompt,
+ "mask_prompts": mask_prompts,
+ "seed": seed,
+ }
+ with open(os.path.join(save_dir, f"prompts.json"), 'w', encoding='utf-8') as f:
+ json.dump(sample, f, ensure_ascii=False, indent=4)
+
+def visualize_masks(image, masks, mask_prompts, font_size=35, use_random_colors=False):
+ # Create a blank image for overlays
+ overlay = Image.new('RGBA', image.size, (0, 0, 0, 0))
+ colors = [
+ (165, 238, 173, 80),
+ (76, 102, 221, 80),
+ (221, 160, 77, 80),
+ (204, 93, 71, 80),
+ (145, 187, 149, 80),
+ (134, 141, 172, 80),
+ (157, 137, 109, 80),
+ (153, 104, 95, 80),
+ (165, 238, 173, 80),
+ (76, 102, 221, 80),
+ (221, 160, 77, 80),
+ (204, 93, 71, 80),
+ (145, 187, 149, 80),
+ (134, 141, 172, 80),
+ (157, 137, 109, 80),
+ (153, 104, 95, 80),
+ ]
+ # Generate random colors for each mask
+ if use_random_colors:
+ colors = [(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 80) for _ in range(len(masks))]
+ # Font settings
+ try:
+ font = ImageFont.truetype("wqy-zenhei.ttc", font_size) # Adjust as needed
+ except IOError:
+ font = ImageFont.load_default(font_size)
+ # Overlay each mask onto the overlay image
+ for mask, mask_prompt, color in zip(masks, mask_prompts, colors):
+ if mask is None:
+ continue
+ # Convert mask to RGBA mode
+ mask_rgba = mask.convert('RGBA')
+ mask_data = mask_rgba.getdata()
+ new_data = [(color if item[:3] == (255, 255, 255) else (0, 0, 0, 0)) for item in mask_data]
+ mask_rgba.putdata(new_data)
+ # Draw the mask prompt text on the mask
+ draw = ImageDraw.Draw(mask_rgba)
+ mask_bbox = mask.getbbox() # Get the bounding box of the mask
+ if mask_bbox is None:
+ continue
+ text_position = (mask_bbox[0] + 10, mask_bbox[1] + 10) # Adjust text position based on mask position
+ draw.text(text_position, mask_prompt, fill=(255, 255, 255, 255), font=font)
+ # Alpha composite the overlay with this mask
+ overlay = Image.alpha_composite(overlay, mask_rgba)
+ # Composite the overlay onto the original image
+ result = Image.alpha_composite(image.convert('RGBA'), overlay)
+ return result
+
+config = {
+ "max_num_painter_layers": 8,
+ "max_num_model_cache": 1,
+}
+
+model_dict = {}
+
+def load_model(model_type='qwen-image'):
+ global model_dict
+ model_key = f"{model_type}"
+ if model_key in model_dict:
+ return model_dict[model_key]
+ pipe = QwenImagePipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
+ ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
+ ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
+ ],
+ tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
+ )
+ pipe.load_lora(pipe.dit, "models/DiffSynth-Studio/Qwen-Image-EliGen/model.safetensors")
+ model_dict[model_key] = pipe
+ return pipe
+
+load_model('qwen-image')
+
+with gr.Blocks() as app:
+ gr.Markdown(
+ """## EliGen: Entity-Level Controllable Text-to-Image Model
+ 1. On the left, input the **global prompt** for the overall image, such as "a person stands by the river."
+ 2. On the right, input the **local prompt** for each entity, such as "person," and draw the corresponding mask in the **Entity Mask Painter**. Generally, solid rectangular masks yield better results.
+ 3. Click the **Generate** button to create the image. By selecting different **random seeds**, you can generate diverse images.
+ 4. **You can directly click the "Load Example" button on any sample at the bottom to load example inputs.**
+ """
+ )
+
+ loading_status = gr.Textbox(label="Loading Model...", value="Loading model... Please wait...", visible=True)
+ main_interface = gr.Column(visible=False)
+
+ def initialize_model():
+ try:
+ load_model('qwen-image')
+ return {
+ loading_status: gr.update(value="Model loaded successfully!", visible=False),
+ main_interface: gr.update(visible=True),
+ }
+ except Exception as e:
+ print(f'Failed to load model with error: {e}')
+ return {
+ loading_status: gr.update(value=f"Failed to load model: {str(e)}", visible=True),
+ main_interface: gr.update(visible=True),
+ }
+
+ app.load(initialize_model, inputs=None, outputs=[loading_status, main_interface])
+
+ with main_interface:
+ with gr.Row():
+ local_prompt_list = []
+ canvas_list = []
+ random_mask_dir = gr.State(f'{random.randint(0, 1000000):08d}')
+ with gr.Column(scale=382, min_width=100):
+ model_type = gr.State('qwen-image')
+ with gr.Accordion(label="Global prompt"):
+ prompt = gr.Textbox(label="Global Prompt", lines=3)
+ negative_prompt = gr.Textbox(label="Negative prompt", value="", lines=3)
+ with gr.Accordion(label="Inference Options", open=True):
+ seed = gr.Number(minimum=0, maximum=10**9, value=42, interactive=True, label="Random seed", show_label=True)
+ num_inference_steps = gr.Slider(minimum=1, maximum=100, value=30, step=1, interactive=True, label="Inference steps")
+ cfg_scale = gr.Slider(minimum=2.0, maximum=10.0, value=4.0, step=0.1, interactive=True, label="Classifier-free guidance scale")
+ height = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, interactive=True, label="Height")
+ width = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, interactive=True, label="Width")
+ with gr.Accordion(label="Inpaint Input Image", open=False, visible=False):
+ input_image = gr.Image(sources=None, show_label=False, interactive=True, type="pil")
+ background_weight = gr.Slider(minimum=0.0, maximum=1000., value=0., step=1, interactive=False, label="background_weight", visible=False)
+
+ with gr.Column():
+ reset_input_button = gr.Button(value="Reset Inpaint Input")
+ send_input_to_painter = gr.Button(value="Set as painter's background")
+ @gr.on(inputs=[input_image], outputs=[input_image], triggers=reset_input_button.click)
+ def reset_input_image(input_image):
+ return None
+
+ with gr.Column(scale=618, min_width=100):
+ with gr.Accordion(label="Entity Painter"):
+ for painter_layer_id in range(config["max_num_painter_layers"]):
+ with gr.Tab(label=f"Entity {painter_layer_id}"):
+ local_prompt = gr.Textbox(label="Local prompt", key=f"local_prompt_{painter_layer_id}")
+ canvas = gr.ImageEditor(
+ canvas_size=(1024, 1024),
+ sources=None,
+ layers=False,
+ interactive=True,
+ image_mode="RGBA",
+ brush=gr.Brush(
+ default_size=50,
+ default_color="#000000",
+ colors=["#000000"],
+ ),
+ label="Entity Mask Painter",
+ key=f"canvas_{painter_layer_id}",
+ width=width,
+ height=height,
+ )
+ @gr.on(inputs=[height, width, canvas], outputs=canvas, triggers=[height.change, width.change, canvas.clear], show_progress="hidden")
+ def resize_canvas(height, width, canvas):
+ if canvas is None or canvas["background"] is None:
+ return np.ones((height, width, 3), dtype=np.uint8) * 255
+ h, w = canvas["background"].shape[:2]
+ if h != height or width != w:
+ return np.ones((height, width, 3), dtype=np.uint8) * 255
+ else:
+ return canvas
+ local_prompt_list.append(local_prompt)
+ canvas_list.append(canvas)
+ with gr.Accordion(label="Results"):
+ run_button = gr.Button(value="Generate", variant="primary")
+ output_image = gr.Image(sources=None, show_label=False, interactive=False, type="pil")
+ with gr.Row():
+ with gr.Column():
+ output_to_painter_button = gr.Button(value="Set as painter's background")
+ with gr.Column():
+ return_with_mask = gr.Checkbox(value=False, interactive=True, label="show result with mask painting")
+ output_to_input_button = gr.Button(value="Set as input image", visible=False, interactive=False)
+ real_output = gr.State(None)
+ mask_out = gr.State(None)
+
+ @gr.on(
+ inputs=[model_type, prompt, negative_prompt, cfg_scale, num_inference_steps, height, width, return_with_mask, seed, input_image, background_weight, random_mask_dir] + local_prompt_list + canvas_list,
+ outputs=[output_image, real_output, mask_out],
+ triggers=run_button.click
+ )
+ def generate_image(model_type, prompt, negative_prompt, cfg_scale, num_inference_steps, height, width, return_with_mask, seed, input_image, background_weight, random_mask_dir, *args, progress=gr.Progress()):
+ pipe = load_model(model_type)
+ input_params = {
+ "prompt": prompt,
+ "negative_prompt": negative_prompt,
+ "cfg_scale": cfg_scale,
+ "num_inference_steps": num_inference_steps,
+ "height": height,
+ "width": width,
+ "progress_bar_cmd": progress.tqdm,
+ }
+ # if input_image is not None:
+ # input_params["input_image"] = input_image.resize((width, height)).convert("RGB")
+ # input_params["enable_eligen_inpaint"] = True
+
+ local_prompt_list, canvas_list = (
+ args[0 * config["max_num_painter_layers"]: 1 * config["max_num_painter_layers"]],
+ args[1 * config["max_num_painter_layers"]: 2 * config["max_num_painter_layers"]],
+ )
+ local_prompts, masks = [], []
+ for local_prompt, canvas in zip(local_prompt_list, canvas_list):
+ if isinstance(local_prompt, str) and len(local_prompt) > 0:
+ local_prompts.append(local_prompt)
+ masks.append(Image.fromarray(canvas["layers"][0][:, :, -1]).convert("RGB"))
+ entity_prompts = None if len(local_prompts) == 0 else local_prompts
+ entity_masks = None if len(masks) == 0 or entity_prompts is None else masks
+ input_params.update({
+ "eligen_entity_prompts": entity_prompts,
+ "eligen_entity_masks": entity_masks,
+ })
+ torch.manual_seed(seed)
+ save_mask_prompts(masks, local_prompts, prompt, seed, random_mask_dir)
+ image = pipe(**input_params)
+ masks = [mask.resize(image.size) for mask in masks]
+ image_with_mask = visualize_masks(image, masks, local_prompts)
+
+ real_output = gr.State(image)
+ mask_out = gr.State(image_with_mask)
+
+ if return_with_mask:
+ return image_with_mask, real_output, mask_out
+ return image, real_output, mask_out
+
+ @gr.on(inputs=[input_image] + canvas_list, outputs=canvas_list, triggers=send_input_to_painter.click)
+ def send_input_to_painter_background(input_image, *canvas_list):
+ if input_image is None:
+ return tuple(canvas_list)
+ for canvas in canvas_list:
+ h, w = canvas["background"].shape[:2]
+ canvas["background"] = input_image.resize((w, h))
+ return tuple(canvas_list)
+ @gr.on(inputs=[real_output] + canvas_list, outputs=canvas_list, triggers=output_to_painter_button.click)
+ def send_output_to_painter_background(real_output, *canvas_list):
+ if real_output is None:
+ return tuple(canvas_list)
+ for canvas in canvas_list:
+ h, w = canvas["background"].shape[:2]
+ canvas["background"] = real_output.value.resize((w, h))
+ return tuple(canvas_list)
+ @gr.on(inputs=[return_with_mask, real_output, mask_out], outputs=[output_image], triggers=[return_with_mask.change], show_progress="hidden")
+ def show_output(return_with_mask, real_output, mask_out):
+ if return_with_mask:
+ return mask_out.value
+ else:
+ return real_output.value
+ @gr.on(inputs=[real_output], outputs=[input_image], triggers=output_to_input_button.click)
+ def send_output_to_pipe_input(real_output):
+ return real_output.value
+
+ with gr.Column():
+ gr.Markdown("## Examples")
+ for i in range(0, len(examples), 2):
+ with gr.Row():
+ if i < len(examples):
+ example = examples[i]
+ with gr.Column():
+ example_image = gr.Image(
+ value=f"data/examples/eligen/qwen-image/example_{example['example_id']}/example_image.png",
+ label=example["description"],
+ interactive=False,
+ width=1024,
+ height=512
+ )
+ load_example_button = gr.Button(value=f"Load Example {example['example_id']}")
+ load_example_button.click(
+ load_example,
+ inputs=[load_example_button],
+ outputs=[num_inference_steps, prompt, negative_prompt, seed] + local_prompt_list + canvas_list
+ )
+
+ if i + 1 < len(examples):
+ example = examples[i + 1]
+ with gr.Column():
+ example_image = gr.Image(
+ value=f"data/examples/eligen/qwen-image/example_{example['example_id']}/example_image.png",
+ label=example["description"],
+ interactive=False,
+ width=1024,
+ height=512
+ )
+ load_example_button = gr.Button(value=f"Load Example {example['example_id']}")
+ load_example_button.click(
+ load_example,
+ inputs=[load_example_button],
+ outputs=[num_inference_steps, prompt, negative_prompt, seed] + local_prompt_list + canvas_list
+ )
+app.config["show_progress"] = "hidden"
+app.launch(share=False)
diff --git a/apps/streamlit/DiffSynth_Studio.py b/apps/streamlit/DiffSynth_Studio.py
new file mode 100644
index 0000000000000000000000000000000000000000..cfd38565da745df53111a94b77d4354a2d475042
--- /dev/null
+++ b/apps/streamlit/DiffSynth_Studio.py
@@ -0,0 +1,15 @@
+# Set web page format
+import streamlit as st
+st.set_page_config(layout="wide")
+# Disable virtual VRAM on windows system
+import torch
+torch.cuda.set_per_process_memory_fraction(0.999, 0)
+
+
+st.markdown("""
+# DiffSynth Studio
+
+[Source Code](https://github.com/Artiprocher/DiffSynth-Studio)
+
+Welcome to DiffSynth Studio.
+""")
diff --git a/apps/streamlit/pages/1_Image_Creator.py b/apps/streamlit/pages/1_Image_Creator.py
new file mode 100644
index 0000000000000000000000000000000000000000..732d2195ce45061739c8cbe46623a181f6a30c08
--- /dev/null
+++ b/apps/streamlit/pages/1_Image_Creator.py
@@ -0,0 +1,362 @@
+import torch, os, io, json, time
+import numpy as np
+from PIL import Image
+import streamlit as st
+st.set_page_config(layout="wide")
+from streamlit_drawable_canvas import st_canvas
+from diffsynth.models import ModelManager
+from diffsynth.pipelines import SDImagePipeline, SDXLImagePipeline, SD3ImagePipeline, HunyuanDiTImagePipeline, FluxImagePipeline
+from diffsynth.data.video import crop_and_resize
+
+
+config = {
+ "Stable Diffusion": {
+ "model_folder": "models/stable_diffusion",
+ "pipeline_class": SDImagePipeline,
+ "fixed_parameters": {}
+ },
+ "Stable Diffusion XL": {
+ "model_folder": "models/stable_diffusion_xl",
+ "pipeline_class": SDXLImagePipeline,
+ "fixed_parameters": {}
+ },
+ "Stable Diffusion 3": {
+ "model_folder": "models/stable_diffusion_3",
+ "pipeline_class": SD3ImagePipeline,
+ "fixed_parameters": {}
+ },
+ "Stable Diffusion XL Turbo": {
+ "model_folder": "models/stable_diffusion_xl_turbo",
+ "pipeline_class": SDXLImagePipeline,
+ "fixed_parameters": {
+ "negative_prompt": "",
+ "cfg_scale": 1.0,
+ "num_inference_steps": 1,
+ "height": 512,
+ "width": 512,
+ }
+ },
+ "Kolors": {
+ "model_folder": "models/kolors",
+ "pipeline_class": SDXLImagePipeline,
+ "fixed_parameters": {}
+ },
+ "HunyuanDiT": {
+ "model_folder": "models/HunyuanDiT",
+ "pipeline_class": HunyuanDiTImagePipeline,
+ "fixed_parameters": {
+ "height": 1024,
+ "width": 1024,
+ }
+ },
+ "FLUX": {
+ "model_folder": "models/FLUX",
+ "pipeline_class": FluxImagePipeline,
+ "fixed_parameters": {
+ "cfg_scale": 1.0,
+ }
+ }
+}
+
+
+def load_model_list(model_type):
+ folder = config[model_type]["model_folder"]
+ file_list = [i for i in os.listdir(folder) if i.endswith(".safetensors")]
+ if model_type in ["HunyuanDiT", "Kolors", "FLUX"]:
+ file_list += [i for i in os.listdir(folder) if os.path.isdir(os.path.join(folder, i))]
+ file_list = sorted(file_list)
+ return file_list
+
+
+def release_model():
+ if "model_manager" in st.session_state:
+ st.session_state["model_manager"].to("cpu")
+ del st.session_state["loaded_model_path"]
+ del st.session_state["model_manager"]
+ del st.session_state["pipeline"]
+ torch.cuda.empty_cache()
+
+
+def load_model(model_type, model_path):
+ model_manager = ModelManager()
+ if model_type == "HunyuanDiT":
+ model_manager.load_models([
+ os.path.join(model_path, "clip_text_encoder/pytorch_model.bin"),
+ os.path.join(model_path, "mt5/pytorch_model.bin"),
+ os.path.join(model_path, "model/pytorch_model_ema.pt"),
+ os.path.join(model_path, "sdxl-vae-fp16-fix/diffusion_pytorch_model.bin"),
+ ])
+ elif model_type == "Kolors":
+ model_manager.load_models([
+ os.path.join(model_path, "text_encoder"),
+ os.path.join(model_path, "unet/diffusion_pytorch_model.safetensors"),
+ os.path.join(model_path, "vae/diffusion_pytorch_model.safetensors"),
+ ])
+ elif model_type == "FLUX":
+ model_manager.torch_dtype = torch.bfloat16
+ file_list = [
+ os.path.join(model_path, "text_encoder/model.safetensors"),
+ os.path.join(model_path, "text_encoder_2"),
+ ]
+ for file_name in os.listdir(model_path):
+ if file_name.endswith(".safetensors"):
+ file_list.append(os.path.join(model_path, file_name))
+ model_manager.load_models(file_list)
+ else:
+ model_manager.load_model(model_path)
+ pipeline = config[model_type]["pipeline_class"].from_model_manager(model_manager)
+ st.session_state.loaded_model_path = model_path
+ st.session_state.model_manager = model_manager
+ st.session_state.pipeline = pipeline
+ return model_manager, pipeline
+
+
+def use_output_image_as_input(update=True):
+ # Search for input image
+ output_image_id = 0
+ selected_output_image = None
+ while True:
+ if f"use_output_as_input_{output_image_id}" not in st.session_state:
+ break
+ if st.session_state[f"use_output_as_input_{output_image_id}"]:
+ selected_output_image = st.session_state["output_images"][output_image_id]
+ break
+ output_image_id += 1
+ if update and selected_output_image is not None:
+ st.session_state["input_image"] = selected_output_image
+ return selected_output_image is not None
+
+
+def apply_stroke_to_image(stroke_image, image):
+ image = np.array(image.convert("RGB")).astype(np.float32)
+ height, width, _ = image.shape
+
+ stroke_image = np.array(Image.fromarray(stroke_image).resize((width, height))).astype(np.float32)
+ weight = stroke_image[:, :, -1:] / 255
+ stroke_image = stroke_image[:, :, :-1]
+
+ image = stroke_image * weight + image * (1 - weight)
+ image = np.clip(image, 0, 255).astype(np.uint8)
+ image = Image.fromarray(image)
+ return image
+
+
+@st.cache_data
+def image2bits(image):
+ image_byte = io.BytesIO()
+ image.save(image_byte, format="PNG")
+ image_byte = image_byte.getvalue()
+ return image_byte
+
+
+def show_output_image(image):
+ st.image(image, use_column_width="always")
+ st.button("Use it as input image", key=f"use_output_as_input_{image_id}")
+ st.download_button("Download", data=image2bits(image), file_name="image.png", mime="image/png", key=f"download_output_{image_id}")
+
+
+column_input, column_output = st.columns(2)
+with st.sidebar:
+ # Select a model
+ with st.expander("Model", expanded=True):
+ model_type = st.selectbox("Model type", [model_type_ for model_type_ in config])
+ fixed_parameters = config[model_type]["fixed_parameters"]
+ model_path_list = ["None"] + load_model_list(model_type)
+ model_path = st.selectbox("Model path", model_path_list)
+
+ # Load the model
+ if model_path == "None":
+ # No models are selected. Release VRAM.
+ st.markdown("No models are selected.")
+ release_model()
+ else:
+ # A model is selected.
+ model_path = os.path.join(config[model_type]["model_folder"], model_path)
+ if st.session_state.get("loaded_model_path", "") != model_path:
+ # The loaded model is not the selected model. Reload it.
+ st.markdown(f"Loading model at {model_path}.")
+ st.markdown("Please wait a moment...")
+ release_model()
+ model_manager, pipeline = load_model(model_type, model_path)
+ st.markdown("Done.")
+ else:
+ # The loaded model is not the selected model. Fetch it from `st.session_state`.
+ st.markdown(f"Loading model at {model_path}.")
+ st.markdown("Please wait a moment...")
+ model_manager, pipeline = st.session_state.model_manager, st.session_state.pipeline
+ st.markdown("Done.")
+
+ # Show parameters
+ with st.expander("Prompt", expanded=True):
+ prompt = st.text_area("Positive prompt")
+ if "negative_prompt" in fixed_parameters:
+ negative_prompt = fixed_parameters["negative_prompt"]
+ else:
+ negative_prompt = st.text_area("Negative prompt")
+ if "cfg_scale" in fixed_parameters:
+ cfg_scale = fixed_parameters["cfg_scale"]
+ else:
+ cfg_scale = st.slider("Classifier-free guidance scale", min_value=1.0, max_value=10.0, value=7.5)
+ with st.expander("Image", expanded=True):
+ if "num_inference_steps" in fixed_parameters:
+ num_inference_steps = fixed_parameters["num_inference_steps"]
+ else:
+ num_inference_steps = st.slider("Inference steps", min_value=1, max_value=100, value=20)
+ if "height" in fixed_parameters:
+ height = fixed_parameters["height"]
+ else:
+ height = st.select_slider("Height", options=[256, 512, 768, 1024, 2048], value=512)
+ if "width" in fixed_parameters:
+ width = fixed_parameters["width"]
+ else:
+ width = st.select_slider("Width", options=[256, 512, 768, 1024, 2048], value=512)
+ num_images = st.number_input("Number of images", value=2)
+ use_fixed_seed = st.checkbox("Use fixed seed", value=False)
+ if use_fixed_seed:
+ seed = st.number_input("Random seed", min_value=0, max_value=10**9, step=1, value=0)
+
+ # Other fixed parameters
+ denoising_strength = 1.0
+ repetition = 1
+
+
+# Show input image
+with column_input:
+ with st.expander("Input image (Optional)", expanded=True):
+ with st.container(border=True):
+ column_white_board, column_upload_image = st.columns([1, 2])
+ with column_white_board:
+ create_white_board = st.button("Create white board")
+ delete_input_image = st.button("Delete input image")
+ with column_upload_image:
+ upload_image = st.file_uploader("Upload image", type=["png", "jpg"], key="upload_image")
+
+ if upload_image is not None:
+ st.session_state["input_image"] = crop_and_resize(Image.open(upload_image), height, width)
+ elif create_white_board:
+ st.session_state["input_image"] = Image.fromarray(np.ones((height, width, 3), dtype=np.uint8) * 255)
+ else:
+ use_output_image_as_input()
+
+ if delete_input_image and "input_image" in st.session_state:
+ del st.session_state.input_image
+ if delete_input_image and "upload_image" in st.session_state:
+ del st.session_state.upload_image
+
+ input_image = st.session_state.get("input_image", None)
+ if input_image is not None:
+ with st.container(border=True):
+ column_drawing_mode, column_color_1, column_color_2 = st.columns([4, 1, 1])
+ with column_drawing_mode:
+ drawing_mode = st.radio("Drawing tool", ["transform", "freedraw", "line", "rect"], horizontal=True, index=1)
+ with column_color_1:
+ stroke_color = st.color_picker("Stroke color")
+ with column_color_2:
+ fill_color = st.color_picker("Fill color")
+ stroke_width = st.slider("Stroke width", min_value=1, max_value=50, value=10)
+ with st.container(border=True):
+ denoising_strength = st.slider("Denoising strength", min_value=0.0, max_value=1.0, value=0.7)
+ repetition = st.slider("Repetition", min_value=1, max_value=8, value=1)
+ with st.container(border=True):
+ input_width, input_height = input_image.size
+ canvas_result = st_canvas(
+ fill_color=fill_color,
+ stroke_width=stroke_width,
+ stroke_color=stroke_color,
+ background_color="rgba(255, 255, 255, 0)",
+ background_image=input_image,
+ update_streamlit=True,
+ height=int(512 / input_width * input_height),
+ width=512,
+ drawing_mode=drawing_mode,
+ key="canvas"
+ )
+
+ num_painter_layer = st.number_input("Number of painter layers", min_value=0, max_value=10, step=1, value=0)
+ local_prompts, masks, mask_scales = [], [], []
+ white_board = Image.fromarray(np.ones((512, 512, 3), dtype=np.uint8) * 255)
+ painter_layers_json_data = []
+ for painter_tab_id in range(num_painter_layer):
+ with st.expander(f"Painter layer {painter_tab_id}", expanded=True):
+ enable_local_prompt = st.checkbox(f"Enable prompt {painter_tab_id}", value=True)
+ local_prompt = st.text_area(f"Prompt {painter_tab_id}")
+ mask_scale = st.slider(f"Mask scale {painter_tab_id}", min_value=0.0, max_value=3.0, value=1.0)
+ stroke_width = st.slider(f"Stroke width {painter_tab_id}", min_value=1, max_value=300, value=100)
+ canvas_result_local = st_canvas(
+ fill_color="#000000",
+ stroke_width=stroke_width,
+ stroke_color="#000000",
+ background_color="rgba(255, 255, 255, 0)",
+ background_image=white_board,
+ update_streamlit=True,
+ height=512,
+ width=512,
+ drawing_mode="freedraw",
+ key=f"canvas_{painter_tab_id}"
+ )
+ if canvas_result_local.json_data is not None:
+ painter_layers_json_data.append(canvas_result_local.json_data.copy())
+ painter_layers_json_data[-1]["prompt"] = local_prompt
+ if enable_local_prompt:
+ local_prompts.append(local_prompt)
+ if canvas_result_local.image_data is not None:
+ mask = apply_stroke_to_image(canvas_result_local.image_data, white_board)
+ else:
+ mask = white_board
+ mask = Image.fromarray(255 - np.array(mask))
+ masks.append(mask)
+ mask_scales.append(mask_scale)
+ save_painter_layers = st.button("Save painter layers")
+ if save_painter_layers:
+ os.makedirs("data/painter_layers", exist_ok=True)
+ json_file_path = f"data/painter_layers/{time.time_ns()}.json"
+ with open(json_file_path, "w") as f:
+ json.dump(painter_layers_json_data, f, indent=4)
+ st.markdown(f"Painter layers are saved in {json_file_path}.")
+
+
+with column_output:
+ run_button = st.button("Generate image", type="primary")
+ auto_update = st.checkbox("Auto update", value=False)
+ num_image_columns = st.slider("Columns", min_value=1, max_value=8, value=2)
+ image_columns = st.columns(num_image_columns)
+
+ # Run
+ if (run_button or auto_update) and model_path != "None":
+
+ if input_image is not None:
+ input_image = input_image.resize((width, height))
+ if canvas_result.image_data is not None:
+ input_image = apply_stroke_to_image(canvas_result.image_data, input_image)
+
+ output_images = []
+ for image_id in range(num_images * repetition):
+ if use_fixed_seed:
+ torch.manual_seed(seed + image_id)
+ else:
+ torch.manual_seed(np.random.randint(0, 10**9))
+ if image_id >= num_images:
+ input_image = output_images[image_id - num_images]
+ with image_columns[image_id % num_image_columns]:
+ progress_bar_st = st.progress(0.0)
+ image = pipeline(
+ prompt, negative_prompt=negative_prompt,
+ local_prompts=local_prompts, masks=masks, mask_scales=mask_scales,
+ cfg_scale=cfg_scale, num_inference_steps=num_inference_steps,
+ height=height, width=width,
+ input_image=input_image, denoising_strength=denoising_strength,
+ progress_bar_st=progress_bar_st
+ )
+ output_images.append(image)
+ progress_bar_st.progress(1.0)
+ show_output_image(image)
+ st.session_state["output_images"] = output_images
+
+ elif "output_images" in st.session_state:
+ for image_id in range(len(st.session_state.output_images)):
+ with image_columns[image_id % num_image_columns]:
+ image = st.session_state.output_images[image_id]
+ progress_bar = st.progress(1.0)
+ show_output_image(image)
+ if "upload_image" in st.session_state and use_output_image_as_input(update=False):
+ st.markdown("If you want to use an output image as input image, please delete the uploaded image manually.")
diff --git a/apps/streamlit/pages/2_Video_Creator.py b/apps/streamlit/pages/2_Video_Creator.py
new file mode 100644
index 0000000000000000000000000000000000000000..87480726d3d32f48a2648266ce84d14e81e42315
--- /dev/null
+++ b/apps/streamlit/pages/2_Video_Creator.py
@@ -0,0 +1,197 @@
+import streamlit as st
+st.set_page_config(layout="wide")
+from diffsynth import SDVideoPipelineRunner
+import os
+import numpy as np
+
+
+def load_model_list(folder):
+ file_list = os.listdir(folder)
+ file_list = [i for i in file_list if i.endswith(".safetensors") or i.endswith(".pth") or i.endswith(".ckpt")]
+ file_list = sorted(file_list)
+ return file_list
+
+
+def match_processor_id(model_name, supported_processor_id_list):
+ sorted_processor_id = [i[1] for i in sorted([(-len(i), i) for i in supported_processor_id_list])]
+ for processor_id in sorted_processor_id:
+ if processor_id in model_name:
+ return supported_processor_id_list.index(processor_id) + 1
+ return 0
+
+
+config = {
+ "models": {
+ "model_list": [],
+ "textual_inversion_folder": "models/textual_inversion",
+ "device": "cuda",
+ "lora_alphas": [],
+ "controlnet_units": []
+ },
+ "data": {
+ "input_frames": None,
+ "controlnet_frames": [],
+ "output_folder": "output",
+ "fps": 60
+ },
+ "pipeline": {
+ "seed": 0,
+ "pipeline_inputs": {}
+ }
+}
+
+
+with st.expander("Model", expanded=True):
+ stable_diffusion_ckpt = st.selectbox("Stable Diffusion", ["None"] + load_model_list("models/stable_diffusion"))
+ if stable_diffusion_ckpt != "None":
+ config["models"]["model_list"].append(os.path.join("models/stable_diffusion", stable_diffusion_ckpt))
+ animatediff_ckpt = st.selectbox("AnimateDiff", ["None"] + load_model_list("models/AnimateDiff"))
+ if animatediff_ckpt != "None":
+ config["models"]["model_list"].append(os.path.join("models/AnimateDiff", animatediff_ckpt))
+ column_lora, column_lora_alpha = st.columns([2, 1])
+ with column_lora:
+ sd_lora_ckpt = st.selectbox("LoRA", ["None"] + load_model_list("models/lora"))
+ with column_lora_alpha:
+ lora_alpha = st.slider("LoRA Alpha", min_value=-4.0, max_value=4.0, value=1.0, step=0.1)
+ if sd_lora_ckpt != "None":
+ config["models"]["model_list"].append(os.path.join("models/lora", sd_lora_ckpt))
+ config["models"]["lora_alphas"].append(lora_alpha)
+
+
+with st.expander("Data", expanded=True):
+ with st.container(border=True):
+ input_video = st.text_input("Input Video File Path (e.g., data/your_video.mp4)", value="")
+ column_height, column_width, column_start_frame_index, column_end_frame_index = st.columns([2, 2, 1, 1])
+ with column_height:
+ height = st.select_slider("Height", options=[256, 512, 768, 1024, 1536, 2048], value=1024)
+ with column_width:
+ width = st.select_slider("Width", options=[256, 512, 768, 1024, 1536, 2048], value=1024)
+ with column_start_frame_index:
+ start_frame_id = st.number_input("Start Frame id", value=0)
+ with column_end_frame_index:
+ end_frame_id = st.number_input("End Frame id", value=16)
+ if input_video != "":
+ config["data"]["input_frames"] = {
+ "video_file": input_video,
+ "image_folder": None,
+ "height": height,
+ "width": width,
+ "start_frame_id": start_frame_id,
+ "end_frame_id": end_frame_id
+ }
+ with st.container(border=True):
+ output_video = st.text_input("Output Video File Path (e.g., data/a_folder_to_save_something)", value="output")
+ fps = st.number_input("FPS", value=60)
+ config["data"]["output_folder"] = output_video
+ config["data"]["fps"] = fps
+
+
+with st.expander("ControlNet Units", expanded=True):
+ supported_processor_id_list = ["canny", "depth", "softedge", "lineart", "lineart_anime", "openpose", "tile"]
+ controlnet_units = st.tabs(["ControlNet Unit 0", "ControlNet Unit 1", "ControlNet Unit 2"])
+ for controlnet_id in range(len(controlnet_units)):
+ with controlnet_units[controlnet_id]:
+ controlnet_ckpt = st.selectbox("ControlNet", ["None"] + load_model_list("models/ControlNet"),
+ key=f"controlnet_ckpt_{controlnet_id}")
+ processor_id = st.selectbox("Processor", ["None"] + supported_processor_id_list,
+ index=match_processor_id(controlnet_ckpt, supported_processor_id_list),
+ disabled=controlnet_ckpt == "None", key=f"processor_id_{controlnet_id}")
+ controlnet_scale = st.slider("Scale", min_value=0.0, max_value=1.0, step=0.01, value=0.5,
+ disabled=controlnet_ckpt == "None", key=f"controlnet_scale_{controlnet_id}")
+ use_input_video_as_controlnet_input = st.checkbox("Use input video as ControlNet input", value=True,
+ disabled=controlnet_ckpt == "None",
+ key=f"use_input_video_as_controlnet_input_{controlnet_id}")
+ if not use_input_video_as_controlnet_input:
+ controlnet_input_video = st.text_input("ControlNet Input Video File Path", value="",
+ disabled=controlnet_ckpt == "None", key=f"controlnet_input_video_{controlnet_id}")
+ column_height, column_width, column_start_frame_index, column_end_frame_index = st.columns([2, 2, 1, 1])
+ with column_height:
+ height = st.select_slider("Height", options=[256, 512, 768, 1024, 1536, 2048], value=1024,
+ disabled=controlnet_ckpt == "None", key=f"controlnet_height_{controlnet_id}")
+ with column_width:
+ width = st.select_slider("Width", options=[256, 512, 768, 1024, 1536, 2048], value=1024,
+ disabled=controlnet_ckpt == "None", key=f"controlnet_width_{controlnet_id}")
+ with column_start_frame_index:
+ start_frame_id = st.number_input("Start Frame id", value=0,
+ disabled=controlnet_ckpt == "None", key=f"controlnet_start_frame_id_{controlnet_id}")
+ with column_end_frame_index:
+ end_frame_id = st.number_input("End Frame id", value=16,
+ disabled=controlnet_ckpt == "None", key=f"controlnet_end_frame_id_{controlnet_id}")
+ if input_video != "":
+ config["data"]["input_video"] = {
+ "video_file": input_video,
+ "image_folder": None,
+ "height": height,
+ "width": width,
+ "start_frame_id": start_frame_id,
+ "end_frame_id": end_frame_id
+ }
+ if controlnet_ckpt != "None":
+ config["models"]["model_list"].append(os.path.join("models/ControlNet", controlnet_ckpt))
+ config["models"]["controlnet_units"].append({
+ "processor_id": processor_id,
+ "model_path": os.path.join("models/ControlNet", controlnet_ckpt),
+ "scale": controlnet_scale,
+ })
+ if use_input_video_as_controlnet_input:
+ config["data"]["controlnet_frames"].append(config["data"]["input_frames"])
+ else:
+ config["data"]["controlnet_frames"].append({
+ "video_file": input_video,
+ "image_folder": None,
+ "height": height,
+ "width": width,
+ "start_frame_id": start_frame_id,
+ "end_frame_id": end_frame_id
+ })
+
+
+with st.container(border=True):
+ with st.expander("Seed", expanded=True):
+ use_fixed_seed = st.checkbox("Use fixed seed", value=False)
+ if use_fixed_seed:
+ seed = st.number_input("Random seed", min_value=0, max_value=10**9, step=1, value=0)
+ else:
+ seed = np.random.randint(0, 10**9)
+ with st.expander("Textual Guidance", expanded=True):
+ prompt = st.text_area("Positive prompt")
+ negative_prompt = st.text_area("Negative prompt")
+ column_cfg_scale, column_clip_skip = st.columns(2)
+ with column_cfg_scale:
+ cfg_scale = st.slider("Classifier-free guidance scale", min_value=1.0, max_value=10.0, value=7.0)
+ with column_clip_skip:
+ clip_skip = st.slider("Clip Skip", min_value=1, max_value=4, value=1)
+ with st.expander("Denoising", expanded=True):
+ column_num_inference_steps, column_denoising_strength = st.columns(2)
+ with column_num_inference_steps:
+ num_inference_steps = st.slider("Inference steps", min_value=1, max_value=100, value=10)
+ with column_denoising_strength:
+ denoising_strength = st.slider("Denoising strength", min_value=0.0, max_value=1.0, value=1.0)
+ with st.expander("Efficiency", expanded=False):
+ animatediff_batch_size = st.slider("Animatediff batch size (sliding window size)", min_value=1, max_value=32, value=16, step=1)
+ animatediff_stride = st.slider("Animatediff stride",
+ min_value=1,
+ max_value=max(2, animatediff_batch_size),
+ value=max(1, animatediff_batch_size // 2),
+ step=1)
+ unet_batch_size = st.slider("UNet batch size", min_value=1, max_value=32, value=1, step=1)
+ controlnet_batch_size = st.slider("ControlNet batch size", min_value=1, max_value=32, value=1, step=1)
+ cross_frame_attention = st.checkbox("Enable Cross-Frame Attention", value=False)
+ config["pipeline"]["seed"] = seed
+ config["pipeline"]["pipeline_inputs"] = {
+ "prompt": prompt,
+ "negative_prompt": negative_prompt,
+ "cfg_scale": cfg_scale,
+ "clip_skip": clip_skip,
+ "denoising_strength": denoising_strength,
+ "num_inference_steps": num_inference_steps,
+ "animatediff_batch_size": animatediff_batch_size,
+ "animatediff_stride": animatediff_stride,
+ "unet_batch_size": unet_batch_size,
+ "controlnet_batch_size": controlnet_batch_size,
+ "cross_frame_attention": cross_frame_attention,
+ }
+
+run_button = st.button("☢️Run☢️", type="primary")
+if run_button:
+ SDVideoPipelineRunner(in_streamlit=True).run(config)
diff --git a/deal1.py b/deal1.py
new file mode 100644
index 0000000000000000000000000000000000000000..6891c20ca74a428d31af5fad77bf0f6fc083f556
--- /dev/null
+++ b/deal1.py
@@ -0,0 +1,82 @@
+import cv2
+import os
+import sys
+from datetime import timedelta
+
+def extract_frames_per_second(video_path, output_dir, interval_seconds=3):
+ """
+ 从视频中每隔 interval_seconds 提取一帧并保存到指定目录(按时间定位,保证时间戳单调递增)。
+ :param video_path: 视频文件的路径
+ :param output_dir: 帧图片的保存目录
+ :param interval_seconds: 每隔多少秒保存一帧(默认3秒)
+ """
+ os.makedirs(output_dir, exist_ok=True)
+ # 1. 如果输出目录已存在且非空则跳过,避免重复处理
+ if os.path.exists(output_dir) and os.listdir(output_dir):
+ print(f"输出目录已存在且非空,跳过提取:{os.path.abspath(output_dir)}")
+ return
+ os.makedirs(output_dir, exist_ok=True)
+ print(f"帧保存目录:{os.path.abspath(output_dir)}")
+
+ # 2. 打开视频文件
+ cap = cv2.VideoCapture(video_path)
+ if not cap.isOpened():
+ raise ValueError(f"无法打开视频文件:{video_path}")
+
+ # 3. 获取视频基本信息
+ fps = cap.get(cv2.CAP_PROP_FPS) # 视频帧率(每秒帧数)
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) # 视频总帧数
+ duration = total_frames / fps if fps > 0 else 0 # 视频总时长(秒)
+ print(f"视频信息:帧率={fps:.2f} FPS | 总帧数={total_frames} | 总时长={timedelta(seconds=duration)}")
+
+ if fps <= 0:
+ raise ValueError("无法获取视频帧率,视频文件可能损坏或格式不支持")
+
+ saved_count = 0 # 已保存的帧序号
+
+ try:
+ t = 0.0 # 当前时间点(秒)
+ # 使用按时间定位的方式读取帧,避免帧计数舍入或读取跳帧导致时间戳混乱
+ while t <= duration:
+ # 定位到指定毫秒位置(更可靠地获取指定时间的帧)
+ cap.set(cv2.CAP_PROP_POS_MSEC, t * 1000)
+ ret, frame = cap.read()
+ if not ret:
+ # 定位或读取失败,跳出循环
+ break
+
+ # 使用时间戳作为文件名的一部分,保证按时间顺序保存
+ frame_filename = f"{saved_count:06d}_{t:.2f}s.jpg"
+ frame_path = os.path.join(output_dir, frame_filename)
+
+ # 保存帧图片
+ cv2.imwrite(frame_path, frame)
+ saved_count += 1
+
+ # 打印进度(每10帧打印一次,避免刷屏)
+ if saved_count % 10 == 0:
+ progress = (t / duration) * 100 if duration > 0 else 0
+ print(f"进度:{progress:.1f}% | 已保存 {saved_count} 帧 | 时间:{t:.2f}s")
+
+ t += interval_seconds
+
+ except Exception as e:
+ raise RuntimeError(f"提取帧时发生错误:{str(e)}")
+ finally:
+ # 释放视频资源
+ cap.release()
+ cv2.destroyAllWindows()
+
+ # 打印最终结果
+ print(f"\n提取完成!共保存 {saved_count} 帧,保存路径:{os.path.abspath(output_dir)}")
+
+output_dirs = ["/fi-lib/workspace/sjx/DiffSynth-Studio/dataset/no other choice","/fi-lib/workspace/sjx/DiffSynth-Studio/dataset/the roses","/fi-lib/workspace/sjx/DiffSynth-Studio/dataset/nouvelle","/fi-lib/workspace/sjx/DiffSynth-Studio/dataset/legs","/fi-lib/workspace/sjx/DiffSynth-Studio/dataset/frankenstein"]
+input_dirs = ["/fi-lib/workspace/sjx/DiffSynth-Studio/dataset/films/어쩔수가없다 NO OTHER CHOICE, 2025.1080p.WEB-DL.H264.AAC.mp4","/fi-lib/workspace/sjx/DiffSynth-Studio/dataset/films/The.Roses.2025.2160p.WEB-DL.DDP5.1.Atmos.SDR.H265-AOC/The.Roses.2025.2160p.WEB-DL.DDP5.1.Atmos.SDR.H265-AOC.mkv","/fi-lib/workspace/sjx/DiffSynth-Studio/dataset/films/NOUVELLE.VAGUE.2025.2160p.NF.WEB-DL.DDP.5.1.H.265-CHDWEB[PianYuan]/NOUVELLE.VAGUE.2025.2160p.NF.WEB-DL.DDP.5.1.H.265-CHDWEB.mkv","/fi-lib/workspace/sjx/DiffSynth-Studio/dataset/films/If.I.Had.Legs.Id.Kick.You.2025.1080p.iT.WEB-DL.DDP5.1.Atmos.H264-BTM/If.I.Had.Legs.Id.Kick.You.2025.1080p.iT.WEB-DL[Ben The Men].mkv","/fi-lib/workspace/sjx/DiffSynth-Studio/dataset/films/Frankenstein.2025.1080p.NF.WEB-DL.DDP5.1.Atmos.H.264-FLUX/Frankenstein.2025.1080p.NF.WEB-DL.DDP5.1.Atmos.H.264-FLUX.mkv"]
+if __name__ == "__main__":
+ # 执行帧提取(每3秒保存一帧)
+ for i in range(len(output_dirs)):
+ try:
+ extract_frames_per_second(video_path=input_dirs[i],output_dir=output_dirs[i],interval_seconds=3)
+ except Exception as e:
+ print(f"程序异常:{str(e)}", file=sys.stderr)
+ sys.exit(1)
diff --git a/deal2.py b/deal2.py
new file mode 100644
index 0000000000000000000000000000000000000000..9111184f5cfbcd1f646cdbeb300327e452560ad1
--- /dev/null
+++ b/deal2.py
@@ -0,0 +1,127 @@
+import cv2
+import numpy as np
+import os
+from pathlib import Path
+import imagehash
+from PIL import Image
+
+def calculate_phash(image_path):
+ """
+ 计算图片的感知哈希值(pHash)
+ :param image_path: 图片路径
+ :return: 感知哈希值(imagehash.ImageHash对象)
+ """
+ try:
+ # 用PIL读取图片(兼容更多格式),转为灰度图
+ img = Image.open(image_path).convert("L")
+ # 计算pHash,hash_size越小,哈希值越短,计算越快(默认8,生成64位哈希)
+ phash = imagehash.phash(img, hash_size=8)
+ return phash
+ except Exception as e:
+ print(f"计算哈希失败:{image_path},错误:{e}")
+ return None
+
+def calculate_clarity(image_path):
+ """
+ 拉普拉斯方差法计算图片清晰度评分
+ :param image_path: 图片路径
+ :return: 清晰度评分(方差值),若读取失败返回0
+ """
+ img = cv2.imread(image_path)
+ if img is None:
+ return 0
+ gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
+ laplacian = cv2.Laplacian(gray, cv2.CV_64F)
+ clarity_score = np.var(laplacian)
+ return clarity_score
+
+def calculate_hamming_distance(hash1, hash2):
+ """
+ 计算两个哈希值的汉明距离
+ :param hash1, hash2: 感知哈希值(imagehash.ImageHash对象)
+ :return: 汉明距离(越小越相似)
+ """
+ if hash1 is None or hash2 is None:
+ return float("inf") # 哈希计算失败,视为不相似
+ return hash1 - hash2
+
+def process_duplicate_frames(input_dir, output_dir, similarity_threshold=10):
+ """
+ 处理视频截帧,去除相似帧并保留更清晰的图片
+ :param input_dir: 原始图片文件夹路径
+ :param output_dir: 去重后保存的文件夹路径
+ :param similarity_threshold: 相似度阈值(汉明距离≤该值为相似帧)
+ """
+ # 创建输出目录
+ Path(output_dir).mkdir(parents=True, exist_ok=True)
+
+ # 1. 获取文件夹内所有图片,按文件名排序(保证视频截帧的顺序)
+ img_extensions = [".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".webp"]
+ img_paths = [
+ os.path.join(input_dir, f) for f in os.listdir(input_dir)
+ if Path(f).suffix.lower() in img_extensions
+ ]
+ # 按文件名排序(关键:保证视频截帧的时间顺序)
+ img_paths.sort(key=lambda x: os.path.basename(x))
+
+ if len(img_paths) == 0:
+ print("文件夹中未找到图片!")
+ return
+
+ # 2. 初始化:保留第一张图片作为基准,遍历后续图片
+ saved_img_path = img_paths[0] # 已保存的基准图片路径
+ saved_phash = calculate_phash(saved_img_path) # 基准图片的哈希值
+ saved_clarity = calculate_clarity(saved_img_path) # 基准图片的清晰度
+
+ # 保存第一张图片
+ save_name = os.path.basename(saved_img_path)
+ cv2.imwrite(os.path.join(output_dir, save_name), cv2.imread(saved_img_path))
+ print(f"初始保存:{save_name},清晰度:{saved_clarity:.2f}")
+
+ # 3. 从第二张开始遍历,逐张对比
+ for current_img_path in img_paths[1:]:
+ current_name = os.path.basename(current_img_path)
+ current_phash = calculate_phash(current_img_path)
+ current_clarity = calculate_clarity(current_img_path)
+
+ # 计算与基准图片的汉明距离
+ hamming_dist = calculate_hamming_distance(saved_phash, current_phash)
+ print(f"\n对比:{saved_img_path.split('/')[-1]} vs {current_name}")
+ print(f"汉明距离:{hamming_dist},当前图片清晰度:{current_clarity:.2f}")
+
+ if hamming_dist <= similarity_threshold:
+ # 相似帧:保留清晰度更高的图片
+ if current_clarity > saved_clarity:
+ # 当前图片更清晰:删除原基准图片,保存当前图片作为新基准
+ os.remove(os.path.join(output_dir, os.path.basename(saved_img_path)))
+ cv2.imwrite(os.path.join(output_dir, current_name), cv2.imread(current_img_path))
+ print(f"替换:{current_name} 更清晰,已替换原基准图片")
+ # 更新基准信息
+ saved_img_path = current_img_path
+ saved_phash = current_phash
+ saved_clarity = current_clarity
+ else:
+ # 当前图片更模糊:跳过,保留原基准
+ print(f"跳过:{current_name} 模糊,保留原基准图片")
+ else:
+ # 非相似帧:保存当前图片,作为新的基准
+ cv2.imwrite(os.path.join(output_dir, current_name), cv2.imread(current_img_path))
+ print(f"保存:{current_name} 为新基准,与原基准非相似帧")
+ # 更新基准信息
+ saved_img_path = current_img_path
+ saved_phash = current_phash
+ saved_clarity = current_clarity
+
+ print(f"\n处理完成!去重后图片保存在:{output_dir}")
+ print(f"原始图片数量:{len(img_paths)},去重后数量:{len(os.listdir(output_dir))}")
+
+# 主函数调用
+if __name__ == "__main__":
+ input_dirs = ["/fi-lib/workspace/sjx/DiffSynth-Studio/dataset/no other choice","/fi-lib/workspace/sjx/DiffSynth-Studio/dataset/the roses","/fi-lib/workspace/sjx/DiffSynth-Studio/dataset/nouvelle","/fi-lib/workspace/sjx/DiffSynth-Studio/dataset/legs","/fi-lib/workspace/sjx/DiffSynth-Studio/dataset/frankenstein"]
+ # 配置参数(根据实际情况修改)
+ for i in range(len(input_dirs)):
+
+ SIMILARITY_THRESHOLD = 10 # 相似性阈值(汉明距离),可调整
+
+ # 执行去重处理
+ process_duplicate_frames(input_dirs[i], f"{input_dirs[i]}_dedup", SIMILARITY_THRESHOLD)
\ No newline at end of file
diff --git a/diffsynth.egg-info/PKG-INFO b/diffsynth.egg-info/PKG-INFO
new file mode 100644
index 0000000000000000000000000000000000000000..622ef73a6d298d3ec050718d3db8d0013e551fdd
--- /dev/null
+++ b/diffsynth.egg-info/PKG-INFO
@@ -0,0 +1,31 @@
+Metadata-Version: 2.4
+Name: diffsynth
+Version: 1.1.9
+Summary: Enjoy the magic of Diffusion models!
+Author: Artiprocher
+Classifier: Programming Language :: Python :: 3
+Classifier: License :: OSI Approved :: Apache Software License
+Classifier: Operating System :: OS Independent
+Requires-Python: >=3.6
+License-File: LICENSE
+Requires-Dist: torch>=2.0.0
+Requires-Dist: torchvision
+Requires-Dist: transformers
+Requires-Dist: imageio
+Requires-Dist: imageio[ffmpeg]
+Requires-Dist: safetensors
+Requires-Dist: einops
+Requires-Dist: sentencepiece
+Requires-Dist: protobuf
+Requires-Dist: modelscope
+Requires-Dist: ftfy
+Requires-Dist: pynvml
+Requires-Dist: pandas
+Requires-Dist: accelerate
+Requires-Dist: peft
+Dynamic: author
+Dynamic: classifier
+Dynamic: license-file
+Dynamic: requires-dist
+Dynamic: requires-python
+Dynamic: summary
diff --git a/diffsynth.egg-info/SOURCES.txt b/diffsynth.egg-info/SOURCES.txt
new file mode 100644
index 0000000000000000000000000000000000000000..9325f9079aaa0a7b212732c58fc83791a089e7d1
--- /dev/null
+++ b/diffsynth.egg-info/SOURCES.txt
@@ -0,0 +1,247 @@
+LICENSE
+README.md
+setup.py
+diffsynth/__init__.py
+diffsynth.egg-info/PKG-INFO
+diffsynth.egg-info/SOURCES.txt
+diffsynth.egg-info/dependency_links.txt
+diffsynth.egg-info/requires.txt
+diffsynth.egg-info/top_level.txt
+diffsynth/configs/__init__.py
+diffsynth/configs/model_config.py
+diffsynth/controlnets/__init__.py
+diffsynth/controlnets/controlnet_unit.py
+diffsynth/controlnets/processors.py
+diffsynth/data/__init__.py
+diffsynth/data/simple_text_image.py
+diffsynth/data/video.py
+diffsynth/distributed/__init__.py
+diffsynth/distributed/xdit_context_parallel.py
+diffsynth/extensions/__init__.py
+diffsynth/extensions/ESRGAN/__init__.py
+diffsynth/extensions/FastBlend/__init__.py
+diffsynth/extensions/FastBlend/api.py
+diffsynth/extensions/FastBlend/cupy_kernels.py
+diffsynth/extensions/FastBlend/data.py
+diffsynth/extensions/FastBlend/patch_match.py
+diffsynth/extensions/FastBlend/runners/__init__.py
+diffsynth/extensions/FastBlend/runners/accurate.py
+diffsynth/extensions/FastBlend/runners/balanced.py
+diffsynth/extensions/FastBlend/runners/fast.py
+diffsynth/extensions/FastBlend/runners/interpolation.py
+diffsynth/extensions/ImageQualityMetric/__init__.py
+diffsynth/extensions/ImageQualityMetric/aesthetic.py
+diffsynth/extensions/ImageQualityMetric/clip.py
+diffsynth/extensions/ImageQualityMetric/config.py
+diffsynth/extensions/ImageQualityMetric/hps.py
+diffsynth/extensions/ImageQualityMetric/imagereward.py
+diffsynth/extensions/ImageQualityMetric/mps.py
+diffsynth/extensions/ImageQualityMetric/pickscore.py
+diffsynth/extensions/ImageQualityMetric/BLIP/__init__.py
+diffsynth/extensions/ImageQualityMetric/BLIP/blip.py
+diffsynth/extensions/ImageQualityMetric/BLIP/blip_pretrain.py
+diffsynth/extensions/ImageQualityMetric/BLIP/med.py
+diffsynth/extensions/ImageQualityMetric/BLIP/vit.py
+diffsynth/extensions/ImageQualityMetric/open_clip/__init__.py
+diffsynth/extensions/ImageQualityMetric/open_clip/coca_model.py
+diffsynth/extensions/ImageQualityMetric/open_clip/constants.py
+diffsynth/extensions/ImageQualityMetric/open_clip/factory.py
+diffsynth/extensions/ImageQualityMetric/open_clip/generation_utils.py
+diffsynth/extensions/ImageQualityMetric/open_clip/hf_configs.py
+diffsynth/extensions/ImageQualityMetric/open_clip/hf_model.py
+diffsynth/extensions/ImageQualityMetric/open_clip/loss.py
+diffsynth/extensions/ImageQualityMetric/open_clip/model.py
+diffsynth/extensions/ImageQualityMetric/open_clip/modified_resnet.py
+diffsynth/extensions/ImageQualityMetric/open_clip/openai.py
+diffsynth/extensions/ImageQualityMetric/open_clip/pretrained.py
+diffsynth/extensions/ImageQualityMetric/open_clip/push_to_hf_hub.py
+diffsynth/extensions/ImageQualityMetric/open_clip/timm_model.py
+diffsynth/extensions/ImageQualityMetric/open_clip/tokenizer.py
+diffsynth/extensions/ImageQualityMetric/open_clip/transform.py
+diffsynth/extensions/ImageQualityMetric/open_clip/transformer.py
+diffsynth/extensions/ImageQualityMetric/open_clip/utils.py
+diffsynth/extensions/ImageQualityMetric/open_clip/version.py
+diffsynth/extensions/ImageQualityMetric/trainer/__init__.py
+diffsynth/extensions/ImageQualityMetric/trainer/models/__init__.py
+diffsynth/extensions/ImageQualityMetric/trainer/models/base_model.py
+diffsynth/extensions/ImageQualityMetric/trainer/models/clip_model.py
+diffsynth/extensions/ImageQualityMetric/trainer/models/cross_modeling.py
+diffsynth/extensions/RIFE/__init__.py
+diffsynth/lora/__init__.py
+diffsynth/lora/flux_lora.py
+diffsynth/models/__init__.py
+diffsynth/models/attention.py
+diffsynth/models/cog_dit.py
+diffsynth/models/cog_vae.py
+diffsynth/models/downloader.py
+diffsynth/models/flux_controlnet.py
+diffsynth/models/flux_dit.py
+diffsynth/models/flux_infiniteyou.py
+diffsynth/models/flux_ipadapter.py
+diffsynth/models/flux_lora_encoder.py
+diffsynth/models/flux_text_encoder.py
+diffsynth/models/flux_vae.py
+diffsynth/models/flux_value_control.py
+diffsynth/models/hunyuan_dit.py
+diffsynth/models/hunyuan_dit_text_encoder.py
+diffsynth/models/hunyuan_video_dit.py
+diffsynth/models/hunyuan_video_text_encoder.py
+diffsynth/models/hunyuan_video_vae_decoder.py
+diffsynth/models/hunyuan_video_vae_encoder.py
+diffsynth/models/kolors_text_encoder.py
+diffsynth/models/longcat_video_dit.py
+diffsynth/models/lora.py
+diffsynth/models/model_manager.py
+diffsynth/models/nexus_gen.py
+diffsynth/models/nexus_gen_ar_model.py
+diffsynth/models/nexus_gen_projector.py
+diffsynth/models/omnigen.py
+diffsynth/models/qwen_image_controlnet.py
+diffsynth/models/qwen_image_dit.py
+diffsynth/models/qwen_image_text_encoder.py
+diffsynth/models/qwen_image_vae.py
+diffsynth/models/qwenvl.py
+diffsynth/models/sd3_dit.py
+diffsynth/models/sd3_text_encoder.py
+diffsynth/models/sd3_vae_decoder.py
+diffsynth/models/sd3_vae_encoder.py
+diffsynth/models/sd_controlnet.py
+diffsynth/models/sd_ipadapter.py
+diffsynth/models/sd_motion.py
+diffsynth/models/sd_text_encoder.py
+diffsynth/models/sd_unet.py
+diffsynth/models/sd_vae_decoder.py
+diffsynth/models/sd_vae_encoder.py
+diffsynth/models/sdxl_controlnet.py
+diffsynth/models/sdxl_ipadapter.py
+diffsynth/models/sdxl_motion.py
+diffsynth/models/sdxl_text_encoder.py
+diffsynth/models/sdxl_unet.py
+diffsynth/models/sdxl_vae_decoder.py
+diffsynth/models/sdxl_vae_encoder.py
+diffsynth/models/step1x_connector.py
+diffsynth/models/stepvideo_dit.py
+diffsynth/models/stepvideo_text_encoder.py
+diffsynth/models/stepvideo_vae.py
+diffsynth/models/svd_image_encoder.py
+diffsynth/models/svd_unet.py
+diffsynth/models/svd_vae_decoder.py
+diffsynth/models/svd_vae_encoder.py
+diffsynth/models/tiler.py
+diffsynth/models/utils.py
+diffsynth/models/wan_video_animate_adapter.py
+diffsynth/models/wan_video_camera_controller.py
+diffsynth/models/wan_video_dit.py
+diffsynth/models/wan_video_dit_s2v.py
+diffsynth/models/wan_video_image_encoder.py
+diffsynth/models/wan_video_mot.py
+diffsynth/models/wan_video_motion_controller.py
+diffsynth/models/wan_video_text_encoder.py
+diffsynth/models/wan_video_vace.py
+diffsynth/models/wan_video_vae.py
+diffsynth/models/wav2vec.py
+diffsynth/pipelines/__init__.py
+diffsynth/pipelines/base.py
+diffsynth/pipelines/cog_video.py
+diffsynth/pipelines/dancer.py
+diffsynth/pipelines/flux_image.py
+diffsynth/pipelines/flux_image_new.py
+diffsynth/pipelines/hunyuan_image.py
+diffsynth/pipelines/hunyuan_video.py
+diffsynth/pipelines/omnigen_image.py
+diffsynth/pipelines/pipeline_runner.py
+diffsynth/pipelines/qwen_image.py
+diffsynth/pipelines/sd3_image.py
+diffsynth/pipelines/sd_image.py
+diffsynth/pipelines/sd_video.py
+diffsynth/pipelines/sdxl_image.py
+diffsynth/pipelines/sdxl_video.py
+diffsynth/pipelines/step_video.py
+diffsynth/pipelines/svd_video.py
+diffsynth/pipelines/wan_video.py
+diffsynth/pipelines/wan_video_new.py
+diffsynth/processors/FastBlend.py
+diffsynth/processors/PILEditor.py
+diffsynth/processors/RIFE.py
+diffsynth/processors/__init__.py
+diffsynth/processors/base.py
+diffsynth/processors/sequencial_processor.py
+diffsynth/prompters/__init__.py
+diffsynth/prompters/base_prompter.py
+diffsynth/prompters/cog_prompter.py
+diffsynth/prompters/flux_prompter.py
+diffsynth/prompters/hunyuan_dit_prompter.py
+diffsynth/prompters/hunyuan_video_prompter.py
+diffsynth/prompters/kolors_prompter.py
+diffsynth/prompters/omnigen_prompter.py
+diffsynth/prompters/omost.py
+diffsynth/prompters/prompt_refiners.py
+diffsynth/prompters/sd3_prompter.py
+diffsynth/prompters/sd_prompter.py
+diffsynth/prompters/sdxl_prompter.py
+diffsynth/prompters/stepvideo_prompter.py
+diffsynth/prompters/wan_prompter.py
+diffsynth/schedulers/__init__.py
+diffsynth/schedulers/continuous_ode.py
+diffsynth/schedulers/ddim.py
+diffsynth/schedulers/flow_match.py
+diffsynth/tokenizer_configs/__init__.py
+diffsynth/tokenizer_configs/cog/tokenizer/added_tokens.json
+diffsynth/tokenizer_configs/cog/tokenizer/special_tokens_map.json
+diffsynth/tokenizer_configs/cog/tokenizer/spiece.model
+diffsynth/tokenizer_configs/cog/tokenizer/tokenizer_config.json
+diffsynth/tokenizer_configs/flux/tokenizer_1/merges.txt
+diffsynth/tokenizer_configs/flux/tokenizer_1/special_tokens_map.json
+diffsynth/tokenizer_configs/flux/tokenizer_1/tokenizer_config.json
+diffsynth/tokenizer_configs/flux/tokenizer_1/vocab.json
+diffsynth/tokenizer_configs/flux/tokenizer_2/special_tokens_map.json
+diffsynth/tokenizer_configs/flux/tokenizer_2/spiece.model
+diffsynth/tokenizer_configs/flux/tokenizer_2/tokenizer.json
+diffsynth/tokenizer_configs/flux/tokenizer_2/tokenizer_config.json
+diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/special_tokens_map.json
+diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/tokenizer_config.json
+diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/vocab.txt
+diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/vocab_org.txt
+diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/config.json
+diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/special_tokens_map.json
+diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/spiece.model
+diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/tokenizer_config.json
+diffsynth/tokenizer_configs/hunyuan_video/tokenizer_1/merges.txt
+diffsynth/tokenizer_configs/hunyuan_video/tokenizer_1/special_tokens_map.json
+diffsynth/tokenizer_configs/hunyuan_video/tokenizer_1/tokenizer_config.json
+diffsynth/tokenizer_configs/hunyuan_video/tokenizer_1/vocab.json
+diffsynth/tokenizer_configs/hunyuan_video/tokenizer_2/preprocessor_config.json
+diffsynth/tokenizer_configs/hunyuan_video/tokenizer_2/special_tokens_map.json
+diffsynth/tokenizer_configs/hunyuan_video/tokenizer_2/tokenizer.json
+diffsynth/tokenizer_configs/hunyuan_video/tokenizer_2/tokenizer_config.json
+diffsynth/tokenizer_configs/kolors/tokenizer/tokenizer.model
+diffsynth/tokenizer_configs/kolors/tokenizer/tokenizer_config.json
+diffsynth/tokenizer_configs/kolors/tokenizer/vocab.txt
+diffsynth/tokenizer_configs/stable_diffusion/tokenizer/merges.txt
+diffsynth/tokenizer_configs/stable_diffusion/tokenizer/special_tokens_map.json
+diffsynth/tokenizer_configs/stable_diffusion/tokenizer/tokenizer_config.json
+diffsynth/tokenizer_configs/stable_diffusion/tokenizer/vocab.json
+diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/merges.txt
+diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/special_tokens_map.json
+diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/tokenizer_config.json
+diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/vocab.json
+diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/merges.txt
+diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/special_tokens_map.json
+diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/tokenizer_config.json
+diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/vocab.json
+diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/special_tokens_map.json
+diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/spiece.model
+diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/tokenizer.json
+diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/tokenizer_config.json
+diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/merges.txt
+diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/special_tokens_map.json
+diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/tokenizer_config.json
+diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/vocab.json
+diffsynth/trainers/__init__.py
+diffsynth/trainers/text_to_image.py
+diffsynth/trainers/unified_dataset.py
+diffsynth/trainers/utils.py
+diffsynth/utils/__init__.py
+diffsynth/vram_management/__init__.py
+diffsynth/vram_management/gradient_checkpointing.py
+diffsynth/vram_management/layers.py
\ No newline at end of file
diff --git a/diffsynth.egg-info/dependency_links.txt b/diffsynth.egg-info/dependency_links.txt
new file mode 100644
index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc
--- /dev/null
+++ b/diffsynth.egg-info/dependency_links.txt
@@ -0,0 +1 @@
+
diff --git a/diffsynth.egg-info/requires.txt b/diffsynth.egg-info/requires.txt
new file mode 100644
index 0000000000000000000000000000000000000000..dd8aab031586615077ccf65d0036adeddfd23833
--- /dev/null
+++ b/diffsynth.egg-info/requires.txt
@@ -0,0 +1,15 @@
+torch>=2.0.0
+torchvision
+transformers
+imageio
+imageio[ffmpeg]
+safetensors
+einops
+sentencepiece
+protobuf
+modelscope
+ftfy
+pynvml
+pandas
+accelerate
+peft
diff --git a/diffsynth.egg-info/top_level.txt b/diffsynth.egg-info/top_level.txt
new file mode 100644
index 0000000000000000000000000000000000000000..a4c845e23f3cd1bc9c56115f1c9a0b1a0c6b68bc
--- /dev/null
+++ b/diffsynth.egg-info/top_level.txt
@@ -0,0 +1 @@
+diffsynth
diff --git a/diffsynth/__init__.py b/diffsynth/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae0a45c2e2dc61f8f16354feb1b0c481776b523f
--- /dev/null
+++ b/diffsynth/__init__.py
@@ -0,0 +1,6 @@
+from .data import *
+from .models import *
+from .prompters import *
+from .schedulers import *
+from .pipelines import *
+from .controlnets import *
diff --git a/diffsynth/configs/__init__.py b/diffsynth/configs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/diffsynth/configs/model_config.py b/diffsynth/configs/model_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..743ffba5b37ee3a812f1253c88e34eb531129601
--- /dev/null
+++ b/diffsynth/configs/model_config.py
@@ -0,0 +1,859 @@
+from typing_extensions import Literal, TypeAlias
+
+from ..models.sd_text_encoder import SDTextEncoder
+from ..models.sd_unet import SDUNet
+from ..models.sd_vae_encoder import SDVAEEncoder
+from ..models.sd_vae_decoder import SDVAEDecoder
+
+from ..models.sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
+from ..models.sdxl_unet import SDXLUNet
+from ..models.sdxl_vae_decoder import SDXLVAEDecoder
+from ..models.sdxl_vae_encoder import SDXLVAEEncoder
+
+from ..models.sd3_text_encoder import SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3
+from ..models.sd3_dit import SD3DiT
+from ..models.sd3_vae_decoder import SD3VAEDecoder
+from ..models.sd3_vae_encoder import SD3VAEEncoder
+
+from ..models.sd_controlnet import SDControlNet
+from ..models.sdxl_controlnet import SDXLControlNetUnion
+
+from ..models.sd_motion import SDMotionModel
+from ..models.sdxl_motion import SDXLMotionModel
+
+from ..models.svd_image_encoder import SVDImageEncoder
+from ..models.svd_unet import SVDUNet
+from ..models.svd_vae_decoder import SVDVAEDecoder
+from ..models.svd_vae_encoder import SVDVAEEncoder
+
+from ..models.sd_ipadapter import SDIpAdapter, IpAdapterCLIPImageEmbedder
+from ..models.sdxl_ipadapter import SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
+
+from ..models.hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
+from ..models.hunyuan_dit import HunyuanDiT
+
+from ..models.flux_dit import FluxDiT
+from ..models.flux_text_encoder import FluxTextEncoder2
+from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder
+from ..models.flux_controlnet import FluxControlNet
+from ..models.flux_ipadapter import FluxIpAdapter
+from ..models.flux_infiniteyou import InfiniteYouImageProjector
+
+from ..models.cog_vae import CogVAEEncoder, CogVAEDecoder
+from ..models.cog_dit import CogDiT
+
+from ..models.omnigen import OmniGenTransformer
+
+from ..models.hunyuan_video_vae_decoder import HunyuanVideoVAEDecoder
+from ..models.hunyuan_video_vae_encoder import HunyuanVideoVAEEncoder
+
+from ..extensions.RIFE import IFNet
+from ..extensions.ESRGAN import RRDBNet
+
+from ..models.hunyuan_video_dit import HunyuanVideoDiT
+
+from ..models.stepvideo_vae import StepVideoVAE
+from ..models.stepvideo_dit import StepVideoModel
+
+from ..models.wan_video_dit import WanModel
+from ..models.wan_video_dit_s2v import WanS2VModel
+from ..models.wan_video_text_encoder import WanTextEncoder
+from ..models.wan_video_image_encoder import WanImageEncoder
+from ..models.wan_video_vae import WanVideoVAE, WanVideoVAE38
+from ..models.wan_video_motion_controller import WanMotionControllerModel
+from ..models.wan_video_vace import VaceWanModel
+from ..models.wav2vec import WanS2VAudioEncoder
+from ..models.wan_video_animate_adapter import WanAnimateAdapter
+from ..models.wan_video_mot import MotWanModel
+
+from ..models.step1x_connector import Qwen2Connector
+
+from ..models.flux_value_control import SingleValueEncoder
+
+from ..lora.flux_lora import FluxLoraPatcher
+from ..models.flux_lora_encoder import FluxLoRAEncoder
+
+from ..models.nexus_gen_projector import NexusGenAdapter, NexusGenImageEmbeddingMerger
+from ..models.nexus_gen import NexusGenAutoregressiveModel
+
+from ..models.qwen_image_dit import QwenImageDiT
+from ..models.qwen_image_text_encoder import QwenImageTextEncoder
+from ..models.qwen_image_vae import QwenImageVAE
+from ..models.qwen_image_controlnet import QwenImageBlockWiseControlNet
+
+from ..models.longcat_video_dit import LongCatVideoTransformer3DModel
+
+model_loader_configs = [
+ # These configs are provided for detecting model type automatically.
+ # The format is (state_dict_keys_hash, state_dict_keys_hash_with_shape, model_names, model_classes, model_resource)
+ (None, "091b0e30e77c76626b3ba62acdf95343", ["sd_controlnet"], [SDControlNet], "civitai"),
+ (None, "4a6c8306a27d916dea81263c8c88f450", ["hunyuan_dit_clip_text_encoder"], [HunyuanDiTCLIPTextEncoder], "civitai"),
+ (None, "f4aec400fe394297961218c768004521", ["hunyuan_dit"], [HunyuanDiT], "civitai"),
+ (None, "9e6e58043a5a2e332803ed42f6ee7181", ["hunyuan_dit_t5_text_encoder"], [HunyuanDiTT5TextEncoder], "civitai"),
+ (None, "13115dd45a6e1c39860f91ab073b8a78", ["sdxl_vae_encoder", "sdxl_vae_decoder"], [SDXLVAEEncoder, SDXLVAEDecoder], "diffusers"),
+ (None, "d78aa6797382a6d455362358a3295ea9", ["sd_ipadapter_clip_image_encoder"], [IpAdapterCLIPImageEmbedder], "diffusers"),
+ (None, "e291636cc15e803186b47404262ef812", ["sd_ipadapter"], [SDIpAdapter], "civitai"),
+ (None, "399c81f2f8de8d1843d0127a00f3c224", ["sdxl_ipadapter_clip_image_encoder"], [IpAdapterXLCLIPImageEmbedder], "diffusers"),
+ (None, "a64eac9aa0db4b9602213bc0131281c7", ["sdxl_ipadapter"], [SDXLIpAdapter], "civitai"),
+ (None, "52817e4fdd89df154f02749ca6f692ac", ["sdxl_unet"], [SDXLUNet], "diffusers"),
+ (None, "03343c606f16d834d6411d0902b53636", ["sd_text_encoder", "sd_unet", "sd_vae_decoder", "sd_vae_encoder"], [SDTextEncoder, SDUNet, SDVAEDecoder, SDVAEEncoder], "civitai"),
+ (None, "d4ba77a7ece070679b4a987f58f201e9", ["sd_text_encoder"], [SDTextEncoder], "civitai"),
+ (None, "d0c89e55c5a57cf3981def0cb1c9e65a", ["sd_vae_decoder", "sd_vae_encoder"], [SDVAEDecoder, SDVAEEncoder], "civitai"),
+ (None, "3926bf373b39a67eeafd7901478a47a7", ["sd_unet"], [SDUNet], "civitai"),
+ (None, "1e0c39ec176b9007c05f76d52b554a4d", ["sd3_text_encoder_1", "sd3_text_encoder_2", "sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3TextEncoder1, SD3TextEncoder2, SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
+ (None, "d9e0290829ba8d98e28e1a2b1407db4a", ["sd3_text_encoder_1", "sd3_text_encoder_2", "sd3_text_encoder_3", "sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3, SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
+ (None, "5072d0b24e406b49507abe861cf97691", ["sd3_text_encoder_3"], [SD3TextEncoder3], "civitai"),
+ (None, "4cf64a799d04260df438c6f33c9a047e", ["sdxl_text_encoder", "sdxl_text_encoder_2", "sdxl_unet", "sdxl_vae_decoder", "sdxl_vae_encoder"], [SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder], "civitai"),
+ (None, "d9b008a867c498ab12ad24042eff8e3f", ["sdxl_text_encoder", "sdxl_text_encoder_2", "sdxl_unet", "sdxl_vae_decoder", "sdxl_vae_encoder"], [SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder], "civitai"), # SDXL-Turbo
+ (None, "025bb7452e531a3853d951d77c63f032", ["sdxl_text_encoder", "sdxl_text_encoder_2"], [SDXLTextEncoder, SDXLTextEncoder2], "civitai"),
+ (None, "298997b403a4245c04102c9f36aac348", ["sdxl_unet"], [SDXLUNet], "civitai"),
+ (None, "2a07abce74b4bdc696b76254ab474da6", ["svd_image_encoder", "svd_unet", "svd_vae_decoder", "svd_vae_encoder"], [SVDImageEncoder, SVDUNet, SVDVAEDecoder, SVDVAEEncoder], "civitai"),
+ (None, "c96a285a6888465f87de22a984d049fb", ["sd_motion_modules"], [SDMotionModel], "civitai"),
+ (None, "72907b92caed19bdb2adb89aa4063fe2", ["sdxl_motion_modules"], [SDXLMotionModel], "civitai"),
+ (None, "31d2d9614fba60511fc9bf2604aa01f7", ["sdxl_controlnet"], [SDXLControlNetUnion], "diffusers"),
+ (None, "94eefa3dac9cec93cb1ebaf1747d7b78", ["sd3_text_encoder_1"], [SD3TextEncoder1], "diffusers"),
+ (None, "1aafa3cc91716fb6b300cc1cd51b85a3", ["flux_vae_encoder", "flux_vae_decoder"], [FluxVAEEncoder, FluxVAEDecoder], "diffusers"),
+ (None, "21ea55f476dfc4fd135587abb59dfe5d", ["flux_vae_encoder", "flux_vae_decoder"], [FluxVAEEncoder, FluxVAEDecoder], "civitai"),
+ (None, "a29710fea6dddb0314663ee823598e50", ["flux_dit"], [FluxDiT], "civitai"),
+ (None, "57b02550baab820169365b3ee3afa2c9", ["flux_dit"], [FluxDiT], "civitai"),
+ (None, "3394f306c4cbf04334b712bf5aaed95f", ["flux_dit"], [FluxDiT], "civitai"),
+ (None, "023f054d918a84ccf503481fd1e3379e", ["flux_dit"], [FluxDiT], "civitai"),
+ (None, "d02f41c13549fa5093d3521f62a5570a", ["flux_dit"], [FluxDiT], "civitai"),
+ (None, "605c56eab23e9e2af863ad8f0813a25d", ["flux_dit"], [FluxDiT], "diffusers"),
+ (None, "0629116fce1472503a66992f96f3eb1a", ["flux_value_controller"], [SingleValueEncoder], "civitai"),
+ (None, "280189ee084bca10f70907bf6ce1649d", ["cog_vae_encoder", "cog_vae_decoder"], [CogVAEEncoder, CogVAEDecoder], "diffusers"),
+ (None, "9b9313d104ac4df27991352fec013fd4", ["rife"], [IFNet], "civitai"),
+ (None, "6b7116078c4170bfbeaedc8fe71f6649", ["esrgan"], [RRDBNet], "civitai"),
+ (None, "61cbcbc7ac11f169c5949223efa960d1", ["omnigen_transformer"], [OmniGenTransformer], "diffusers"),
+ (None, "78d18b9101345ff695f312e7e62538c0", ["flux_controlnet"], [FluxControlNet], "diffusers"),
+ (None, "b001c89139b5f053c715fe772362dd2a", ["flux_controlnet"], [FluxControlNet], "diffusers"),
+ (None, "52357cb26250681367488a8954c271e8", ["flux_controlnet"], [FluxControlNet], "diffusers"),
+ (None, "0cfd1740758423a2a854d67c136d1e8c", ["flux_controlnet"], [FluxControlNet], "diffusers"),
+ (None, "7f9583eb8ba86642abb9a21a4b2c9e16", ["flux_controlnet"], [FluxControlNet], "diffusers"),
+ (None, "43ad5aaa27dd4ee01b832ed16773fa52", ["flux_controlnet"], [FluxControlNet], "diffusers"),
+ (None, "c07c0f04f5ff55e86b4e937c7a40d481", ["infiniteyou_image_projector"], [InfiniteYouImageProjector], "diffusers"),
+ (None, "4daaa66cc656a8fe369908693dad0a35", ["flux_ipadapter"], [FluxIpAdapter], "diffusers"),
+ (None, "51aed3d27d482fceb5e0739b03060e8f", ["sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
+ (None, "98cc34ccc5b54ae0e56bdea8688dcd5a", ["sd3_text_encoder_2"], [SD3TextEncoder2], "civitai"),
+ (None, "77ff18050dbc23f50382e45d51a779fe", ["sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
+ (None, "5da81baee73198a7c19e6d2fe8b5148e", ["sd3_text_encoder_1"], [SD3TextEncoder1], "diffusers"),
+ (None, "aeb82dce778a03dcb4d726cb03f3c43f", ["hunyuan_video_vae_decoder", "hunyuan_video_vae_encoder"], [HunyuanVideoVAEDecoder, HunyuanVideoVAEEncoder], "diffusers"),
+ (None, "b9588f02e78f5ccafc9d7c0294e46308", ["hunyuan_video_dit"], [HunyuanVideoDiT], "civitai"),
+ (None, "84ef4bd4757f60e906b54aa6a7815dc6", ["hunyuan_video_dit"], [HunyuanVideoDiT], "civitai"),
+ (None, "68beaf8429b7c11aa8ca05b1bd0058bd", ["stepvideo_vae"], [StepVideoVAE], "civitai"),
+ (None, "5c0216a2132b082c10cb7a0e0377e681", ["stepvideo_dit"], [StepVideoModel], "civitai"),
+ (None, "9269f8db9040a9d860eaca435be61814", ["wan_video_dit"], [WanModel], "civitai"),
+ (None, "aafcfd9672c3a2456dc46e1cb6e52c70", ["wan_video_dit"], [WanModel], "civitai"),
+ (None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"),
+ (None, "6d6ccde6845b95ad9114ab993d917893", ["wan_video_dit"], [WanModel], "civitai"),
+ (None, "349723183fc063b2bfc10bb2835cf677", ["wan_video_dit"], [WanModel], "civitai"),
+ (None, "efa44cddf936c70abd0ea28b6cbe946c", ["wan_video_dit"], [WanModel], "civitai"),
+ (None, "3ef3b1f8e1dab83d5b71fd7b617f859f", ["wan_video_dit"], [WanModel], "civitai"),
+ (None, "70ddad9d3a133785da5ea371aae09504", ["wan_video_dit"], [WanModel], "civitai"),
+ (None, "26bde73488a92e64cc20b0a7485b9e5b", ["wan_video_dit"], [WanModel], "civitai"),
+ (None, "ac6a5aa74f4a0aab6f64eb9a72f19901", ["wan_video_dit"], [WanModel], "civitai"),
+ (None, "b61c605c2adbd23124d152ed28e049ae", ["wan_video_dit"], [WanModel], "civitai"),
+ (None, "1f5ab7703c6fc803fdded85ff040c316", ["wan_video_dit"], [WanModel], "civitai"),
+ (None, "5b013604280dd715f8457c6ed6d6a626", ["wan_video_dit"], [WanModel], "civitai"),
+ (None, "2267d489f0ceb9f21836532952852ee5", ["wan_video_dit"], [WanModel], "civitai"),
+ (None, "5ec04e02b42d2580483ad69f4e76346a", ["wan_video_dit"], [WanModel], "civitai"),
+ (None, "47dbeab5e560db3180adf51dc0232fb1", ["wan_video_dit"], [WanModel], "civitai"),
+ (None, "5f90e66a0672219f12d9a626c8c21f61", ["wan_video_dit", "wan_video_vap"], [WanModel,MotWanModel], "diffusers"),
+ (None, "a61453409b67cd3246cf0c3bebad47ba", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"),
+ (None, "7a513e1f257a861512b1afd387a8ecd9", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"),
+ (None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"),
+ (None, "966cffdcc52f9c46c391768b27637614", ["wan_video_dit"], [WanS2VModel], "civitai"),
+ (None, "8b27900f680d7251ce44e2dc8ae1ffef", ["wan_video_dit"], [LongCatVideoTransformer3DModel], "civitai"),
+ (None, "9c8818c2cbea55eca56c7b447df170da", ["wan_video_text_encoder"], [WanTextEncoder], "civitai"),
+ (None, "5941c53e207d62f20f9025686193c40b", ["wan_video_image_encoder"], [WanImageEncoder], "civitai"),
+ (None, "1378ea763357eea97acdef78e65d6d96", ["wan_video_vae"], [WanVideoVAE], "civitai"),
+ (None, "ccc42284ea13e1ad04693284c7a09be6", ["wan_video_vae"], [WanVideoVAE], "civitai"),
+ (None, "e1de6c02cdac79f8b739f4d3698cd216", ["wan_video_vae"], [WanVideoVAE38], "civitai"),
+ (None, "dbd5ec76bbf977983f972c151d545389", ["wan_video_motion_controller"], [WanMotionControllerModel], "civitai"),
+ (None, "d30fb9e02b1dbf4e509142f05cf7dd50", ["flux_dit", "step1x_connector"], [FluxDiT, Qwen2Connector], "civitai"),
+ (None, "30143afb2dea73d1ac580e0787628f8c", ["flux_lora_patcher"], [FluxLoraPatcher], "civitai"),
+ (None, "77c2e4dd2440269eb33bfaa0d004f6ab", ["flux_lora_encoder"], [FluxLoRAEncoder], "civitai"),
+ (None, "3e6c61b0f9471135fc9c6d6a98e98b6d", ["flux_dit", "nexus_gen_generation_adapter"], [FluxDiT, NexusGenAdapter], "civitai"),
+ (None, "63c969fd37cce769a90aa781fbff5f81", ["flux_dit", "nexus_gen_editing_adapter"], [FluxDiT, NexusGenImageEmbeddingMerger], "civitai"),
+ (None, "2bd19e845116e4f875a0a048e27fc219", ["nexus_gen_llm"], [NexusGenAutoregressiveModel], "civitai"),
+ (None, "0319a1cb19835fb510907dd3367c95ff", ["qwen_image_dit"], [QwenImageDiT], "civitai"),
+ (None, "8004730443f55db63092006dd9f7110e", ["qwen_image_text_encoder"], [QwenImageTextEncoder], "diffusers"),
+ (None, "ed4ea5824d55ec3107b09815e318123a", ["qwen_image_vae"], [QwenImageVAE], "diffusers"),
+ (None, "073bce9cf969e317e5662cd570c3e79c", ["qwen_image_blockwise_controlnet"], [QwenImageBlockWiseControlNet], "civitai"),
+ (None, "a9e54e480a628f0b956a688a81c33bab", ["qwen_image_blockwise_controlnet"], [QwenImageBlockWiseControlNet], "civitai"),
+ (None, "06be60f3a4526586d8431cd038a71486", ["wans2v_audio_encoder"], [WanS2VAudioEncoder], "civitai"),
+ (None, "31fa352acb8a1b1d33cd8764273d80a2", ["wan_video_dit", "wan_video_animate_adapter"], [WanModel, WanAnimateAdapter], "civitai"),
+]
+huggingface_model_loader_configs = [
+ # These configs are provided for detecting model type automatically.
+ # The format is (architecture_in_huggingface_config, huggingface_lib, model_name, redirected_architecture)
+ ("ChatGLMModel", "diffsynth.models.kolors_text_encoder", "kolors_text_encoder", None),
+ ("MarianMTModel", "transformers.models.marian.modeling_marian", "translator", None),
+ ("BloomForCausalLM", "transformers.models.bloom.modeling_bloom", "beautiful_prompt", None),
+ ("Qwen2ForCausalLM", "transformers.models.qwen2.modeling_qwen2", "qwen_prompt", None),
+ # ("LlamaForCausalLM", "transformers.models.llama.modeling_llama", "omost_prompt", None),
+ ("T5EncoderModel", "diffsynth.models.flux_text_encoder", "flux_text_encoder_2", "FluxTextEncoder2"),
+ ("CogVideoXTransformer3DModel", "diffsynth.models.cog_dit", "cog_dit", "CogDiT"),
+ ("SiglipModel", "transformers.models.siglip.modeling_siglip", "siglip_vision_model", "SiglipVisionModel"),
+ ("LlamaForCausalLM", "diffsynth.models.hunyuan_video_text_encoder", "hunyuan_video_text_encoder_2", "HunyuanVideoLLMEncoder"),
+ ("LlavaForConditionalGeneration", "diffsynth.models.hunyuan_video_text_encoder", "hunyuan_video_text_encoder_2", "HunyuanVideoMLLMEncoder"),
+ ("Step1Model", "diffsynth.models.stepvideo_text_encoder", "stepvideo_text_encoder_2", "STEP1TextEncoder"),
+ ("Qwen2_5_VLForConditionalGeneration", "diffsynth.models.qwenvl", "qwenvl", "Qwen25VL_7b_Embedder"),
+]
+patch_model_loader_configs = [
+ # These configs are provided for detecting model type automatically.
+ # The format is (state_dict_keys_hash_with_shape, model_name, model_class, extra_kwargs)
+ ("9a4ab6869ac9b7d6e31f9854e397c867", ["svd_unet"], [SVDUNet], {"add_positional_conv": 128}),
+]
+
+preset_models_on_huggingface = {
+ "HunyuanDiT": [
+ ("Tencent-Hunyuan/HunyuanDiT", "t2i/clip_text_encoder/pytorch_model.bin", "models/HunyuanDiT/t2i/clip_text_encoder"),
+ ("Tencent-Hunyuan/HunyuanDiT", "t2i/mt5/pytorch_model.bin", "models/HunyuanDiT/t2i/mt5"),
+ ("Tencent-Hunyuan/HunyuanDiT", "t2i/model/pytorch_model_ema.pt", "models/HunyuanDiT/t2i/model"),
+ ("Tencent-Hunyuan/HunyuanDiT", "t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin", "models/HunyuanDiT/t2i/sdxl-vae-fp16-fix"),
+ ],
+ "stable-video-diffusion-img2vid-xt": [
+ ("stabilityai/stable-video-diffusion-img2vid-xt", "svd_xt.safetensors", "models/stable_video_diffusion"),
+ ],
+ "ExVideo-SVD-128f-v1": [
+ ("ECNU-CILab/ExVideo-SVD-128f-v1", "model.fp16.safetensors", "models/stable_video_diffusion"),
+ ],
+ # Stable Diffusion
+ "StableDiffusion_v15": [
+ ("benjamin-paine/stable-diffusion-v1-5", "v1-5-pruned-emaonly.safetensors", "models/stable_diffusion"),
+ ],
+ "DreamShaper_8": [
+ ("Yntec/Dreamshaper8", "dreamshaper_8.safetensors", "models/stable_diffusion"),
+ ],
+ # Textual Inversion
+ "TextualInversion_VeryBadImageNegative_v1.3": [
+ ("gemasai/verybadimagenegative_v1.3", "verybadimagenegative_v1.3.pt", "models/textual_inversion"),
+ ],
+ # Stable Diffusion XL
+ "StableDiffusionXL_v1": [
+ ("stabilityai/stable-diffusion-xl-base-1.0", "sd_xl_base_1.0.safetensors", "models/stable_diffusion_xl"),
+ ],
+ "BluePencilXL_v200": [
+ ("frankjoshua/bluePencilXL_v200", "bluePencilXL_v200.safetensors", "models/stable_diffusion_xl"),
+ ],
+ "StableDiffusionXL_Turbo": [
+ ("stabilityai/sdxl-turbo", "sd_xl_turbo_1.0_fp16.safetensors", "models/stable_diffusion_xl_turbo"),
+ ],
+ # Stable Diffusion 3
+ "StableDiffusion3": [
+ ("stabilityai/stable-diffusion-3-medium", "sd3_medium_incl_clips_t5xxlfp16.safetensors", "models/stable_diffusion_3"),
+ ],
+ "StableDiffusion3_without_T5": [
+ ("stabilityai/stable-diffusion-3-medium", "sd3_medium_incl_clips.safetensors", "models/stable_diffusion_3"),
+ ],
+ # ControlNet
+ "ControlNet_v11f1p_sd15_depth": [
+ ("lllyasviel/ControlNet-v1-1", "control_v11f1p_sd15_depth.pth", "models/ControlNet"),
+ ("lllyasviel/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
+ ],
+ "ControlNet_v11p_sd15_softedge": [
+ ("lllyasviel/ControlNet-v1-1", "control_v11p_sd15_softedge.pth", "models/ControlNet"),
+ ("lllyasviel/Annotators", "ControlNetHED.pth", "models/Annotators")
+ ],
+ "ControlNet_v11f1e_sd15_tile": [
+ ("lllyasviel/ControlNet-v1-1", "control_v11f1e_sd15_tile.pth", "models/ControlNet")
+ ],
+ "ControlNet_v11p_sd15_lineart": [
+ ("lllyasviel/ControlNet-v1-1", "control_v11p_sd15_lineart.pth", "models/ControlNet"),
+ ("lllyasviel/Annotators", "sk_model.pth", "models/Annotators"),
+ ("lllyasviel/Annotators", "sk_model2.pth", "models/Annotators")
+ ],
+ "ControlNet_union_sdxl_promax": [
+ ("xinsir/controlnet-union-sdxl-1.0", "diffusion_pytorch_model_promax.safetensors", "models/ControlNet/controlnet_union"),
+ ("lllyasviel/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
+ ],
+ # AnimateDiff
+ "AnimateDiff_v2": [
+ ("guoyww/animatediff", "mm_sd_v15_v2.ckpt", "models/AnimateDiff"),
+ ],
+ "AnimateDiff_xl_beta": [
+ ("guoyww/animatediff", "mm_sdxl_v10_beta.ckpt", "models/AnimateDiff"),
+ ],
+
+ # Qwen Prompt
+ "QwenPrompt": [
+ ("Qwen/Qwen2-1.5B-Instruct", "config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
+ ("Qwen/Qwen2-1.5B-Instruct", "generation_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
+ ("Qwen/Qwen2-1.5B-Instruct", "model.safetensors", "models/QwenPrompt/qwen2-1.5b-instruct"),
+ ("Qwen/Qwen2-1.5B-Instruct", "special_tokens_map.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
+ ("Qwen/Qwen2-1.5B-Instruct", "tokenizer.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
+ ("Qwen/Qwen2-1.5B-Instruct", "tokenizer_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
+ ("Qwen/Qwen2-1.5B-Instruct", "merges.txt", "models/QwenPrompt/qwen2-1.5b-instruct"),
+ ("Qwen/Qwen2-1.5B-Instruct", "vocab.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
+ ],
+ # Beautiful Prompt
+ "BeautifulPrompt": [
+ ("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
+ ("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "generation_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
+ ("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "model.safetensors", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
+ ("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "special_tokens_map.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
+ ("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "tokenizer.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
+ ("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "tokenizer_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
+ ],
+ # Omost prompt
+ "OmostPrompt":[
+ ("lllyasviel/omost-llama-3-8b-4bits", "model-00001-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
+ ("lllyasviel/omost-llama-3-8b-4bits", "model-00002-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
+ ("lllyasviel/omost-llama-3-8b-4bits", "tokenizer.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
+ ("lllyasviel/omost-llama-3-8b-4bits", "tokenizer_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
+ ("lllyasviel/omost-llama-3-8b-4bits", "config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
+ ("lllyasviel/omost-llama-3-8b-4bits", "generation_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
+ ("lllyasviel/omost-llama-3-8b-4bits", "model.safetensors.index.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
+ ("lllyasviel/omost-llama-3-8b-4bits", "special_tokens_map.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
+ ],
+ # Translator
+ "opus-mt-zh-en": [
+ ("Helsinki-NLP/opus-mt-zh-en", "config.json", "models/translator/opus-mt-zh-en"),
+ ("Helsinki-NLP/opus-mt-zh-en", "generation_config.json", "models/translator/opus-mt-zh-en"),
+ ("Helsinki-NLP/opus-mt-zh-en", "metadata.json", "models/translator/opus-mt-zh-en"),
+ ("Helsinki-NLP/opus-mt-zh-en", "pytorch_model.bin", "models/translator/opus-mt-zh-en"),
+ ("Helsinki-NLP/opus-mt-zh-en", "source.spm", "models/translator/opus-mt-zh-en"),
+ ("Helsinki-NLP/opus-mt-zh-en", "target.spm", "models/translator/opus-mt-zh-en"),
+ ("Helsinki-NLP/opus-mt-zh-en", "tokenizer_config.json", "models/translator/opus-mt-zh-en"),
+ ("Helsinki-NLP/opus-mt-zh-en", "vocab.json", "models/translator/opus-mt-zh-en"),
+ ],
+ # IP-Adapter
+ "IP-Adapter-SD": [
+ ("h94/IP-Adapter", "models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion/image_encoder"),
+ ("h94/IP-Adapter", "models/ip-adapter_sd15.bin", "models/IpAdapter/stable_diffusion"),
+ ],
+ "IP-Adapter-SDXL": [
+ ("h94/IP-Adapter", "sdxl_models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion_xl/image_encoder"),
+ ("h94/IP-Adapter", "sdxl_models/ip-adapter_sdxl.bin", "models/IpAdapter/stable_diffusion_xl"),
+ ],
+ "SDXL-vae-fp16-fix": [
+ ("madebyollin/sdxl-vae-fp16-fix", "diffusion_pytorch_model.safetensors", "models/sdxl-vae-fp16-fix")
+ ],
+ # Kolors
+ "Kolors": [
+ ("Kwai-Kolors/Kolors", "text_encoder/config.json", "models/kolors/Kolors/text_encoder"),
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model.bin.index.json", "models/kolors/Kolors/text_encoder"),
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00001-of-00007.bin", "models/kolors/Kolors/text_encoder"),
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00002-of-00007.bin", "models/kolors/Kolors/text_encoder"),
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00003-of-00007.bin", "models/kolors/Kolors/text_encoder"),
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00004-of-00007.bin", "models/kolors/Kolors/text_encoder"),
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00005-of-00007.bin", "models/kolors/Kolors/text_encoder"),
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00006-of-00007.bin", "models/kolors/Kolors/text_encoder"),
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00007-of-00007.bin", "models/kolors/Kolors/text_encoder"),
+ ("Kwai-Kolors/Kolors", "unet/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/unet"),
+ ("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/vae"),
+ ],
+ # FLUX
+ "FLUX.1-dev": [
+ ("black-forest-labs/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"),
+ ("black-forest-labs/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
+ ("black-forest-labs/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
+ ("black-forest-labs/FLUX.1-dev", "text_encoder_2/model-00002-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
+ ("black-forest-labs/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
+ ("black-forest-labs/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
+ ("black-forest-labs/FLUX.1-dev", "flux1-dev.safetensors", "models/FLUX/FLUX.1-dev"),
+ ],
+ "InstantX/FLUX.1-dev-IP-Adapter": {
+ "file_list": [
+ ("InstantX/FLUX.1-dev-IP-Adapter", "ip-adapter.bin", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter"),
+ ("google/siglip-so400m-patch14-384", "model.safetensors", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
+ ("google/siglip-so400m-patch14-384", "config.json", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
+ ],
+ "load_path": [
+ "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/ip-adapter.bin",
+ "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder",
+ ],
+ },
+ # RIFE
+ "RIFE": [
+ ("AlexWortega/RIFE", "flownet.pkl", "models/RIFE"),
+ ],
+ # CogVideo
+ "CogVideoX-5B": [
+ ("THUDM/CogVideoX-5b", "text_encoder/config.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
+ ("THUDM/CogVideoX-5b", "text_encoder/model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
+ ("THUDM/CogVideoX-5b", "text_encoder/model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
+ ("THUDM/CogVideoX-5b", "text_encoder/model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
+ ("THUDM/CogVideoX-5b", "transformer/config.json", "models/CogVideo/CogVideoX-5b/transformer"),
+ ("THUDM/CogVideoX-5b", "transformer/diffusion_pytorch_model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/transformer"),
+ ("THUDM/CogVideoX-5b", "transformer/diffusion_pytorch_model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
+ ("THUDM/CogVideoX-5b", "transformer/diffusion_pytorch_model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
+ ("THUDM/CogVideoX-5b", "vae/diffusion_pytorch_model.safetensors", "models/CogVideo/CogVideoX-5b/vae"),
+ ],
+ # Stable Diffusion 3.5
+ "StableDiffusion3.5-large": [
+ ("stabilityai/stable-diffusion-3.5-large", "sd3.5_large.safetensors", "models/stable_diffusion_3"),
+ ("stabilityai/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"),
+ ("stabilityai/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
+ ("stabilityai/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
+ ],
+}
+preset_models_on_modelscope = {
+ # Hunyuan DiT
+ "HunyuanDiT": [
+ ("modelscope/HunyuanDiT", "t2i/clip_text_encoder/pytorch_model.bin", "models/HunyuanDiT/t2i/clip_text_encoder"),
+ ("modelscope/HunyuanDiT", "t2i/mt5/pytorch_model.bin", "models/HunyuanDiT/t2i/mt5"),
+ ("modelscope/HunyuanDiT", "t2i/model/pytorch_model_ema.pt", "models/HunyuanDiT/t2i/model"),
+ ("modelscope/HunyuanDiT", "t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin", "models/HunyuanDiT/t2i/sdxl-vae-fp16-fix"),
+ ],
+ # Stable Video Diffusion
+ "stable-video-diffusion-img2vid-xt": [
+ ("AI-ModelScope/stable-video-diffusion-img2vid-xt", "svd_xt.safetensors", "models/stable_video_diffusion"),
+ ],
+ # ExVideo
+ "ExVideo-SVD-128f-v1": [
+ ("ECNU-CILab/ExVideo-SVD-128f-v1", "model.fp16.safetensors", "models/stable_video_diffusion"),
+ ],
+ "ExVideo-CogVideoX-LoRA-129f-v1": [
+ ("ECNU-CILab/ExVideo-CogVideoX-LoRA-129f-v1", "ExVideo-CogVideoX-LoRA-129f-v1.safetensors", "models/lora"),
+ ],
+ # Stable Diffusion
+ "StableDiffusion_v15": [
+ ("AI-ModelScope/stable-diffusion-v1-5", "v1-5-pruned-emaonly.safetensors", "models/stable_diffusion"),
+ ],
+ "DreamShaper_8": [
+ ("sd_lora/dreamshaper_8", "dreamshaper_8.safetensors", "models/stable_diffusion"),
+ ],
+ "AingDiffusion_v12": [
+ ("sd_lora/aingdiffusion_v12", "aingdiffusion_v12.safetensors", "models/stable_diffusion"),
+ ],
+ "Flat2DAnimerge_v45Sharp": [
+ ("sd_lora/Flat-2D-Animerge", "flat2DAnimerge_v45Sharp.safetensors", "models/stable_diffusion"),
+ ],
+ # Textual Inversion
+ "TextualInversion_VeryBadImageNegative_v1.3": [
+ ("sd_lora/verybadimagenegative_v1.3", "verybadimagenegative_v1.3.pt", "models/textual_inversion"),
+ ],
+ # Stable Diffusion XL
+ "StableDiffusionXL_v1": [
+ ("AI-ModelScope/stable-diffusion-xl-base-1.0", "sd_xl_base_1.0.safetensors", "models/stable_diffusion_xl"),
+ ],
+ "BluePencilXL_v200": [
+ ("sd_lora/bluePencilXL_v200", "bluePencilXL_v200.safetensors", "models/stable_diffusion_xl"),
+ ],
+ "StableDiffusionXL_Turbo": [
+ ("AI-ModelScope/sdxl-turbo", "sd_xl_turbo_1.0_fp16.safetensors", "models/stable_diffusion_xl_turbo"),
+ ],
+ "SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0": [
+ ("sd_lora/zyd232_ChineseInkStyle_SDXL_v1_0", "zyd232_ChineseInkStyle_SDXL_v1_0.safetensors", "models/lora"),
+ ],
+ # Stable Diffusion 3
+ "StableDiffusion3": [
+ ("AI-ModelScope/stable-diffusion-3-medium", "sd3_medium_incl_clips_t5xxlfp16.safetensors", "models/stable_diffusion_3"),
+ ],
+ "StableDiffusion3_without_T5": [
+ ("AI-ModelScope/stable-diffusion-3-medium", "sd3_medium_incl_clips.safetensors", "models/stable_diffusion_3"),
+ ],
+ # ControlNet
+ "ControlNet_v11f1p_sd15_depth": [
+ ("AI-ModelScope/ControlNet-v1-1", "control_v11f1p_sd15_depth.pth", "models/ControlNet"),
+ ("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
+ ],
+ "ControlNet_v11p_sd15_softedge": [
+ ("AI-ModelScope/ControlNet-v1-1", "control_v11p_sd15_softedge.pth", "models/ControlNet"),
+ ("sd_lora/Annotators", "ControlNetHED.pth", "models/Annotators")
+ ],
+ "ControlNet_v11f1e_sd15_tile": [
+ ("AI-ModelScope/ControlNet-v1-1", "control_v11f1e_sd15_tile.pth", "models/ControlNet")
+ ],
+ "ControlNet_v11p_sd15_lineart": [
+ ("AI-ModelScope/ControlNet-v1-1", "control_v11p_sd15_lineart.pth", "models/ControlNet"),
+ ("sd_lora/Annotators", "sk_model.pth", "models/Annotators"),
+ ("sd_lora/Annotators", "sk_model2.pth", "models/Annotators")
+ ],
+ "ControlNet_union_sdxl_promax": [
+ ("AI-ModelScope/controlnet-union-sdxl-1.0", "diffusion_pytorch_model_promax.safetensors", "models/ControlNet/controlnet_union"),
+ ("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
+ ],
+ "Annotators:Depth": [
+ ("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators"),
+ ],
+ "Annotators:Softedge": [
+ ("sd_lora/Annotators", "ControlNetHED.pth", "models/Annotators"),
+ ],
+ "Annotators:Lineart": [
+ ("sd_lora/Annotators", "sk_model.pth", "models/Annotators"),
+ ("sd_lora/Annotators", "sk_model2.pth", "models/Annotators"),
+ ],
+ "Annotators:Normal": [
+ ("sd_lora/Annotators", "scannet.pt", "models/Annotators"),
+ ],
+ "Annotators:Openpose": [
+ ("sd_lora/Annotators", "body_pose_model.pth", "models/Annotators"),
+ ("sd_lora/Annotators", "facenet.pth", "models/Annotators"),
+ ("sd_lora/Annotators", "hand_pose_model.pth", "models/Annotators"),
+ ],
+ # AnimateDiff
+ "AnimateDiff_v2": [
+ ("Shanghai_AI_Laboratory/animatediff", "mm_sd_v15_v2.ckpt", "models/AnimateDiff"),
+ ],
+ "AnimateDiff_xl_beta": [
+ ("Shanghai_AI_Laboratory/animatediff", "mm_sdxl_v10_beta.ckpt", "models/AnimateDiff"),
+ ],
+ # RIFE
+ "RIFE": [
+ ("Damo_XR_Lab/cv_rife_video-frame-interpolation", "flownet.pkl", "models/RIFE"),
+ ],
+ # Qwen Prompt
+ "QwenPrompt": {
+ "file_list": [
+ ("qwen/Qwen2-1.5B-Instruct", "config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
+ ("qwen/Qwen2-1.5B-Instruct", "generation_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
+ ("qwen/Qwen2-1.5B-Instruct", "model.safetensors", "models/QwenPrompt/qwen2-1.5b-instruct"),
+ ("qwen/Qwen2-1.5B-Instruct", "special_tokens_map.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
+ ("qwen/Qwen2-1.5B-Instruct", "tokenizer.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
+ ("qwen/Qwen2-1.5B-Instruct", "tokenizer_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
+ ("qwen/Qwen2-1.5B-Instruct", "merges.txt", "models/QwenPrompt/qwen2-1.5b-instruct"),
+ ("qwen/Qwen2-1.5B-Instruct", "vocab.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
+ ],
+ "load_path": [
+ "models/QwenPrompt/qwen2-1.5b-instruct",
+ ],
+ },
+ # Beautiful Prompt
+ "BeautifulPrompt": {
+ "file_list": [
+ ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
+ ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "generation_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
+ ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "model.safetensors", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
+ ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "special_tokens_map.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
+ ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
+ ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
+ ],
+ "load_path": [
+ "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd",
+ ],
+ },
+ # Omost prompt
+ "OmostPrompt": {
+ "file_list": [
+ ("Omost/omost-llama-3-8b-4bits", "model-00001-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
+ ("Omost/omost-llama-3-8b-4bits", "model-00002-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
+ ("Omost/omost-llama-3-8b-4bits", "tokenizer.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
+ ("Omost/omost-llama-3-8b-4bits", "tokenizer_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
+ ("Omost/omost-llama-3-8b-4bits", "config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
+ ("Omost/omost-llama-3-8b-4bits", "generation_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
+ ("Omost/omost-llama-3-8b-4bits", "model.safetensors.index.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
+ ("Omost/omost-llama-3-8b-4bits", "special_tokens_map.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
+ ],
+ "load_path": [
+ "models/OmostPrompt/omost-llama-3-8b-4bits",
+ ],
+ },
+ # Translator
+ "opus-mt-zh-en": {
+ "file_list": [
+ ("moxying/opus-mt-zh-en", "config.json", "models/translator/opus-mt-zh-en"),
+ ("moxying/opus-mt-zh-en", "generation_config.json", "models/translator/opus-mt-zh-en"),
+ ("moxying/opus-mt-zh-en", "metadata.json", "models/translator/opus-mt-zh-en"),
+ ("moxying/opus-mt-zh-en", "pytorch_model.bin", "models/translator/opus-mt-zh-en"),
+ ("moxying/opus-mt-zh-en", "source.spm", "models/translator/opus-mt-zh-en"),
+ ("moxying/opus-mt-zh-en", "target.spm", "models/translator/opus-mt-zh-en"),
+ ("moxying/opus-mt-zh-en", "tokenizer_config.json", "models/translator/opus-mt-zh-en"),
+ ("moxying/opus-mt-zh-en", "vocab.json", "models/translator/opus-mt-zh-en"),
+ ],
+ "load_path": [
+ "models/translator/opus-mt-zh-en",
+ ],
+ },
+ # IP-Adapter
+ "IP-Adapter-SD": [
+ ("AI-ModelScope/IP-Adapter", "models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion/image_encoder"),
+ ("AI-ModelScope/IP-Adapter", "models/ip-adapter_sd15.bin", "models/IpAdapter/stable_diffusion"),
+ ],
+ "IP-Adapter-SDXL": [
+ ("AI-ModelScope/IP-Adapter", "sdxl_models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion_xl/image_encoder"),
+ ("AI-ModelScope/IP-Adapter", "sdxl_models/ip-adapter_sdxl.bin", "models/IpAdapter/stable_diffusion_xl"),
+ ],
+ # Kolors
+ "Kolors": {
+ "file_list": [
+ ("Kwai-Kolors/Kolors", "text_encoder/config.json", "models/kolors/Kolors/text_encoder"),
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model.bin.index.json", "models/kolors/Kolors/text_encoder"),
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00001-of-00007.bin", "models/kolors/Kolors/text_encoder"),
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00002-of-00007.bin", "models/kolors/Kolors/text_encoder"),
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00003-of-00007.bin", "models/kolors/Kolors/text_encoder"),
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00004-of-00007.bin", "models/kolors/Kolors/text_encoder"),
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00005-of-00007.bin", "models/kolors/Kolors/text_encoder"),
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00006-of-00007.bin", "models/kolors/Kolors/text_encoder"),
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00007-of-00007.bin", "models/kolors/Kolors/text_encoder"),
+ ("Kwai-Kolors/Kolors", "unet/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/unet"),
+ ("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/vae"),
+ ],
+ "load_path": [
+ "models/kolors/Kolors/text_encoder",
+ "models/kolors/Kolors/unet/diffusion_pytorch_model.safetensors",
+ "models/kolors/Kolors/vae/diffusion_pytorch_model.safetensors",
+ ],
+ },
+ "SDXL-vae-fp16-fix": [
+ ("AI-ModelScope/sdxl-vae-fp16-fix", "diffusion_pytorch_model.safetensors", "models/sdxl-vae-fp16-fix")
+ ],
+ # FLUX
+ "FLUX.1-dev": {
+ "file_list": [
+ ("AI-ModelScope/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"),
+ ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
+ ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
+ ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00002-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
+ ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
+ ("AI-ModelScope/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
+ ("AI-ModelScope/FLUX.1-dev", "flux1-dev.safetensors", "models/FLUX/FLUX.1-dev"),
+ ],
+ "load_path": [
+ "models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
+ "models/FLUX/FLUX.1-dev/text_encoder_2",
+ "models/FLUX/FLUX.1-dev/ae.safetensors",
+ "models/FLUX/FLUX.1-dev/flux1-dev.safetensors"
+ ],
+ },
+ "FLUX.1-schnell": {
+ "file_list": [
+ ("AI-ModelScope/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"),
+ ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
+ ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
+ ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00002-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
+ ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
+ ("AI-ModelScope/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
+ ("AI-ModelScope/FLUX.1-schnell", "flux1-schnell.safetensors", "models/FLUX/FLUX.1-schnell"),
+ ],
+ "load_path": [
+ "models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
+ "models/FLUX/FLUX.1-dev/text_encoder_2",
+ "models/FLUX/FLUX.1-dev/ae.safetensors",
+ "models/FLUX/FLUX.1-schnell/flux1-schnell.safetensors"
+ ],
+ },
+ "InstantX/FLUX.1-dev-Controlnet-Union-alpha": [
+ ("InstantX/FLUX.1-dev-Controlnet-Union-alpha", "diffusion_pytorch_model.safetensors", "models/ControlNet/InstantX/FLUX.1-dev-Controlnet-Union-alpha"),
+ ],
+ "jasperai/Flux.1-dev-Controlnet-Depth": [
+ ("jasperai/Flux.1-dev-Controlnet-Depth", "diffusion_pytorch_model.safetensors", "models/ControlNet/jasperai/Flux.1-dev-Controlnet-Depth"),
+ ],
+ "jasperai/Flux.1-dev-Controlnet-Surface-Normals": [
+ ("jasperai/Flux.1-dev-Controlnet-Surface-Normals", "diffusion_pytorch_model.safetensors", "models/ControlNet/jasperai/Flux.1-dev-Controlnet-Surface-Normals"),
+ ],
+ "jasperai/Flux.1-dev-Controlnet-Upscaler": [
+ ("jasperai/Flux.1-dev-Controlnet-Upscaler", "diffusion_pytorch_model.safetensors", "models/ControlNet/jasperai/Flux.1-dev-Controlnet-Upscaler"),
+ ],
+ "alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha": [
+ ("alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha", "diffusion_pytorch_model.safetensors", "models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha"),
+ ],
+ "alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta": [
+ ("alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta", "diffusion_pytorch_model.safetensors", "models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta"),
+ ],
+ "Shakker-Labs/FLUX.1-dev-ControlNet-Depth": [
+ ("Shakker-Labs/FLUX.1-dev-ControlNet-Depth", "diffusion_pytorch_model.safetensors", "models/ControlNet/Shakker-Labs/FLUX.1-dev-ControlNet-Depth"),
+ ],
+ "Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro": [
+ ("Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro", "diffusion_pytorch_model.safetensors", "models/ControlNet/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro"),
+ ],
+ "InstantX/FLUX.1-dev-IP-Adapter": {
+ "file_list": [
+ ("InstantX/FLUX.1-dev-IP-Adapter", "ip-adapter.bin", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter"),
+ ("AI-ModelScope/siglip-so400m-patch14-384", "model.safetensors", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
+ ("AI-ModelScope/siglip-so400m-patch14-384", "config.json", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
+ ],
+ "load_path": [
+ "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/ip-adapter.bin",
+ "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder",
+ ],
+ },
+ "InfiniteYou":{
+ "file_list":[
+ ("ByteDance/InfiniteYou", "infu_flux_v1.0/aes_stage2/InfuseNetModel/diffusion_pytorch_model-00001-of-00002.safetensors", "models/InfiniteYou/InfuseNetModel"),
+ ("ByteDance/InfiniteYou", "infu_flux_v1.0/aes_stage2/InfuseNetModel/diffusion_pytorch_model-00002-of-00002.safetensors", "models/InfiniteYou/InfuseNetModel"),
+ ("ByteDance/InfiniteYou", "infu_flux_v1.0/aes_stage2/image_proj_model.bin", "models/InfiniteYou"),
+ ("ByteDance/InfiniteYou", "supports/insightface/models/antelopev2/1k3d68.onnx", "models/InfiniteYou/insightface/models/antelopev2"),
+ ("ByteDance/InfiniteYou", "supports/insightface/models/antelopev2/2d106det.onnx", "models/InfiniteYou/insightface/models/antelopev2"),
+ ("ByteDance/InfiniteYou", "supports/insightface/models/antelopev2/genderage.onnx", "models/InfiniteYou/insightface/models/antelopev2"),
+ ("ByteDance/InfiniteYou", "supports/insightface/models/antelopev2/glintr100.onnx", "models/InfiniteYou/insightface/models/antelopev2"),
+ ("ByteDance/InfiniteYou", "supports/insightface/models/antelopev2/scrfd_10g_bnkps.onnx", "models/InfiniteYou/insightface/models/antelopev2"),
+ ],
+ "load_path":[
+ [
+ "models/InfiniteYou/InfuseNetModel/diffusion_pytorch_model-00001-of-00002.safetensors",
+ "models/InfiniteYou/InfuseNetModel/diffusion_pytorch_model-00002-of-00002.safetensors"
+ ],
+ "models/InfiniteYou/image_proj_model.bin",
+ ],
+ },
+ # ESRGAN
+ "ESRGAN_x4": [
+ ("AI-ModelScope/Real-ESRGAN", "RealESRGAN_x4.pth", "models/ESRGAN"),
+ ],
+ # RIFE
+ "RIFE": [
+ ("AI-ModelScope/RIFE", "flownet.pkl", "models/RIFE"),
+ ],
+ # Omnigen
+ "OmniGen-v1": {
+ "file_list": [
+ ("BAAI/OmniGen-v1", "vae/diffusion_pytorch_model.safetensors", "models/OmniGen/OmniGen-v1/vae"),
+ ("BAAI/OmniGen-v1", "model.safetensors", "models/OmniGen/OmniGen-v1"),
+ ("BAAI/OmniGen-v1", "config.json", "models/OmniGen/OmniGen-v1"),
+ ("BAAI/OmniGen-v1", "special_tokens_map.json", "models/OmniGen/OmniGen-v1"),
+ ("BAAI/OmniGen-v1", "tokenizer_config.json", "models/OmniGen/OmniGen-v1"),
+ ("BAAI/OmniGen-v1", "tokenizer.json", "models/OmniGen/OmniGen-v1"),
+ ],
+ "load_path": [
+ "models/OmniGen/OmniGen-v1/vae/diffusion_pytorch_model.safetensors",
+ "models/OmniGen/OmniGen-v1/model.safetensors",
+ ]
+ },
+ # CogVideo
+ "CogVideoX-5B": {
+ "file_list": [
+ ("ZhipuAI/CogVideoX-5b", "text_encoder/config.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
+ ("ZhipuAI/CogVideoX-5b", "text_encoder/model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
+ ("ZhipuAI/CogVideoX-5b", "text_encoder/model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
+ ("ZhipuAI/CogVideoX-5b", "text_encoder/model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
+ ("ZhipuAI/CogVideoX-5b", "transformer/config.json", "models/CogVideo/CogVideoX-5b/transformer"),
+ ("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/transformer"),
+ ("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
+ ("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
+ ("ZhipuAI/CogVideoX-5b", "vae/diffusion_pytorch_model.safetensors", "models/CogVideo/CogVideoX-5b/vae"),
+ ],
+ "load_path": [
+ "models/CogVideo/CogVideoX-5b/text_encoder",
+ "models/CogVideo/CogVideoX-5b/transformer",
+ "models/CogVideo/CogVideoX-5b/vae/diffusion_pytorch_model.safetensors",
+ ],
+ },
+ # Stable Diffusion 3.5
+ "StableDiffusion3.5-large": [
+ ("AI-ModelScope/stable-diffusion-3.5-large", "sd3.5_large.safetensors", "models/stable_diffusion_3"),
+ ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"),
+ ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
+ ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
+ ],
+ "StableDiffusion3.5-medium": [
+ ("AI-ModelScope/stable-diffusion-3.5-medium", "sd3.5_medium.safetensors", "models/stable_diffusion_3"),
+ ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"),
+ ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
+ ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
+ ],
+ "StableDiffusion3.5-large-turbo": [
+ ("AI-ModelScope/stable-diffusion-3.5-large-turbo", "sd3.5_large_turbo.safetensors", "models/stable_diffusion_3"),
+ ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"),
+ ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
+ ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
+ ],
+ "HunyuanVideo":{
+ "file_list": [
+ ("AI-ModelScope/clip-vit-large-patch14", "model.safetensors", "models/HunyuanVideo/text_encoder"),
+ ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00001-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
+ ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00002-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
+ ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00003-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
+ ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00004-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
+ ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "config.json", "models/HunyuanVideo/text_encoder_2"),
+ ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model.safetensors.index.json", "models/HunyuanVideo/text_encoder_2"),
+ ("AI-ModelScope/HunyuanVideo", "hunyuan-video-t2v-720p/vae/pytorch_model.pt", "models/HunyuanVideo/vae"),
+ ("AI-ModelScope/HunyuanVideo", "hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt", "models/HunyuanVideo/transformers")
+ ],
+ "load_path": [
+ "models/HunyuanVideo/text_encoder/model.safetensors",
+ "models/HunyuanVideo/text_encoder_2",
+ "models/HunyuanVideo/vae/pytorch_model.pt",
+ "models/HunyuanVideo/transformers/mp_rank_00_model_states.pt"
+ ],
+ },
+ "HunyuanVideoI2V":{
+ "file_list": [
+ ("AI-ModelScope/clip-vit-large-patch14", "model.safetensors", "models/HunyuanVideoI2V/text_encoder"),
+ ("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00001-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"),
+ ("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00002-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"),
+ ("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00003-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"),
+ ("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00004-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"),
+ ("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "config.json", "models/HunyuanVideoI2V/text_encoder_2"),
+ ("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model.safetensors.index.json", "models/HunyuanVideoI2V/text_encoder_2"),
+ ("AI-ModelScope/HunyuanVideo-I2V", "hunyuan-video-i2v-720p/vae/pytorch_model.pt", "models/HunyuanVideoI2V/vae"),
+ ("AI-ModelScope/HunyuanVideo-I2V", "hunyuan-video-i2v-720p/transformers/mp_rank_00_model_states.pt", "models/HunyuanVideoI2V/transformers")
+ ],
+ "load_path": [
+ "models/HunyuanVideoI2V/text_encoder/model.safetensors",
+ "models/HunyuanVideoI2V/text_encoder_2",
+ "models/HunyuanVideoI2V/vae/pytorch_model.pt",
+ "models/HunyuanVideoI2V/transformers/mp_rank_00_model_states.pt"
+ ],
+ },
+ "HunyuanVideo-fp8":{
+ "file_list": [
+ ("AI-ModelScope/clip-vit-large-patch14", "model.safetensors", "models/HunyuanVideo/text_encoder"),
+ ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00001-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
+ ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00002-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
+ ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00003-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
+ ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00004-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
+ ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "config.json", "models/HunyuanVideo/text_encoder_2"),
+ ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model.safetensors.index.json", "models/HunyuanVideo/text_encoder_2"),
+ ("AI-ModelScope/HunyuanVideo", "hunyuan-video-t2v-720p/vae/pytorch_model.pt", "models/HunyuanVideo/vae"),
+ ("DiffSynth-Studio/HunyuanVideo-safetensors", "model.fp8.safetensors", "models/HunyuanVideo/transformers")
+ ],
+ "load_path": [
+ "models/HunyuanVideo/text_encoder/model.safetensors",
+ "models/HunyuanVideo/text_encoder_2",
+ "models/HunyuanVideo/vae/pytorch_model.pt",
+ "models/HunyuanVideo/transformers/model.fp8.safetensors"
+ ],
+ },
+}
+Preset_model_id: TypeAlias = Literal[
+ "HunyuanDiT",
+ "stable-video-diffusion-img2vid-xt",
+ "ExVideo-SVD-128f-v1",
+ "ExVideo-CogVideoX-LoRA-129f-v1",
+ "StableDiffusion_v15",
+ "DreamShaper_8",
+ "AingDiffusion_v12",
+ "Flat2DAnimerge_v45Sharp",
+ "TextualInversion_VeryBadImageNegative_v1.3",
+ "StableDiffusionXL_v1",
+ "BluePencilXL_v200",
+ "StableDiffusionXL_Turbo",
+ "ControlNet_v11f1p_sd15_depth",
+ "ControlNet_v11p_sd15_softedge",
+ "ControlNet_v11f1e_sd15_tile",
+ "ControlNet_v11p_sd15_lineart",
+ "AnimateDiff_v2",
+ "AnimateDiff_xl_beta",
+ "RIFE",
+ "BeautifulPrompt",
+ "opus-mt-zh-en",
+ "IP-Adapter-SD",
+ "IP-Adapter-SDXL",
+ "StableDiffusion3",
+ "StableDiffusion3_without_T5",
+ "Kolors",
+ "SDXL-vae-fp16-fix",
+ "ControlNet_union_sdxl_promax",
+ "FLUX.1-dev",
+ "FLUX.1-schnell",
+ "InstantX/FLUX.1-dev-Controlnet-Union-alpha",
+ "jasperai/Flux.1-dev-Controlnet-Depth",
+ "jasperai/Flux.1-dev-Controlnet-Surface-Normals",
+ "jasperai/Flux.1-dev-Controlnet-Upscaler",
+ "alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha",
+ "alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta",
+ "Shakker-Labs/FLUX.1-dev-ControlNet-Depth",
+ "Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro",
+ "InstantX/FLUX.1-dev-IP-Adapter",
+ "InfiniteYou",
+ "SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0",
+ "QwenPrompt",
+ "OmostPrompt",
+ "ESRGAN_x4",
+ "RIFE",
+ "OmniGen-v1",
+ "CogVideoX-5B",
+ "Annotators:Depth",
+ "Annotators:Softedge",
+ "Annotators:Lineart",
+ "Annotators:Normal",
+ "Annotators:Openpose",
+ "StableDiffusion3.5-large",
+ "StableDiffusion3.5-medium",
+ "HunyuanVideo",
+ "HunyuanVideo-fp8",
+ "HunyuanVideoI2V",
+]
diff --git a/diffsynth/controlnets/__init__.py b/diffsynth/controlnets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a3e15add6ab116bf261804b8c83c86ff4d61c41b
--- /dev/null
+++ b/diffsynth/controlnets/__init__.py
@@ -0,0 +1,2 @@
+from .controlnet_unit import ControlNetConfigUnit, ControlNetUnit, MultiControlNetManager, FluxMultiControlNetManager
+from .processors import Annotator
diff --git a/diffsynth/controlnets/controlnet_unit.py b/diffsynth/controlnets/controlnet_unit.py
new file mode 100644
index 0000000000000000000000000000000000000000..fdb4829483d208ec0295d1b5a8f82681b4251ea4
--- /dev/null
+++ b/diffsynth/controlnets/controlnet_unit.py
@@ -0,0 +1,91 @@
+import torch
+import numpy as np
+from .processors import Processor_id
+
+
+class ControlNetConfigUnit:
+ def __init__(self, processor_id: Processor_id, model_path, scale=1.0, skip_processor=False):
+ self.processor_id = processor_id
+ self.model_path = model_path
+ self.scale = scale
+ self.skip_processor = skip_processor
+
+
+class ControlNetUnit:
+ def __init__(self, processor, model, scale=1.0):
+ self.processor = processor
+ self.model = model
+ self.scale = scale
+
+
+class MultiControlNetManager:
+ def __init__(self, controlnet_units=[]):
+ self.processors = [unit.processor for unit in controlnet_units]
+ self.models = [unit.model for unit in controlnet_units]
+ self.scales = [unit.scale for unit in controlnet_units]
+
+ def cpu(self):
+ for model in self.models:
+ model.cpu()
+
+ def to(self, device):
+ for model in self.models:
+ model.to(device)
+ for processor in self.processors:
+ processor.to(device)
+
+ def process_image(self, image, processor_id=None):
+ if processor_id is None:
+ processed_image = [processor(image) for processor in self.processors]
+ else:
+ processed_image = [self.processors[processor_id](image)]
+ processed_image = torch.concat([
+ torch.Tensor(np.array(image_, dtype=np.float32) / 255).permute(2, 0, 1).unsqueeze(0)
+ for image_ in processed_image
+ ], dim=0)
+ return processed_image
+
+ def __call__(
+ self,
+ sample, timestep, encoder_hidden_states, conditionings,
+ tiled=False, tile_size=64, tile_stride=32, **kwargs
+ ):
+ res_stack = None
+ for processor, conditioning, model, scale in zip(self.processors, conditionings, self.models, self.scales):
+ res_stack_ = model(
+ sample, timestep, encoder_hidden_states, conditioning, **kwargs,
+ tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
+ processor_id=processor.processor_id
+ )
+ res_stack_ = [res * scale for res in res_stack_]
+ if res_stack is None:
+ res_stack = res_stack_
+ else:
+ res_stack = [i + j for i, j in zip(res_stack, res_stack_)]
+ return res_stack
+
+
+class FluxMultiControlNetManager(MultiControlNetManager):
+ def __init__(self, controlnet_units=[]):
+ super().__init__(controlnet_units=controlnet_units)
+
+ def process_image(self, image, processor_id=None):
+ if processor_id is None:
+ processed_image = [processor(image) for processor in self.processors]
+ else:
+ processed_image = [self.processors[processor_id](image)]
+ return processed_image
+
+ def __call__(self, conditionings, **kwargs):
+ res_stack, single_res_stack = None, None
+ for processor, conditioning, model, scale in zip(self.processors, conditionings, self.models, self.scales):
+ res_stack_, single_res_stack_ = model(controlnet_conditioning=conditioning, processor_id=processor.processor_id, **kwargs)
+ res_stack_ = [res * scale for res in res_stack_]
+ single_res_stack_ = [res * scale for res in single_res_stack_]
+ if res_stack is None:
+ res_stack = res_stack_
+ single_res_stack = single_res_stack_
+ else:
+ res_stack = [i + j for i, j in zip(res_stack, res_stack_)]
+ single_res_stack = [i + j for i, j in zip(single_res_stack, single_res_stack_)]
+ return res_stack, single_res_stack
diff --git a/diffsynth/controlnets/processors.py b/diffsynth/controlnets/processors.py
new file mode 100644
index 0000000000000000000000000000000000000000..06553e06d1c6d09f5a3deecfd4ea5604c5dd4352
--- /dev/null
+++ b/diffsynth/controlnets/processors.py
@@ -0,0 +1,62 @@
+from typing_extensions import Literal, TypeAlias
+
+
+Processor_id: TypeAlias = Literal[
+ "canny", "depth", "softedge", "lineart", "lineart_anime", "openpose", "normal", "tile", "none", "inpaint"
+]
+
+class Annotator:
+ def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None, device='cuda', skip_processor=False):
+ if not skip_processor:
+ if processor_id == "canny":
+ from controlnet_aux.processor import CannyDetector
+ self.processor = CannyDetector()
+ elif processor_id == "depth":
+ from controlnet_aux.processor import MidasDetector
+ self.processor = MidasDetector.from_pretrained(model_path).to(device)
+ elif processor_id == "softedge":
+ from controlnet_aux.processor import HEDdetector
+ self.processor = HEDdetector.from_pretrained(model_path).to(device)
+ elif processor_id == "lineart":
+ from controlnet_aux.processor import LineartDetector
+ self.processor = LineartDetector.from_pretrained(model_path).to(device)
+ elif processor_id == "lineart_anime":
+ from controlnet_aux.processor import LineartAnimeDetector
+ self.processor = LineartAnimeDetector.from_pretrained(model_path).to(device)
+ elif processor_id == "openpose":
+ from controlnet_aux.processor import OpenposeDetector
+ self.processor = OpenposeDetector.from_pretrained(model_path).to(device)
+ elif processor_id == "normal":
+ from controlnet_aux.processor import NormalBaeDetector
+ self.processor = NormalBaeDetector.from_pretrained(model_path).to(device)
+ elif processor_id == "tile" or processor_id == "none" or processor_id == "inpaint":
+ self.processor = None
+ else:
+ raise ValueError(f"Unsupported processor_id: {processor_id}")
+ else:
+ self.processor = None
+
+ self.processor_id = processor_id
+ self.detect_resolution = detect_resolution
+
+ def to(self,device):
+ if hasattr(self.processor,"model") and hasattr(self.processor.model,"to"):
+
+ self.processor.model.to(device)
+
+ def __call__(self, image, mask=None):
+ width, height = image.size
+ if self.processor_id == "openpose":
+ kwargs = {
+ "include_body": True,
+ "include_hand": True,
+ "include_face": True
+ }
+ else:
+ kwargs = {}
+ if self.processor is not None:
+ detect_resolution = self.detect_resolution if self.detect_resolution is not None else min(width, height)
+ image = self.processor(image, detect_resolution=detect_resolution, image_resolution=min(width, height), **kwargs)
+ image = image.resize((width, height))
+ return image
+
diff --git a/diffsynth/data/__init__.py b/diffsynth/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..749c03fbc421e30b5190703587e57af8412aab73
--- /dev/null
+++ b/diffsynth/data/__init__.py
@@ -0,0 +1 @@
+from .video import VideoData, save_video, save_frames, merge_video_audio, save_video_with_audio
diff --git a/diffsynth/data/simple_text_image.py b/diffsynth/data/simple_text_image.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a9525e3c8a4d21418c1464fe11fc621450fd0d8
--- /dev/null
+++ b/diffsynth/data/simple_text_image.py
@@ -0,0 +1,41 @@
+import torch, os, torchvision
+from torchvision import transforms
+import pandas as pd
+from PIL import Image
+
+
+
+class TextImageDataset(torch.utils.data.Dataset):
+ def __init__(self, dataset_path, steps_per_epoch=10000, height=1024, width=1024, center_crop=True, random_flip=False):
+ self.steps_per_epoch = steps_per_epoch
+ metadata = pd.read_csv(os.path.join(dataset_path, "train/metadata.csv"))
+ self.path = [os.path.join(dataset_path, "train", file_name) for file_name in metadata["file_name"]]
+ self.text = metadata["text"].to_list()
+ self.height = height
+ self.width = width
+ self.image_processor = transforms.Compose(
+ [
+ transforms.CenterCrop((height, width)) if center_crop else transforms.RandomCrop((height, width)),
+ transforms.RandomHorizontalFlip() if random_flip else transforms.Lambda(lambda x: x),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+
+ def __getitem__(self, index):
+ data_id = torch.randint(0, len(self.path), (1,))[0]
+ data_id = (data_id + index) % len(self.path) # For fixed seed.
+ text = self.text[data_id]
+ image = Image.open(self.path[data_id]).convert("RGB")
+ target_height, target_width = self.height, self.width
+ width, height = image.size
+ scale = max(target_width / width, target_height / height)
+ shape = [round(height*scale),round(width*scale)]
+ image = torchvision.transforms.functional.resize(image,shape,interpolation=transforms.InterpolationMode.BILINEAR)
+ image = self.image_processor(image)
+ return {"text": text, "image": image}
+
+
+ def __len__(self):
+ return self.steps_per_epoch
diff --git a/diffsynth/data/video.py b/diffsynth/data/video.py
new file mode 100644
index 0000000000000000000000000000000000000000..c6b9daa41bea9d36e012d52a1d280d1cf8d92850
--- /dev/null
+++ b/diffsynth/data/video.py
@@ -0,0 +1,217 @@
+import imageio, os
+import numpy as np
+from PIL import Image
+from tqdm import tqdm
+import subprocess
+import shutil
+
+
+class LowMemoryVideo:
+ def __init__(self, file_name):
+ self.reader = imageio.get_reader(file_name)
+
+ def __len__(self):
+ return self.reader.count_frames()
+
+ def __getitem__(self, item):
+ return Image.fromarray(np.array(self.reader.get_data(item))).convert("RGB")
+
+ def __del__(self):
+ self.reader.close()
+
+
+def split_file_name(file_name):
+ result = []
+ number = -1
+ for i in file_name:
+ if ord(i)>=ord("0") and ord(i)<=ord("9"):
+ if number == -1:
+ number = 0
+ number = number*10 + ord(i) - ord("0")
+ else:
+ if number != -1:
+ result.append(number)
+ number = -1
+ result.append(i)
+ if number != -1:
+ result.append(number)
+ result = tuple(result)
+ return result
+
+
+def search_for_images(folder):
+ file_list = [i for i in os.listdir(folder) if i.endswith(".jpg") or i.endswith(".png")]
+ file_list = [(split_file_name(file_name), file_name) for file_name in file_list]
+ file_list = [i[1] for i in sorted(file_list)]
+ file_list = [os.path.join(folder, i) for i in file_list]
+ return file_list
+
+
+class LowMemoryImageFolder:
+ def __init__(self, folder, file_list=None):
+ if file_list is None:
+ self.file_list = search_for_images(folder)
+ else:
+ self.file_list = [os.path.join(folder, file_name) for file_name in file_list]
+
+ def __len__(self):
+ return len(self.file_list)
+
+ def __getitem__(self, item):
+ return Image.open(self.file_list[item]).convert("RGB")
+
+ def __del__(self):
+ pass
+
+
+def crop_and_resize(image, height, width):
+ image = np.array(image)
+ image_height, image_width, _ = image.shape
+ if image_height / image_width < height / width:
+ croped_width = int(image_height / height * width)
+ left = (image_width - croped_width) // 2
+ image = image[:, left: left+croped_width]
+ image = Image.fromarray(image).resize((width, height))
+ else:
+ croped_height = int(image_width / width * height)
+ left = (image_height - croped_height) // 2
+ image = image[left: left+croped_height, :]
+ image = Image.fromarray(image).resize((width, height))
+ return image
+
+
+class VideoData:
+ def __init__(self, video_file=None, image_folder=None, height=None, width=None, **kwargs):
+ if video_file is not None:
+ self.data_type = "video"
+ self.data = LowMemoryVideo(video_file, **kwargs)
+ elif image_folder is not None:
+ self.data_type = "images"
+ self.data = LowMemoryImageFolder(image_folder, **kwargs)
+ else:
+ raise ValueError("Cannot open video or image folder")
+ self.length = None
+ self.set_shape(height, width)
+
+ def raw_data(self):
+ frames = []
+ for i in range(self.__len__()):
+ frames.append(self.__getitem__(i))
+ return frames
+
+ def set_length(self, length):
+ self.length = length
+
+ def set_shape(self, height, width):
+ self.height = height
+ self.width = width
+
+ def __len__(self):
+ if self.length is None:
+ return len(self.data)
+ else:
+ return self.length
+
+ def shape(self):
+ if self.height is not None and self.width is not None:
+ return self.height, self.width
+ else:
+ height, width, _ = self.__getitem__(0).shape
+ return height, width
+
+ def __getitem__(self, item):
+ frame = self.data.__getitem__(item)
+ width, height = frame.size
+ if self.height is not None and self.width is not None:
+ if self.height != height or self.width != width:
+ frame = crop_and_resize(frame, self.height, self.width)
+ return frame
+
+ def __del__(self):
+ pass
+
+ def save_images(self, folder):
+ os.makedirs(folder, exist_ok=True)
+ for i in tqdm(range(self.__len__()), desc="Saving images"):
+ frame = self.__getitem__(i)
+ frame.save(os.path.join(folder, f"{i}.png"))
+
+
+def save_video(frames, save_path, fps, quality=9, ffmpeg_params=None):
+ writer = imageio.get_writer(save_path, fps=fps, quality=quality, ffmpeg_params=ffmpeg_params)
+ for frame in tqdm(frames, desc="Saving video"):
+ frame = np.array(frame)
+ writer.append_data(frame)
+ writer.close()
+
+def save_frames(frames, save_path):
+ os.makedirs(save_path, exist_ok=True)
+ for i, frame in enumerate(tqdm(frames, desc="Saving images")):
+ frame.save(os.path.join(save_path, f"{i}.png"))
+
+
+def merge_video_audio(video_path: str, audio_path: str):
+ # TODO: may need a in-python implementation to avoid subprocess dependency
+ """
+ Merge the video and audio into a new video, with the duration set to the shorter of the two,
+ and overwrite the original video file.
+
+ Parameters:
+ video_path (str): Path to the original video file
+ audio_path (str): Path to the audio file
+ """
+
+ # check
+ if not os.path.exists(video_path):
+ raise FileNotFoundError(f"video file {video_path} does not exist")
+ if not os.path.exists(audio_path):
+ raise FileNotFoundError(f"audio file {audio_path} does not exist")
+
+ base, ext = os.path.splitext(video_path)
+ temp_output = f"{base}_temp{ext}"
+
+ try:
+ # create ffmpeg command
+ command = [
+ 'ffmpeg',
+ '-y', # overwrite
+ '-i',
+ video_path,
+ '-i',
+ audio_path,
+ '-c:v',
+ 'copy', # copy video stream
+ '-c:a',
+ 'aac', # use AAC audio encoder
+ '-b:a',
+ '192k', # set audio bitrate (optional)
+ '-map',
+ '0:v:0', # select the first video stream
+ '-map',
+ '1:a:0', # select the first audio stream
+ '-shortest', # choose the shortest duration
+ temp_output
+ ]
+
+ # execute the command
+ result = subprocess.run(
+ command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
+
+ # check result
+ if result.returncode != 0:
+ error_msg = f"FFmpeg execute failed: {result.stderr}"
+ print(error_msg)
+ raise RuntimeError(error_msg)
+
+ shutil.move(temp_output, video_path)
+ print(f"Merge completed, saved to {video_path}")
+
+ except Exception as e:
+ if os.path.exists(temp_output):
+ os.remove(temp_output)
+ print(f"merge_video_audio failed with error: {e}")
+
+
+def save_video_with_audio(frames, save_path, audio_path, fps=16, quality=9, ffmpeg_params=None):
+ save_video(frames, save_path, fps, quality, ffmpeg_params)
+ merge_video_audio(save_path, audio_path)
diff --git a/diffsynth/distributed/__init__.py b/diffsynth/distributed/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/diffsynth/distributed/xdit_context_parallel.py b/diffsynth/distributed/xdit_context_parallel.py
new file mode 100644
index 0000000000000000000000000000000000000000..4887e2f16a87fa3a28e8057de41c6537be83f8ed
--- /dev/null
+++ b/diffsynth/distributed/xdit_context_parallel.py
@@ -0,0 +1,131 @@
+import torch
+from typing import Optional
+from einops import rearrange
+from xfuser.core.distributed import (get_sequence_parallel_rank,
+ get_sequence_parallel_world_size,
+ get_sp_group)
+from xfuser.core.long_ctx_attention import xFuserLongContextAttention
+
+def sinusoidal_embedding_1d(dim, position):
+ sinusoid = torch.outer(position.type(torch.float64), torch.pow(
+ 10000, -torch.arange(dim//2, dtype=torch.float64, device=position.device).div(dim//2)))
+ x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
+ return x.to(position.dtype)
+
+def pad_freqs(original_tensor, target_len):
+ seq_len, s1, s2 = original_tensor.shape
+ pad_size = target_len - seq_len
+ padding_tensor = torch.ones(
+ pad_size,
+ s1,
+ s2,
+ dtype=original_tensor.dtype,
+ device=original_tensor.device)
+ padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
+ return padded_tensor
+
+def rope_apply(x, freqs, num_heads):
+ x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
+ s_per_rank = x.shape[1]
+
+ x_out = torch.view_as_complex(x.to(torch.float64).reshape(
+ x.shape[0], x.shape[1], x.shape[2], -1, 2))
+
+ sp_size = get_sequence_parallel_world_size()
+ sp_rank = get_sequence_parallel_rank()
+ freqs = pad_freqs(freqs, s_per_rank * sp_size)
+ freqs_rank = freqs[(sp_rank * s_per_rank):((sp_rank + 1) * s_per_rank), :, :]
+
+ x_out = torch.view_as_real(x_out * freqs_rank).flatten(2)
+ return x_out.to(x.dtype)
+
+def usp_dit_forward(self,
+ x: torch.Tensor,
+ timestep: torch.Tensor,
+ context: torch.Tensor,
+ clip_feature: Optional[torch.Tensor] = None,
+ y: Optional[torch.Tensor] = None,
+ use_gradient_checkpointing: bool = False,
+ use_gradient_checkpointing_offload: bool = False,
+ **kwargs,
+ ):
+ t = self.time_embedding(
+ sinusoidal_embedding_1d(self.freq_dim, timestep))
+ t_mod = self.time_projection(t).unflatten(1, (6, self.dim))
+ context = self.text_embedding(context)
+
+ if self.has_image_input:
+ x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w)
+ clip_embdding = self.img_emb(clip_feature)
+ context = torch.cat([clip_embdding, context], dim=1)
+
+ x, (f, h, w) = self.patchify(x)
+
+ freqs = torch.cat([
+ self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
+ self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
+ self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
+ ], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+ return custom_forward
+
+ # Context Parallel
+ chunks = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)
+ pad_shape = chunks[0].shape[1] - chunks[-1].shape[1]
+ chunks = [torch.nn.functional.pad(chunk, (0, 0, 0, chunks[0].shape[1]-chunk.shape[1]), value=0) for chunk in chunks]
+ x = chunks[get_sequence_parallel_rank()]
+
+ for block in self.blocks:
+ if self.training and use_gradient_checkpointing:
+ if use_gradient_checkpointing_offload:
+ with torch.autograd.graph.save_on_cpu():
+ x = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(block),
+ x, context, t_mod, freqs,
+ use_reentrant=False,
+ )
+ else:
+ x = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(block),
+ x, context, t_mod, freqs,
+ use_reentrant=False,
+ )
+ else:
+ x = block(x, context, t_mod, freqs)
+
+ x = self.head(x, t)
+
+ # Context Parallel
+ x = get_sp_group().all_gather(x, dim=1)
+ x = x[:, :-pad_shape] if pad_shape > 0 else x
+
+ # unpatchify
+ x = self.unpatchify(x, (f, h, w))
+ return x
+
+
+def usp_attn_forward(self, x, freqs):
+ q = self.norm_q(self.q(x))
+ k = self.norm_k(self.k(x))
+ v = self.v(x)
+
+ q = rope_apply(q, freqs, self.num_heads)
+ k = rope_apply(k, freqs, self.num_heads)
+ q = rearrange(q, "b s (n d) -> b s n d", n=self.num_heads)
+ k = rearrange(k, "b s (n d) -> b s n d", n=self.num_heads)
+ v = rearrange(v, "b s (n d) -> b s n d", n=self.num_heads)
+
+ x = xFuserLongContextAttention()(
+ None,
+ query=q,
+ key=k,
+ value=v,
+ )
+ x = x.flatten(2)
+
+ del q, k, v
+ torch.cuda.empty_cache()
+ return self.o(x)
\ No newline at end of file
diff --git a/diffsynth/extensions/ESRGAN/__init__.py b/diffsynth/extensions/ESRGAN/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..94aff4c6fe8d75ff65e30d672dbe3e38a0d919c3
--- /dev/null
+++ b/diffsynth/extensions/ESRGAN/__init__.py
@@ -0,0 +1,137 @@
+import torch
+from einops import repeat
+from PIL import Image
+import numpy as np
+
+
+class ResidualDenseBlock(torch.nn.Module):
+
+ def __init__(self, num_feat=64, num_grow_ch=32):
+ super(ResidualDenseBlock, self).__init__()
+ self.conv1 = torch.nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
+ self.conv2 = torch.nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
+ self.conv3 = torch.nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
+ self.conv4 = torch.nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
+ self.conv5 = torch.nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
+ self.lrelu = torch.nn.LeakyReLU(negative_slope=0.2, inplace=True)
+
+ def forward(self, x):
+ x1 = self.lrelu(self.conv1(x))
+ x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
+ x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
+ x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
+ x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
+ return x5 * 0.2 + x
+
+
+class RRDB(torch.nn.Module):
+
+ def __init__(self, num_feat, num_grow_ch=32):
+ super(RRDB, self).__init__()
+ self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
+ self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
+ self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)
+
+ def forward(self, x):
+ out = self.rdb1(x)
+ out = self.rdb2(out)
+ out = self.rdb3(out)
+ return out * 0.2 + x
+
+
+class RRDBNet(torch.nn.Module):
+
+ def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, **kwargs):
+ super(RRDBNet, self).__init__()
+ self.conv_first = torch.nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
+ self.body = torch.torch.nn.Sequential(*[RRDB(num_feat=num_feat, num_grow_ch=num_grow_ch) for _ in range(num_block)])
+ self.conv_body = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ # upsample
+ self.conv_up1 = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ self.conv_up2 = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ self.conv_hr = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ self.conv_last = torch.nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
+ self.lrelu = torch.nn.LeakyReLU(negative_slope=0.2, inplace=True)
+
+ def forward(self, x):
+ feat = x
+ feat = self.conv_first(feat)
+ body_feat = self.conv_body(self.body(feat))
+ feat = feat + body_feat
+ # upsample
+ feat = repeat(feat, "B C H W -> B C (H 2) (W 2)")
+ feat = self.lrelu(self.conv_up1(feat))
+ feat = repeat(feat, "B C H W -> B C (H 2) (W 2)")
+ feat = self.lrelu(self.conv_up2(feat))
+ out = self.conv_last(self.lrelu(self.conv_hr(feat)))
+ return out
+
+ @staticmethod
+ def state_dict_converter():
+ return RRDBNetStateDictConverter()
+
+
+class RRDBNetStateDictConverter:
+ def __init__(self):
+ pass
+
+ def from_diffusers(self, state_dict):
+ return state_dict, {"upcast_to_float32": True}
+
+ def from_civitai(self, state_dict):
+ return state_dict, {"upcast_to_float32": True}
+
+
+class ESRGAN(torch.nn.Module):
+ def __init__(self, model):
+ super().__init__()
+ self.model = model
+
+ @staticmethod
+ def from_model_manager(model_manager):
+ return ESRGAN(model_manager.fetch_model("esrgan"))
+
+ def process_image(self, image):
+ image = torch.Tensor(np.array(image, dtype=np.float32) / 255).permute(2, 0, 1)
+ return image
+
+ def process_images(self, images):
+ images = [self.process_image(image) for image in images]
+ images = torch.stack(images)
+ return images
+
+ def decode_images(self, images):
+ images = (images.permute(0, 2, 3, 1) * 255).clip(0, 255).numpy().astype(np.uint8)
+ images = [Image.fromarray(image) for image in images]
+ return images
+
+ @torch.no_grad()
+ def upscale(self, images, batch_size=4, progress_bar=lambda x:x):
+ if not isinstance(images, list):
+ images = [images]
+ is_single_image = True
+ else:
+ is_single_image = False
+
+ # Preprocess
+ input_tensor = self.process_images(images)
+
+ # Interpolate
+ output_tensor = []
+ for batch_id in progress_bar(range(0, input_tensor.shape[0], batch_size)):
+ batch_id_ = min(batch_id + batch_size, input_tensor.shape[0])
+ batch_input_tensor = input_tensor[batch_id: batch_id_]
+ batch_input_tensor = batch_input_tensor.to(
+ device=self.model.conv_first.weight.device,
+ dtype=self.model.conv_first.weight.dtype)
+ batch_output_tensor = self.model(batch_input_tensor)
+ output_tensor.append(batch_output_tensor.cpu())
+
+ # Output
+ output_tensor = torch.concat(output_tensor, dim=0)
+
+ # To images
+ output_images = self.decode_images(output_tensor)
+ if is_single_image:
+ output_images = output_images[0]
+ return output_images
diff --git a/diffsynth/extensions/FastBlend/__init__.py b/diffsynth/extensions/FastBlend/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2bf812c2085082bfa82658dd249ebca89e9fb465
--- /dev/null
+++ b/diffsynth/extensions/FastBlend/__init__.py
@@ -0,0 +1,63 @@
+from .runners.fast import TableManager, PyramidPatchMatcher
+from PIL import Image
+import numpy as np
+import cupy as cp
+
+
+class FastBlendSmoother:
+ def __init__(self):
+ self.batch_size = 8
+ self.window_size = 64
+ self.ebsynth_config = {
+ "minimum_patch_size": 5,
+ "threads_per_block": 8,
+ "num_iter": 5,
+ "gpu_id": 0,
+ "guide_weight": 10.0,
+ "initialize": "identity",
+ "tracking_window_size": 0,
+ }
+
+ @staticmethod
+ def from_model_manager(model_manager):
+ # TODO: fetch GPU ID from model_manager
+ return FastBlendSmoother()
+
+ def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config):
+ frames_guide = [np.array(frame) for frame in frames_guide]
+ frames_style = [np.array(frame) for frame in frames_style]
+ table_manager = TableManager()
+ patch_match_engine = PyramidPatchMatcher(
+ image_height=frames_style[0].shape[0],
+ image_width=frames_style[0].shape[1],
+ channel=3,
+ **ebsynth_config
+ )
+ # left part
+ table_l = table_manager.build_remapping_table(frames_guide, frames_style, patch_match_engine, batch_size, desc="FastBlend Step 1/4")
+ table_l = table_manager.remapping_table_to_blending_table(table_l)
+ table_l = table_manager.process_window_sum(frames_guide, table_l, patch_match_engine, window_size, batch_size, desc="FastBlend Step 2/4")
+ # right part
+ table_r = table_manager.build_remapping_table(frames_guide[::-1], frames_style[::-1], patch_match_engine, batch_size, desc="FastBlend Step 3/4")
+ table_r = table_manager.remapping_table_to_blending_table(table_r)
+ table_r = table_manager.process_window_sum(frames_guide[::-1], table_r, patch_match_engine, window_size, batch_size, desc="FastBlend Step 4/4")[::-1]
+ # merge
+ frames = []
+ for (frame_l, weight_l), frame_m, (frame_r, weight_r) in zip(table_l, frames_style, table_r):
+ weight_m = -1
+ weight = weight_l + weight_m + weight_r
+ frame = frame_l * (weight_l / weight) + frame_m * (weight_m / weight) + frame_r * (weight_r / weight)
+ frames.append(frame)
+ frames = [Image.fromarray(frame.clip(0, 255).astype("uint8")) for frame in frames]
+ return frames
+
+ def __call__(self, rendered_frames, original_frames=None, **kwargs):
+ frames = self.run(
+ original_frames, rendered_frames,
+ self.batch_size, self.window_size, self.ebsynth_config
+ )
+ mempool = cp.get_default_memory_pool()
+ pinned_mempool = cp.get_default_pinned_memory_pool()
+ mempool.free_all_blocks()
+ pinned_mempool.free_all_blocks()
+ return frames
\ No newline at end of file
diff --git a/diffsynth/extensions/FastBlend/api.py b/diffsynth/extensions/FastBlend/api.py
new file mode 100644
index 0000000000000000000000000000000000000000..2db24330e375ed62065af54613b6ab956c9c64cf
--- /dev/null
+++ b/diffsynth/extensions/FastBlend/api.py
@@ -0,0 +1,397 @@
+from .runners import AccurateModeRunner, FastModeRunner, BalancedModeRunner, InterpolationModeRunner, InterpolationModeSingleFrameRunner
+from .data import VideoData, get_video_fps, save_video, search_for_images
+import os
+import gradio as gr
+
+
+def check_input_for_blending(video_guide, video_guide_folder, video_style, video_style_folder):
+ frames_guide = VideoData(video_guide, video_guide_folder)
+ frames_style = VideoData(video_style, video_style_folder)
+ message = ""
+ if len(frames_guide) < len(frames_style):
+ message += f"The number of frames mismatches. Only the first {len(frames_guide)} frames of style video will be used.\n"
+ frames_style.set_length(len(frames_guide))
+ elif len(frames_guide) > len(frames_style):
+ message += f"The number of frames mismatches. Only the first {len(frames_style)} frames of guide video will be used.\n"
+ frames_guide.set_length(len(frames_style))
+ height_guide, width_guide = frames_guide.shape()
+ height_style, width_style = frames_style.shape()
+ if height_guide != height_style or width_guide != width_style:
+ message += f"The shape of frames mismatches. The frames in style video will be resized to (height: {height_guide}, width: {width_guide})\n"
+ frames_style.set_shape(height_guide, width_guide)
+ return frames_guide, frames_style, message
+
+
+def smooth_video(
+ video_guide,
+ video_guide_folder,
+ video_style,
+ video_style_folder,
+ mode,
+ window_size,
+ batch_size,
+ tracking_window_size,
+ output_path,
+ fps,
+ minimum_patch_size,
+ num_iter,
+ guide_weight,
+ initialize,
+ progress = None,
+):
+ # input
+ frames_guide, frames_style, message = check_input_for_blending(video_guide, video_guide_folder, video_style, video_style_folder)
+ if len(message) > 0:
+ print(message)
+ # output
+ if output_path == "":
+ if video_style is None:
+ output_path = os.path.join(video_style_folder, "output")
+ else:
+ output_path = os.path.join(os.path.split(video_style)[0], "output")
+ os.makedirs(output_path, exist_ok=True)
+ print("No valid output_path. Your video will be saved here:", output_path)
+ elif not os.path.exists(output_path):
+ os.makedirs(output_path, exist_ok=True)
+ print("Your video will be saved here:", output_path)
+ frames_path = os.path.join(output_path, "frames")
+ video_path = os.path.join(output_path, "video.mp4")
+ os.makedirs(frames_path, exist_ok=True)
+ # process
+ if mode == "Fast" or mode == "Balanced":
+ tracking_window_size = 0
+ ebsynth_config = {
+ "minimum_patch_size": minimum_patch_size,
+ "threads_per_block": 8,
+ "num_iter": num_iter,
+ "gpu_id": 0,
+ "guide_weight": guide_weight,
+ "initialize": initialize,
+ "tracking_window_size": tracking_window_size,
+ }
+ if mode == "Fast":
+ FastModeRunner().run(frames_guide, frames_style, batch_size=batch_size, window_size=window_size, ebsynth_config=ebsynth_config, save_path=frames_path)
+ elif mode == "Balanced":
+ BalancedModeRunner().run(frames_guide, frames_style, batch_size=batch_size, window_size=window_size, ebsynth_config=ebsynth_config, save_path=frames_path)
+ elif mode == "Accurate":
+ AccurateModeRunner().run(frames_guide, frames_style, batch_size=batch_size, window_size=window_size, ebsynth_config=ebsynth_config, save_path=frames_path)
+ # output
+ try:
+ fps = int(fps)
+ except:
+ fps = get_video_fps(video_style) if video_style is not None else 30
+ print("Fps:", fps)
+ print("Saving video...")
+ video_path = save_video(frames_path, video_path, num_frames=len(frames_style), fps=fps)
+ print("Success!")
+ print("Your frames are here:", frames_path)
+ print("Your video is here:", video_path)
+ return output_path, fps, video_path
+
+
+class KeyFrameMatcher:
+ def __init__(self):
+ pass
+
+ def extract_number_from_filename(self, file_name):
+ result = []
+ number = -1
+ for i in file_name:
+ if ord(i)>=ord("0") and ord(i)<=ord("9"):
+ if number == -1:
+ number = 0
+ number = number*10 + ord(i) - ord("0")
+ else:
+ if number != -1:
+ result.append(number)
+ number = -1
+ if number != -1:
+ result.append(number)
+ result = tuple(result)
+ return result
+
+ def extract_number_from_filenames(self, file_names):
+ numbers = [self.extract_number_from_filename(file_name) for file_name in file_names]
+ min_length = min(len(i) for i in numbers)
+ for i in range(min_length-1, -1, -1):
+ if len(set(number[i] for number in numbers))==len(file_names):
+ return [number[i] for number in numbers]
+ return list(range(len(file_names)))
+
+ def match_using_filename(self, file_names_a, file_names_b):
+ file_names_b_set = set(file_names_b)
+ matched_file_name = []
+ for file_name in file_names_a:
+ if file_name not in file_names_b_set:
+ matched_file_name.append(None)
+ else:
+ matched_file_name.append(file_name)
+ return matched_file_name
+
+ def match_using_numbers(self, file_names_a, file_names_b):
+ numbers_a = self.extract_number_from_filenames(file_names_a)
+ numbers_b = self.extract_number_from_filenames(file_names_b)
+ numbers_b_dict = {number: file_name for number, file_name in zip(numbers_b, file_names_b)}
+ matched_file_name = []
+ for number in numbers_a:
+ if number in numbers_b_dict:
+ matched_file_name.append(numbers_b_dict[number])
+ else:
+ matched_file_name.append(None)
+ return matched_file_name
+
+ def match_filenames(self, file_names_a, file_names_b):
+ matched_file_name = self.match_using_filename(file_names_a, file_names_b)
+ if sum([i is not None for i in matched_file_name]) > 0:
+ return matched_file_name
+ matched_file_name = self.match_using_numbers(file_names_a, file_names_b)
+ return matched_file_name
+
+
+def detect_frames(frames_path, keyframes_path):
+ if not os.path.exists(frames_path) and not os.path.exists(keyframes_path):
+ return "Please input the directory of guide video and rendered frames"
+ elif not os.path.exists(frames_path):
+ return "Please input the directory of guide video"
+ elif not os.path.exists(keyframes_path):
+ return "Please input the directory of rendered frames"
+ frames = [os.path.split(i)[-1] for i in search_for_images(frames_path)]
+ keyframes = [os.path.split(i)[-1] for i in search_for_images(keyframes_path)]
+ if len(frames)==0:
+ return f"No images detected in {frames_path}"
+ if len(keyframes)==0:
+ return f"No images detected in {keyframes_path}"
+ matched_keyframes = KeyFrameMatcher().match_filenames(frames, keyframes)
+ max_filename_length = max([len(i) for i in frames])
+ if sum([i is not None for i in matched_keyframes])==0:
+ message = ""
+ for frame, matched_keyframe in zip(frames, matched_keyframes):
+ message += frame + " " * (max_filename_length - len(frame) + 1)
+ message += "--> No matched keyframes\n"
+ else:
+ message = ""
+ for frame, matched_keyframe in zip(frames, matched_keyframes):
+ message += frame + " " * (max_filename_length - len(frame) + 1)
+ if matched_keyframe is None:
+ message += "--> [to be rendered]\n"
+ else:
+ message += f"--> {matched_keyframe}\n"
+ return message
+
+
+def check_input_for_interpolating(frames_path, keyframes_path):
+ # search for images
+ frames = [os.path.split(i)[-1] for i in search_for_images(frames_path)]
+ keyframes = [os.path.split(i)[-1] for i in search_for_images(keyframes_path)]
+ # match frames
+ matched_keyframes = KeyFrameMatcher().match_filenames(frames, keyframes)
+ file_list = [file_name for file_name in matched_keyframes if file_name is not None]
+ index_style = [i for i, file_name in enumerate(matched_keyframes) if file_name is not None]
+ frames_guide = VideoData(None, frames_path)
+ frames_style = VideoData(None, keyframes_path, file_list=file_list)
+ # match shape
+ message = ""
+ height_guide, width_guide = frames_guide.shape()
+ height_style, width_style = frames_style.shape()
+ if height_guide != height_style or width_guide != width_style:
+ message += f"The shape of frames mismatches. The rendered keyframes will be resized to (height: {height_guide}, width: {width_guide})\n"
+ frames_style.set_shape(height_guide, width_guide)
+ return frames_guide, frames_style, index_style, message
+
+
+def interpolate_video(
+ frames_path,
+ keyframes_path,
+ output_path,
+ fps,
+ batch_size,
+ tracking_window_size,
+ minimum_patch_size,
+ num_iter,
+ guide_weight,
+ initialize,
+ progress = None,
+):
+ # input
+ frames_guide, frames_style, index_style, message = check_input_for_interpolating(frames_path, keyframes_path)
+ if len(message) > 0:
+ print(message)
+ # output
+ if output_path == "":
+ output_path = os.path.join(keyframes_path, "output")
+ os.makedirs(output_path, exist_ok=True)
+ print("No valid output_path. Your video will be saved here:", output_path)
+ elif not os.path.exists(output_path):
+ os.makedirs(output_path, exist_ok=True)
+ print("Your video will be saved here:", output_path)
+ output_frames_path = os.path.join(output_path, "frames")
+ output_video_path = os.path.join(output_path, "video.mp4")
+ os.makedirs(output_frames_path, exist_ok=True)
+ # process
+ ebsynth_config = {
+ "minimum_patch_size": minimum_patch_size,
+ "threads_per_block": 8,
+ "num_iter": num_iter,
+ "gpu_id": 0,
+ "guide_weight": guide_weight,
+ "initialize": initialize,
+ "tracking_window_size": tracking_window_size
+ }
+ if len(index_style)==1:
+ InterpolationModeSingleFrameRunner().run(frames_guide, frames_style, index_style, batch_size=batch_size, ebsynth_config=ebsynth_config, save_path=output_frames_path)
+ else:
+ InterpolationModeRunner().run(frames_guide, frames_style, index_style, batch_size=batch_size, ebsynth_config=ebsynth_config, save_path=output_frames_path)
+ try:
+ fps = int(fps)
+ except:
+ fps = 30
+ print("Fps:", fps)
+ print("Saving video...")
+ video_path = save_video(output_frames_path, output_video_path, num_frames=len(frames_guide), fps=fps)
+ print("Success!")
+ print("Your frames are here:", output_frames_path)
+ print("Your video is here:", video_path)
+ return output_path, fps, video_path
+
+
+def on_ui_tabs():
+ with gr.Blocks(analytics_enabled=False) as ui_component:
+ with gr.Tab("Blend"):
+ gr.Markdown("""
+# Blend
+
+Given a guide video and a style video, this algorithm will make the style video fluent according to the motion features of the guide video. Click [here](https://github.com/Artiprocher/sd-webui-fastblend/assets/35051019/208d902d-6aba-48d7-b7d5-cd120ebd306d) to see the example. Note that this extension doesn't support long videos. Please use short videos (e.g., several seconds). The algorithm is mainly designed for 512*512 resolution. Please use a larger `Minimum patch size` for higher resolution.
+ """)
+ with gr.Row():
+ with gr.Column():
+ with gr.Tab("Guide video"):
+ video_guide = gr.Video(label="Guide video")
+ with gr.Tab("Guide video (images format)"):
+ video_guide_folder = gr.Textbox(label="Guide video (images format)", value="")
+ with gr.Column():
+ with gr.Tab("Style video"):
+ video_style = gr.Video(label="Style video")
+ with gr.Tab("Style video (images format)"):
+ video_style_folder = gr.Textbox(label="Style video (images format)", value="")
+ with gr.Column():
+ output_path = gr.Textbox(label="Output directory", value="", placeholder="Leave empty to use the directory of style video")
+ fps = gr.Textbox(label="Fps", value="", placeholder="Leave empty to use the default fps")
+ video_output = gr.Video(label="Output video", interactive=False, show_share_button=True)
+ btn = gr.Button(value="Blend")
+ with gr.Row():
+ with gr.Column():
+ gr.Markdown("# Settings")
+ mode = gr.Radio(["Fast", "Balanced", "Accurate"], label="Inference mode", value="Fast", interactive=True)
+ window_size = gr.Slider(label="Sliding window size", value=15, minimum=1, maximum=1000, step=1, interactive=True)
+ batch_size = gr.Slider(label="Batch size", value=8, minimum=1, maximum=128, step=1, interactive=True)
+ tracking_window_size = gr.Slider(label="Tracking window size (only for accurate mode)", value=0, minimum=0, maximum=10, step=1, interactive=True)
+ gr.Markdown("## Advanced Settings")
+ minimum_patch_size = gr.Slider(label="Minimum patch size (odd number)", value=5, minimum=5, maximum=99, step=2, interactive=True)
+ num_iter = gr.Slider(label="Number of iterations", value=5, minimum=1, maximum=10, step=1, interactive=True)
+ guide_weight = gr.Slider(label="Guide weight", value=10.0, minimum=0.0, maximum=100.0, step=0.1, interactive=True)
+ initialize = gr.Radio(["identity", "random"], label="NNF initialization", value="identity", interactive=True)
+ with gr.Column():
+ gr.Markdown("""
+# Reference
+
+* Output directory: the directory to save the video.
+* Inference mode
+
+|Mode|Time|Memory|Quality|Frame by frame output|Description|
+|-|-|-|-|-|-|
+|Fast|■|■■■|■■|No|Blend the frames using a tree-like data structure, which requires much RAM but is fast.|
+|Balanced|■■|■|■■|Yes|Blend the frames naively.|
+|Accurate|■■■|■|■■■|Yes|Blend the frames and align them together for higher video quality. When [batch size] >= [sliding window size] * 2 + 1, the performance is the best.|
+
+* Sliding window size: our algorithm will blend the frames in a sliding windows. If the size is n, each frame will be blended with the last n frames and the next n frames. A large sliding window can make the video fluent but sometimes smoggy.
+* Batch size: a larger batch size makes the program faster but requires more VRAM.
+* Tracking window size (only for accurate mode): The size of window in which our algorithm tracks moving objects. Empirically, 1 is enough.
+* Advanced settings
+ * Minimum patch size (odd number): the minimum patch size used for patch matching. (Default: 5)
+ * Number of iterations: the number of iterations of patch matching. (Default: 5)
+ * Guide weight: a parameter that determines how much motion feature applied to the style video. (Default: 10)
+ * NNF initialization: how to initialize the NNF (Nearest Neighbor Field). (Default: identity)
+ """)
+ btn.click(
+ smooth_video,
+ inputs=[
+ video_guide,
+ video_guide_folder,
+ video_style,
+ video_style_folder,
+ mode,
+ window_size,
+ batch_size,
+ tracking_window_size,
+ output_path,
+ fps,
+ minimum_patch_size,
+ num_iter,
+ guide_weight,
+ initialize
+ ],
+ outputs=[output_path, fps, video_output]
+ )
+ with gr.Tab("Interpolate"):
+ gr.Markdown("""
+# Interpolate
+
+Given a guide video and some rendered keyframes, this algorithm will render the remaining frames. Click [here](https://github.com/Artiprocher/sd-webui-fastblend/assets/35051019/3490c5b4-8f67-478f-86de-f9adc2ace16a) to see the example. The algorithm is experimental and is only tested for 512*512 resolution.
+ """)
+ with gr.Row():
+ with gr.Column():
+ with gr.Row():
+ with gr.Column():
+ video_guide_folder_ = gr.Textbox(label="Guide video (images format)", value="")
+ with gr.Column():
+ rendered_keyframes_ = gr.Textbox(label="Rendered keyframes (images format)", value="")
+ with gr.Row():
+ detected_frames = gr.Textbox(label="Detected frames", value="Please input the directory of guide video and rendered frames", lines=9, max_lines=9, interactive=False)
+ video_guide_folder_.change(detect_frames, inputs=[video_guide_folder_, rendered_keyframes_], outputs=detected_frames)
+ rendered_keyframes_.change(detect_frames, inputs=[video_guide_folder_, rendered_keyframes_], outputs=detected_frames)
+ with gr.Column():
+ output_path_ = gr.Textbox(label="Output directory", value="", placeholder="Leave empty to use the directory of rendered keyframes")
+ fps_ = gr.Textbox(label="Fps", value="", placeholder="Leave empty to use the default fps")
+ video_output_ = gr.Video(label="Output video", interactive=False, show_share_button=True)
+ btn_ = gr.Button(value="Interpolate")
+ with gr.Row():
+ with gr.Column():
+ gr.Markdown("# Settings")
+ batch_size_ = gr.Slider(label="Batch size", value=8, minimum=1, maximum=128, step=1, interactive=True)
+ tracking_window_size_ = gr.Slider(label="Tracking window size", value=0, minimum=0, maximum=10, step=1, interactive=True)
+ gr.Markdown("## Advanced Settings")
+ minimum_patch_size_ = gr.Slider(label="Minimum patch size (odd number, larger is better)", value=15, minimum=5, maximum=99, step=2, interactive=True)
+ num_iter_ = gr.Slider(label="Number of iterations", value=5, minimum=1, maximum=10, step=1, interactive=True)
+ guide_weight_ = gr.Slider(label="Guide weight", value=10.0, minimum=0.0, maximum=100.0, step=0.1, interactive=True)
+ initialize_ = gr.Radio(["identity", "random"], label="NNF initialization", value="identity", interactive=True)
+ with gr.Column():
+ gr.Markdown("""
+# Reference
+
+* Output directory: the directory to save the video.
+* Batch size: a larger batch size makes the program faster but requires more VRAM.
+* Tracking window size (only for accurate mode): The size of window in which our algorithm tracks moving objects. Empirically, 1 is enough.
+* Advanced settings
+ * Minimum patch size (odd number): the minimum patch size used for patch matching. **This parameter should be larger than that in blending. (Default: 15)**
+ * Number of iterations: the number of iterations of patch matching. (Default: 5)
+ * Guide weight: a parameter that determines how much motion feature applied to the style video. (Default: 10)
+ * NNF initialization: how to initialize the NNF (Nearest Neighbor Field). (Default: identity)
+ """)
+ btn_.click(
+ interpolate_video,
+ inputs=[
+ video_guide_folder_,
+ rendered_keyframes_,
+ output_path_,
+ fps_,
+ batch_size_,
+ tracking_window_size_,
+ minimum_patch_size_,
+ num_iter_,
+ guide_weight_,
+ initialize_,
+ ],
+ outputs=[output_path_, fps_, video_output_]
+ )
+
+ return [(ui_component, "FastBlend", "FastBlend_ui")]
diff --git a/diffsynth/extensions/FastBlend/cupy_kernels.py b/diffsynth/extensions/FastBlend/cupy_kernels.py
new file mode 100644
index 0000000000000000000000000000000000000000..70e2790a2c67a2dd537f4188b38ebfc785f1fb34
--- /dev/null
+++ b/diffsynth/extensions/FastBlend/cupy_kernels.py
@@ -0,0 +1,119 @@
+import cupy as cp
+
+remapping_kernel = cp.RawKernel(r'''
+extern "C" __global__
+void remap(
+ const int height,
+ const int width,
+ const int channel,
+ const int patch_size,
+ const int pad_size,
+ const float* source_style,
+ const int* nnf,
+ float* target_style
+) {
+ const int r = (patch_size - 1) / 2;
+ const int x = blockDim.x * blockIdx.x + threadIdx.x;
+ const int y = blockDim.y * blockIdx.y + threadIdx.y;
+ if (x >= height or y >= width) return;
+ const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel;
+ const int pid = (x + pad_size) * (width + pad_size * 2) + (y + pad_size);
+ const int min_px = x < r ? -x : -r;
+ const int max_px = x + r > height - 1 ? height - 1 - x : r;
+ const int min_py = y < r ? -y : -r;
+ const int max_py = y + r > width - 1 ? width - 1 - y : r;
+ int num = 0;
+ for (int px = min_px; px <= max_px; px++){
+ for (int py = min_py; py <= max_py; py++){
+ const int nid = (x + px) * width + y + py;
+ const int x_ = nnf[blockIdx.z * height * width * 2 + nid*2 + 0] - px;
+ const int y_ = nnf[blockIdx.z * height * width * 2 + nid*2 + 1] - py;
+ if (x_ < 0 or y_ < 0 or x_ >= height or y_ >= width)continue;
+ const int pid_ = (x_ + pad_size) * (width + pad_size * 2) + (y_ + pad_size);
+ num++;
+ for (int c = 0; c < channel; c++){
+ target_style[z + pid * channel + c] += source_style[z + pid_ * channel + c];
+ }
+ }
+ }
+ for (int c = 0; c < channel; c++){
+ target_style[z + pid * channel + c] /= num;
+ }
+}
+''', 'remap')
+
+
+patch_error_kernel = cp.RawKernel(r'''
+extern "C" __global__
+void patch_error(
+ const int height,
+ const int width,
+ const int channel,
+ const int patch_size,
+ const int pad_size,
+ const float* source,
+ const int* nnf,
+ const float* target,
+ float* error
+) {
+ const int r = (patch_size - 1) / 2;
+ const int x = blockDim.x * blockIdx.x + threadIdx.x;
+ const int y = blockDim.y * blockIdx.y + threadIdx.y;
+ const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel;
+ if (x >= height or y >= width) return;
+ const int x_ = nnf[blockIdx.z * height * width * 2 + (x * width + y)*2 + 0];
+ const int y_ = nnf[blockIdx.z * height * width * 2 + (x * width + y)*2 + 1];
+ float e = 0;
+ for (int px = -r; px <= r; px++){
+ for (int py = -r; py <= r; py++){
+ const int pid = (x + pad_size + px) * (width + pad_size * 2) + y + pad_size + py;
+ const int pid_ = (x_ + pad_size + px) * (width + pad_size * 2) + y_ + pad_size + py;
+ for (int c = 0; c < channel; c++){
+ const float diff = target[z + pid * channel + c] - source[z + pid_ * channel + c];
+ e += diff * diff;
+ }
+ }
+ }
+ error[blockIdx.z * height * width + x * width + y] = e;
+}
+''', 'patch_error')
+
+
+pairwise_patch_error_kernel = cp.RawKernel(r'''
+extern "C" __global__
+void pairwise_patch_error(
+ const int height,
+ const int width,
+ const int channel,
+ const int patch_size,
+ const int pad_size,
+ const float* source_a,
+ const int* nnf_a,
+ const float* source_b,
+ const int* nnf_b,
+ float* error
+) {
+ const int r = (patch_size - 1) / 2;
+ const int x = blockDim.x * blockIdx.x + threadIdx.x;
+ const int y = blockDim.y * blockIdx.y + threadIdx.y;
+ const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel;
+ if (x >= height or y >= width) return;
+ const int z_nnf = blockIdx.z * height * width * 2 + (x * width + y) * 2;
+ const int x_a = nnf_a[z_nnf + 0];
+ const int y_a = nnf_a[z_nnf + 1];
+ const int x_b = nnf_b[z_nnf + 0];
+ const int y_b = nnf_b[z_nnf + 1];
+ float e = 0;
+ for (int px = -r; px <= r; px++){
+ for (int py = -r; py <= r; py++){
+ const int pid_a = (x_a + pad_size + px) * (width + pad_size * 2) + y_a + pad_size + py;
+ const int pid_b = (x_b + pad_size + px) * (width + pad_size * 2) + y_b + pad_size + py;
+ for (int c = 0; c < channel; c++){
+ const float diff = source_a[z + pid_a * channel + c] - source_b[z + pid_b * channel + c];
+ e += diff * diff;
+ }
+ }
+ }
+ error[blockIdx.z * height * width + x * width + y] = e;
+}
+''', 'pairwise_patch_error')
diff --git a/diffsynth/extensions/FastBlend/data.py b/diffsynth/extensions/FastBlend/data.py
new file mode 100644
index 0000000000000000000000000000000000000000..dcaddd77de9eaf208cd083dd522e5eaa6b58f783
--- /dev/null
+++ b/diffsynth/extensions/FastBlend/data.py
@@ -0,0 +1,146 @@
+import imageio, os
+import numpy as np
+from PIL import Image
+
+
+def read_video(file_name):
+ reader = imageio.get_reader(file_name)
+ video = []
+ for frame in reader:
+ frame = np.array(frame)
+ video.append(frame)
+ reader.close()
+ return video
+
+
+def get_video_fps(file_name):
+ reader = imageio.get_reader(file_name)
+ fps = reader.get_meta_data()["fps"]
+ reader.close()
+ return fps
+
+
+def save_video(frames_path, video_path, num_frames, fps):
+ writer = imageio.get_writer(video_path, fps=fps, quality=9)
+ for i in range(num_frames):
+ frame = np.array(Image.open(os.path.join(frames_path, "%05d.png" % i)))
+ writer.append_data(frame)
+ writer.close()
+ return video_path
+
+
+class LowMemoryVideo:
+ def __init__(self, file_name):
+ self.reader = imageio.get_reader(file_name)
+
+ def __len__(self):
+ return self.reader.count_frames()
+
+ def __getitem__(self, item):
+ return np.array(self.reader.get_data(item))
+
+ def __del__(self):
+ self.reader.close()
+
+
+def split_file_name(file_name):
+ result = []
+ number = -1
+ for i in file_name:
+ if ord(i)>=ord("0") and ord(i)<=ord("9"):
+ if number == -1:
+ number = 0
+ number = number*10 + ord(i) - ord("0")
+ else:
+ if number != -1:
+ result.append(number)
+ number = -1
+ result.append(i)
+ if number != -1:
+ result.append(number)
+ result = tuple(result)
+ return result
+
+
+def search_for_images(folder):
+ file_list = [i for i in os.listdir(folder) if i.endswith(".jpg") or i.endswith(".png")]
+ file_list = [(split_file_name(file_name), file_name) for file_name in file_list]
+ file_list = [i[1] for i in sorted(file_list)]
+ file_list = [os.path.join(folder, i) for i in file_list]
+ return file_list
+
+
+def read_images(folder):
+ file_list = search_for_images(folder)
+ frames = [np.array(Image.open(i)) for i in file_list]
+ return frames
+
+
+class LowMemoryImageFolder:
+ def __init__(self, folder, file_list=None):
+ if file_list is None:
+ self.file_list = search_for_images(folder)
+ else:
+ self.file_list = [os.path.join(folder, file_name) for file_name in file_list]
+
+ def __len__(self):
+ return len(self.file_list)
+
+ def __getitem__(self, item):
+ return np.array(Image.open(self.file_list[item]))
+
+ def __del__(self):
+ pass
+
+
+class VideoData:
+ def __init__(self, video_file, image_folder, **kwargs):
+ if video_file is not None:
+ self.data_type = "video"
+ self.data = LowMemoryVideo(video_file, **kwargs)
+ elif image_folder is not None:
+ self.data_type = "images"
+ self.data = LowMemoryImageFolder(image_folder, **kwargs)
+ else:
+ raise ValueError("Cannot open video or image folder")
+ self.length = None
+ self.height = None
+ self.width = None
+
+ def raw_data(self):
+ frames = []
+ for i in range(self.__len__()):
+ frames.append(self.__getitem__(i))
+ return frames
+
+ def set_length(self, length):
+ self.length = length
+
+ def set_shape(self, height, width):
+ self.height = height
+ self.width = width
+
+ def __len__(self):
+ if self.length is None:
+ return len(self.data)
+ else:
+ return self.length
+
+ def shape(self):
+ if self.height is not None and self.width is not None:
+ return self.height, self.width
+ else:
+ height, width, _ = self.__getitem__(0).shape
+ return height, width
+
+ def __getitem__(self, item):
+ frame = self.data.__getitem__(item)
+ height, width, _ = frame.shape
+ if self.height is not None and self.width is not None:
+ if self.height != height or self.width != width:
+ frame = Image.fromarray(frame).resize((self.width, self.height))
+ frame = np.array(frame)
+ return frame
+
+ def __del__(self):
+ pass
diff --git a/diffsynth/extensions/FastBlend/patch_match.py b/diffsynth/extensions/FastBlend/patch_match.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ba60036af43aa89e4811223af8f9f63b87ab6e1
--- /dev/null
+++ b/diffsynth/extensions/FastBlend/patch_match.py
@@ -0,0 +1,299 @@
+from .cupy_kernels import remapping_kernel, patch_error_kernel, pairwise_patch_error_kernel
+import numpy as np
+import cupy as cp
+import cv2
+import torch
+import torch.nn.functional as F
+
+class PatchMatcher:
+ def __init__(
+ self, height, width, channel, minimum_patch_size,
+ threads_per_block=8, num_iter=5, gpu_id=0, guide_weight=10.0,
+ random_search_steps=3, random_search_range=4,
+ use_mean_target_style=False, use_pairwise_patch_error=False,
+ tracking_window_size=0
+ ):
+ self.height = height
+ self.width = width
+ self.channel = channel
+ self.minimum_patch_size = minimum_patch_size
+ self.threads_per_block = threads_per_block
+ self.num_iter = num_iter
+ self.gpu_id = gpu_id
+ self.guide_weight = guide_weight
+ self.random_search_steps = random_search_steps
+ self.random_search_range = random_search_range
+ self.use_mean_target_style = use_mean_target_style
+ self.use_pairwise_patch_error = use_pairwise_patch_error
+ self.tracking_window_size = tracking_window_size
+
+ self.patch_size_list = [minimum_patch_size + i*2 for i in range(num_iter)][::-1]
+ self.pad_size = self.patch_size_list[0] // 2
+ self.grid = (
+ (height + threads_per_block - 1) // threads_per_block,
+ (width + threads_per_block - 1) // threads_per_block
+ )
+ self.block = (threads_per_block, threads_per_block)
+
+ def pad_image(self, image):
+ return cp.pad(image, ((0, 0), (self.pad_size, self.pad_size), (self.pad_size, self.pad_size), (0, 0)))
+
+ def unpad_image(self, image):
+ return image[:, self.pad_size: -self.pad_size, self.pad_size: -self.pad_size, :]
+
+ def apply_nnf_to_image(self, nnf, source):
+ batch_size = source.shape[0]
+ target = cp.zeros((batch_size, self.height + self.pad_size * 2, self.width + self.pad_size * 2, self.channel), dtype=cp.float32)
+ remapping_kernel(
+ self.grid + (batch_size,),
+ self.block,
+ (self.height, self.width, self.channel, self.patch_size, self.pad_size, source, nnf, target)
+ )
+ return target
+
+ def get_patch_error(self, source, nnf, target):
+ batch_size = source.shape[0]
+ error = cp.zeros((batch_size, self.height, self.width), dtype=cp.float32)
+ patch_error_kernel(
+ self.grid + (batch_size,),
+ self.block,
+ (self.height, self.width, self.channel, self.patch_size, self.pad_size, source, nnf, target, error)
+ )
+ return error
+
+ def get_pairwise_patch_error(self, source, nnf):
+ batch_size = source.shape[0]//2
+ error = cp.zeros((batch_size, self.height, self.width), dtype=cp.float32)
+ source_a, nnf_a = source[0::2].copy(), nnf[0::2].copy()
+ source_b, nnf_b = source[1::2].copy(), nnf[1::2].copy()
+ pairwise_patch_error_kernel(
+ self.grid + (batch_size,),
+ self.block,
+ (self.height, self.width, self.channel, self.patch_size, self.pad_size, source_a, nnf_a, source_b, nnf_b, error)
+ )
+ error = error.repeat(2, axis=0)
+ return error
+
+ def get_error(self, source_guide, target_guide, source_style, target_style, nnf):
+ error_guide = self.get_patch_error(source_guide, nnf, target_guide)
+ if self.use_mean_target_style:
+ target_style = self.apply_nnf_to_image(nnf, source_style)
+ target_style = target_style.mean(axis=0, keepdims=True)
+ target_style = target_style.repeat(source_guide.shape[0], axis=0)
+ if self.use_pairwise_patch_error:
+ error_style = self.get_pairwise_patch_error(source_style, nnf)
+ else:
+ error_style = self.get_patch_error(source_style, nnf, target_style)
+ error = error_guide * self.guide_weight + error_style
+ return error
+
+ def clamp_bound(self, nnf):
+ nnf[:,:,:,0] = cp.clip(nnf[:,:,:,0], 0, self.height-1)
+ nnf[:,:,:,1] = cp.clip(nnf[:,:,:,1], 0, self.width-1)
+ return nnf
+
+ def random_step(self, nnf, r):
+ batch_size = nnf.shape[0]
+ step = cp.random.randint(-r, r+1, size=(batch_size, self.height, self.width, 2), dtype=cp.int32)
+ upd_nnf = self.clamp_bound(nnf + step)
+ return upd_nnf
+
+ def neighboor_step(self, nnf, d):
+ if d==0:
+ upd_nnf = cp.concatenate([nnf[:, :1, :], nnf[:, :-1, :]], axis=1)
+ upd_nnf[:, :, :, 0] += 1
+ elif d==1:
+ upd_nnf = cp.concatenate([nnf[:, :, :1], nnf[:, :, :-1]], axis=2)
+ upd_nnf[:, :, :, 1] += 1
+ elif d==2:
+ upd_nnf = cp.concatenate([nnf[:, 1:, :], nnf[:, -1:, :]], axis=1)
+ upd_nnf[:, :, :, 0] -= 1
+ elif d==3:
+ upd_nnf = cp.concatenate([nnf[:, :, 1:], nnf[:, :, -1:]], axis=2)
+ upd_nnf[:, :, :, 1] -= 1
+ upd_nnf = self.clamp_bound(upd_nnf)
+ return upd_nnf
+
+ def shift_nnf(self, nnf, d):
+ if d>0:
+ d = min(nnf.shape[0], d)
+ upd_nnf = cp.concatenate([nnf[d:]] + [nnf[-1:]] * d, axis=0)
+ else:
+ d = max(-nnf.shape[0], d)
+ upd_nnf = cp.concatenate([nnf[:1]] * (-d) + [nnf[:d]], axis=0)
+ return upd_nnf
+
+ def track_step(self, nnf, d):
+ if self.use_pairwise_patch_error:
+ upd_nnf = cp.zeros_like(nnf)
+ upd_nnf[0::2] = self.shift_nnf(nnf[0::2], d)
+ upd_nnf[1::2] = self.shift_nnf(nnf[1::2], d)
+ else:
+ upd_nnf = self.shift_nnf(nnf, d)
+ return upd_nnf
+
+ def C(self, n, m):
+ # not used
+ c = 1
+ for i in range(1, n+1):
+ c *= i
+ for i in range(1, m+1):
+ c //= i
+ for i in range(1, n-m+1):
+ c //= i
+ return c
+
+ def bezier_step(self, nnf, r):
+ # not used
+ n = r * 2 - 1
+ upd_nnf = cp.zeros(shape=nnf.shape, dtype=cp.float32)
+ for i, d in enumerate(list(range(-r, 0)) + list(range(1, r+1))):
+ if d>0:
+ ctl_nnf = cp.concatenate([nnf[d:]] + [nnf[-1:]] * d, axis=0)
+ elif d<0:
+ ctl_nnf = cp.concatenate([nnf[:1]] * (-d) + [nnf[:d]], axis=0)
+ upd_nnf += ctl_nnf * (self.C(n, i) / 2**n)
+ upd_nnf = self.clamp_bound(upd_nnf).astype(nnf.dtype)
+ return upd_nnf
+
+ def update(self, source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf):
+ upd_err = self.get_error(source_guide, target_guide, source_style, target_style, upd_nnf)
+ upd_idx = (upd_err < err)
+ nnf[upd_idx] = upd_nnf[upd_idx]
+ err[upd_idx] = upd_err[upd_idx]
+ return nnf, err
+
+ def propagation(self, source_guide, target_guide, source_style, target_style, nnf, err):
+ for d in cp.random.permutation(4):
+ upd_nnf = self.neighboor_step(nnf, d)
+ nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
+ return nnf, err
+
+ def random_search(self, source_guide, target_guide, source_style, target_style, nnf, err):
+ for i in range(self.random_search_steps):
+ upd_nnf = self.random_step(nnf, self.random_search_range)
+ nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
+ return nnf, err
+
+ def track(self, source_guide, target_guide, source_style, target_style, nnf, err):
+ for d in range(1, self.tracking_window_size + 1):
+ upd_nnf = self.track_step(nnf, d)
+ nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
+ upd_nnf = self.track_step(nnf, -d)
+ nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
+ return nnf, err
+
+ def iteration(self, source_guide, target_guide, source_style, target_style, nnf, err):
+ nnf, err = self.propagation(source_guide, target_guide, source_style, target_style, nnf, err)
+ nnf, err = self.random_search(source_guide, target_guide, source_style, target_style, nnf, err)
+ nnf, err = self.track(source_guide, target_guide, source_style, target_style, nnf, err)
+ return nnf, err
+
+ def estimate_nnf(self, source_guide, target_guide, source_style, nnf):
+ with cp.cuda.Device(self.gpu_id):
+ source_guide = self.pad_image(source_guide)
+ target_guide = self.pad_image(target_guide)
+ source_style = self.pad_image(source_style)
+ for it in range(self.num_iter):
+ self.patch_size = self.patch_size_list[it]
+ target_style = self.apply_nnf_to_image(nnf, source_style)
+ err = self.get_error(source_guide, target_guide, source_style, target_style, nnf)
+ nnf, err = self.iteration(source_guide, target_guide, source_style, target_style, nnf, err)
+ target_style = self.unpad_image(self.apply_nnf_to_image(nnf, source_style))
+ return nnf, target_style
+
+
+class PyramidPatchMatcher:
+ def __init__(
+ self, image_height, image_width, channel, minimum_patch_size,
+ threads_per_block=8, num_iter=5, gpu_id=0, guide_weight=10.0,
+ use_mean_target_style=False, use_pairwise_patch_error=False,
+ tracking_window_size=0,
+ initialize="identity"
+ ):
+ maximum_patch_size = minimum_patch_size + (num_iter - 1) * 2
+ self.pyramid_level = int(np.log2(min(image_height, image_width) / maximum_patch_size))
+ self.pyramid_heights = []
+ self.pyramid_widths = []
+ self.patch_matchers = []
+ self.minimum_patch_size = minimum_patch_size
+ self.num_iter = num_iter
+ self.gpu_id = gpu_id
+ self.initialize = initialize
+ for level in range(self.pyramid_level):
+ height = image_height//(2**(self.pyramid_level - 1 - level))
+ width = image_width//(2**(self.pyramid_level - 1 - level))
+ self.pyramid_heights.append(height)
+ self.pyramid_widths.append(width)
+ self.patch_matchers.append(PatchMatcher(
+ height, width, channel, minimum_patch_size=minimum_patch_size,
+ threads_per_block=threads_per_block, num_iter=num_iter, gpu_id=gpu_id, guide_weight=guide_weight,
+ use_mean_target_style=use_mean_target_style, use_pairwise_patch_error=use_pairwise_patch_error,
+ tracking_window_size=tracking_window_size
+ ))
+
+ def resample_image(self, images, level):
+ height, width = self.pyramid_heights[level], self.pyramid_widths[level]
+ images_torch = torch.as_tensor(images, device='cuda', dtype=torch.float32)
+ images_torch = images_torch.permute(0, 3, 1, 2)
+ images_resample = F.interpolate(images_torch, size=(height, width), mode='area', align_corners=None)
+ images_resample = images_resample.permute(0, 2, 3, 1).contiguous()
+ return cp.asarray(images_resample)
+
+ def initialize_nnf(self, batch_size):
+ if self.initialize == "random":
+ height, width = self.pyramid_heights[0], self.pyramid_widths[0]
+ nnf = cp.stack([
+ cp.random.randint(0, height, (batch_size, height, width), dtype=cp.int32),
+ cp.random.randint(0, width, (batch_size, height, width), dtype=cp.int32)
+ ], axis=3)
+ elif self.initialize == "identity":
+ height, width = self.pyramid_heights[0], self.pyramid_widths[0]
+ nnf = cp.stack([
+ cp.repeat(cp.arange(height), width).reshape(height, width),
+ cp.tile(cp.arange(width), height).reshape(height, width)
+ ], axis=2)
+ nnf = cp.stack([nnf] * batch_size)
+ else:
+ raise NotImplementedError()
+ return nnf
+
+ def update_nnf(self, nnf, level):
+ # upscale
+ nnf = nnf.repeat(2, axis=1).repeat(2, axis=2) * 2
+ nnf[:, 1::2, :, 0] += 1
+ nnf[:, :, 1::2, 1] += 1
+ # check if scale is 2
+ height, width = self.pyramid_heights[level], self.pyramid_widths[level]
+ if height != nnf.shape[0] * 2 or width != nnf.shape[1] * 2:
+ nnf_torch = torch.as_tensor(nnf, device='cuda', dtype=torch.float32)
+ nnf_torch = nnf_torch.permute(0, 3, 1, 2)
+ nnf_resized = F.interpolate(nnf_torch, size=(height, width), mode='bilinear', align_corners=False)
+ nnf_resized = nnf_resized.permute(0, 2, 3, 1)
+ nnf = cp.asarray(nnf_resized).astype(cp.int32)
+ nnf = self.patch_matchers[level].clamp_bound(nnf)
+ return nnf
+
+ def apply_nnf_to_image(self, nnf, image):
+ with cp.cuda.Device(self.gpu_id):
+ image = self.patch_matchers[-1].pad_image(image)
+ image = self.patch_matchers[-1].apply_nnf_to_image(nnf, image)
+ return image
+
+ def estimate_nnf(self, source_guide, target_guide, source_style):
+ with cp.cuda.Device(self.gpu_id):
+ if not isinstance(source_guide, cp.ndarray):
+ source_guide = cp.array(source_guide, dtype=cp.float32)
+ if not isinstance(target_guide, cp.ndarray):
+ target_guide = cp.array(target_guide, dtype=cp.float32)
+ if not isinstance(source_style, cp.ndarray):
+ source_style = cp.array(source_style, dtype=cp.float32)
+ for level in range(self.pyramid_level):
+ nnf = self.initialize_nnf(source_guide.shape[0]) if level==0 else self.update_nnf(nnf, level)
+ source_guide_ = self.resample_image(source_guide, level)
+ target_guide_ = self.resample_image(target_guide, level)
+ source_style_ = self.resample_image(source_style, level)
+ nnf, target_style = self.patch_matchers[level].estimate_nnf(
+ source_guide_, target_guide_, source_style_, nnf
+ )
+ return nnf.get(), target_style.get()
diff --git a/diffsynth/extensions/FastBlend/runners/__init__.py b/diffsynth/extensions/FastBlend/runners/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..078382729690d282436411661693ce22f3dcc033
--- /dev/null
+++ b/diffsynth/extensions/FastBlend/runners/__init__.py
@@ -0,0 +1,4 @@
+from .accurate import AccurateModeRunner
+from .fast import FastModeRunner
+from .balanced import BalancedModeRunner
+from .interpolation import InterpolationModeRunner, InterpolationModeSingleFrameRunner
diff --git a/diffsynth/extensions/FastBlend/runners/accurate.py b/diffsynth/extensions/FastBlend/runners/accurate.py
new file mode 100644
index 0000000000000000000000000000000000000000..2e4a47f1981ebc1ec9a034a814dfc1130955c2e1
--- /dev/null
+++ b/diffsynth/extensions/FastBlend/runners/accurate.py
@@ -0,0 +1,35 @@
+from ..patch_match import PyramidPatchMatcher
+import os
+import numpy as np
+from PIL import Image
+from tqdm import tqdm
+
+
+class AccurateModeRunner:
+ def __init__(self):
+ pass
+
+ def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config, desc="Accurate Mode", save_path=None):
+ patch_match_engine = PyramidPatchMatcher(
+ image_height=frames_style[0].shape[0],
+ image_width=frames_style[0].shape[1],
+ channel=3,
+ use_mean_target_style=True,
+ **ebsynth_config
+ )
+ # run
+ n = len(frames_style)
+ for target in tqdm(range(n), desc=desc):
+ l, r = max(target - window_size, 0), min(target + window_size + 1, n)
+ remapped_frames = []
+ for i in range(l, r, batch_size):
+ j = min(i + batch_size, r)
+ source_guide = np.stack([frames_guide[source] for source in range(i, j)])
+ target_guide = np.stack([frames_guide[target]] * (j - i))
+ source_style = np.stack([frames_style[source] for source in range(i, j)])
+ _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
+ remapped_frames.append(target_style)
+ frame = np.concatenate(remapped_frames, axis=0).mean(axis=0)
+ frame = frame.clip(0, 255).astype("uint8")
+ if save_path is not None:
+ Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % target))
\ No newline at end of file
diff --git a/diffsynth/extensions/FastBlend/runners/balanced.py b/diffsynth/extensions/FastBlend/runners/balanced.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c9a2bb7e438b49c89d0786e858ccf03302fab35
--- /dev/null
+++ b/diffsynth/extensions/FastBlend/runners/balanced.py
@@ -0,0 +1,46 @@
+from ..patch_match import PyramidPatchMatcher
+import os
+import numpy as np
+from PIL import Image
+from tqdm import tqdm
+
+
+class BalancedModeRunner:
+ def __init__(self):
+ pass
+
+ def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config, desc="Balanced Mode", save_path=None):
+ patch_match_engine = PyramidPatchMatcher(
+ image_height=frames_style[0].shape[0],
+ image_width=frames_style[0].shape[1],
+ channel=3,
+ **ebsynth_config
+ )
+ # tasks
+ n = len(frames_style)
+ tasks = []
+ for target in range(n):
+ for source in range(target - window_size, target + window_size + 1):
+ if source >= 0 and source < n and source != target:
+ tasks.append((source, target))
+ # run
+ frames = [(None, 1) for i in range(n)]
+ for batch_id in tqdm(range(0, len(tasks), batch_size), desc=desc):
+ tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
+ source_guide = np.stack([frames_guide[source] for source, target in tasks_batch])
+ target_guide = np.stack([frames_guide[target] for source, target in tasks_batch])
+ source_style = np.stack([frames_style[source] for source, target in tasks_batch])
+ _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
+ for (source, target), result in zip(tasks_batch, target_style):
+ frame, weight = frames[target]
+ if frame is None:
+ frame = frames_style[target]
+ frames[target] = (
+ frame * (weight / (weight + 1)) + result / (weight + 1),
+ weight + 1
+ )
+ if weight + 1 == min(n, target + window_size + 1) - max(0, target - window_size):
+ frame = frame.clip(0, 255).astype("uint8")
+ if save_path is not None:
+ Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % target))
+ frames[target] = (None, 1)
diff --git a/diffsynth/extensions/FastBlend/runners/fast.py b/diffsynth/extensions/FastBlend/runners/fast.py
new file mode 100644
index 0000000000000000000000000000000000000000..2ba5731475ab875929b14181e0c22f4fd466c591
--- /dev/null
+++ b/diffsynth/extensions/FastBlend/runners/fast.py
@@ -0,0 +1,141 @@
+from ..patch_match import PyramidPatchMatcher
+import functools, os
+import numpy as np
+from PIL import Image
+from tqdm import tqdm
+
+
+class TableManager:
+ def __init__(self):
+ pass
+
+ def task_list(self, n):
+ tasks = []
+ max_level = 1
+ while (1<=n:
+ break
+ meta_data = {
+ "source": i,
+ "target": j,
+ "level": level + 1
+ }
+ tasks.append(meta_data)
+ tasks.sort(key=functools.cmp_to_key(lambda u, v: u["level"]-v["level"]))
+ return tasks
+
+ def build_remapping_table(self, frames_guide, frames_style, patch_match_engine, batch_size, desc=""):
+ n = len(frames_guide)
+ tasks = self.task_list(n)
+ remapping_table = [[(frames_style[i], 1)] for i in range(n)]
+ for batch_id in tqdm(range(0, len(tasks), batch_size), desc=desc):
+ tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
+ source_guide = np.stack([frames_guide[task["source"]] for task in tasks_batch])
+ target_guide = np.stack([frames_guide[task["target"]] for task in tasks_batch])
+ source_style = np.stack([frames_style[task["source"]] for task in tasks_batch])
+ _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
+ for task, result in zip(tasks_batch, target_style):
+ target, level = task["target"], task["level"]
+ if len(remapping_table[target])==level:
+ remapping_table[target].append((result, 1))
+ else:
+ frame, weight = remapping_table[target][level]
+ remapping_table[target][level] = (
+ frame * (weight / (weight + 1)) + result / (weight + 1),
+ weight + 1
+ )
+ return remapping_table
+
+ def remapping_table_to_blending_table(self, table):
+ for i in range(len(table)):
+ for j in range(1, len(table[i])):
+ frame_1, weight_1 = table[i][j-1]
+ frame_2, weight_2 = table[i][j]
+ frame = (frame_1 + frame_2) / 2
+ weight = weight_1 + weight_2
+ table[i][j] = (frame, weight)
+ return table
+
+ def tree_query(self, leftbound, rightbound):
+ node_list = []
+ node_index = rightbound
+ while node_index>=leftbound:
+ node_level = 0
+ while (1<=leftbound:
+ node_level += 1
+ node_list.append((node_index, node_level))
+ node_index -= 1<0:
+ tasks = []
+ for m in range(index_style[0]):
+ tasks.append((index_style[0], m, index_style[0]))
+ task_group.append(tasks)
+ # middle frames
+ for l, r in zip(index_style[:-1], index_style[1:]):
+ tasks = []
+ for m in range(l, r):
+ tasks.append((l, m, r))
+ task_group.append(tasks)
+ # last frame
+ tasks = []
+ for m in range(index_style[-1], n):
+ tasks.append((index_style[-1], m, index_style[-1]))
+ task_group.append(tasks)
+ return task_group
+
+ def run(self, frames_guide, frames_style, index_style, batch_size, ebsynth_config, save_path=None):
+ patch_match_engine = PyramidPatchMatcher(
+ image_height=frames_style[0].shape[0],
+ image_width=frames_style[0].shape[1],
+ channel=3,
+ use_mean_target_style=False,
+ use_pairwise_patch_error=True,
+ **ebsynth_config
+ )
+ # task
+ index_dict = self.get_index_dict(index_style)
+ task_group = self.get_task_group(index_style, len(frames_guide))
+ # run
+ for tasks in task_group:
+ index_start, index_end = min([i[1] for i in tasks]), max([i[1] for i in tasks])
+ for batch_id in tqdm(range(0, len(tasks), batch_size), desc=f"Rendering frames {index_start}...{index_end}"):
+ tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
+ source_guide, target_guide, source_style = [], [], []
+ for l, m, r in tasks_batch:
+ # l -> m
+ source_guide.append(frames_guide[l])
+ target_guide.append(frames_guide[m])
+ source_style.append(frames_style[index_dict[l]])
+ # r -> m
+ source_guide.append(frames_guide[r])
+ target_guide.append(frames_guide[m])
+ source_style.append(frames_style[index_dict[r]])
+ source_guide = np.stack(source_guide)
+ target_guide = np.stack(target_guide)
+ source_style = np.stack(source_style)
+ _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
+ if save_path is not None:
+ for frame_l, frame_r, (l, m, r) in zip(target_style[0::2], target_style[1::2], tasks_batch):
+ weight_l, weight_r = self.get_weight(l, m, r)
+ frame = frame_l * weight_l + frame_r * weight_r
+ frame = frame.clip(0, 255).astype("uint8")
+ Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % m))
+
+
+class InterpolationModeSingleFrameRunner:
+ def __init__(self):
+ pass
+
+ def run(self, frames_guide, frames_style, index_style, batch_size, ebsynth_config, save_path=None):
+ # check input
+ tracking_window_size = ebsynth_config["tracking_window_size"]
+ if tracking_window_size * 2 >= batch_size:
+ raise ValueError("batch_size should be larger than track_window_size * 2")
+ frame_style = frames_style[0]
+ frame_guide = frames_guide[index_style[0]]
+ patch_match_engine = PyramidPatchMatcher(
+ image_height=frame_style.shape[0],
+ image_width=frame_style.shape[1],
+ channel=3,
+ **ebsynth_config
+ )
+ # run
+ frame_id, n = 0, len(frames_guide)
+ for i in tqdm(range(0, n, batch_size - tracking_window_size * 2), desc=f"Rendering frames 0...{n}"):
+ if i + batch_size > n:
+ l, r = max(n - batch_size, 0), n
+ else:
+ l, r = i, i + batch_size
+ source_guide = np.stack([frame_guide] * (r-l))
+ target_guide = np.stack([frames_guide[i] for i in range(l, r)])
+ source_style = np.stack([frame_style] * (r-l))
+ _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
+ for i, frame in zip(range(l, r), target_style):
+ if i==frame_id:
+ frame = frame.clip(0, 255).astype("uint8")
+ Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % frame_id))
+ frame_id += 1
+ if r < n and r-frame_id <= tracking_window_size:
+ break
diff --git a/diffsynth/extensions/ImageQualityMetric/BLIP/__init__.py b/diffsynth/extensions/ImageQualityMetric/BLIP/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..885dcf8f76ad77865054f0c033f8541ae08b1e04
--- /dev/null
+++ b/diffsynth/extensions/ImageQualityMetric/BLIP/__init__.py
@@ -0,0 +1 @@
+from .blip_pretrain import *
diff --git a/diffsynth/extensions/ImageQualityMetric/BLIP/blip.py b/diffsynth/extensions/ImageQualityMetric/BLIP/blip.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b24c3c17fdeff6949c3692164362abb8d8d0989
--- /dev/null
+++ b/diffsynth/extensions/ImageQualityMetric/BLIP/blip.py
@@ -0,0 +1,77 @@
+'''
+ * Adapted from BLIP (https://github.com/salesforce/BLIP)
+'''
+
+import warnings
+warnings.filterwarnings("ignore")
+
+import torch
+import os
+from urllib.parse import urlparse
+from timm.models.hub import download_cached_file
+from transformers import BertTokenizer
+from .vit import VisionTransformer, interpolate_pos_embed
+
+
+def default_bert():
+ current_dir = os.path.dirname(os.path.abspath(__file__))
+ project_root = os.path.abspath(os.path.join(current_dir, '../../../../'))
+ model_path = os.path.join(project_root, 'models', 'QualityMetric')
+ return os.path.join(model_path, "bert-base-uncased")
+
+
+def init_tokenizer(bert_model_path):
+ tokenizer = BertTokenizer.from_pretrained(bert_model_path)
+ tokenizer.add_special_tokens({'bos_token':'[DEC]'})
+ tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']})
+ tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]
+ return tokenizer
+
+
+def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0):
+
+ assert vit in ['base', 'large'], "vit parameter must be base or large"
+ if vit=='base':
+ vision_width = 768
+ visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12,
+ num_heads=12, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
+ drop_path_rate=0 or drop_path_rate
+ )
+ elif vit=='large':
+ vision_width = 1024
+ visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=24,
+ num_heads=16, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
+ drop_path_rate=0.1 or drop_path_rate
+ )
+ return visual_encoder, vision_width
+
+
+def is_url(url_or_filename):
+ parsed = urlparse(url_or_filename)
+ return parsed.scheme in ("http", "https")
+
+def load_checkpoint(model,url_or_filename):
+ if is_url(url_or_filename):
+ cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
+ checkpoint = torch.load(cached_file, map_location='cpu')
+ elif os.path.isfile(url_or_filename):
+ checkpoint = torch.load(url_or_filename, map_location='cpu')
+ else:
+ raise RuntimeError('checkpoint url or path is invalid')
+
+ state_dict = checkpoint['model']
+
+ state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)
+ if 'visual_encoder_m.pos_embed' in model.state_dict().keys():
+ state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'],
+ model.visual_encoder_m)
+ for key in model.state_dict().keys():
+ if key in state_dict.keys():
+ if state_dict[key].shape!=model.state_dict()[key].shape:
+ print(key, ": ", state_dict[key].shape, ', ', model.state_dict()[key].shape)
+ del state_dict[key]
+
+ msg = model.load_state_dict(state_dict,strict=False)
+ print('load checkpoint from %s'%url_or_filename)
+ return model,msg
+
diff --git a/diffsynth/extensions/ImageQualityMetric/BLIP/blip_pretrain.py b/diffsynth/extensions/ImageQualityMetric/BLIP/blip_pretrain.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba711e2776fd086190ca940248e022b4e083819a
--- /dev/null
+++ b/diffsynth/extensions/ImageQualityMetric/BLIP/blip_pretrain.py
@@ -0,0 +1,44 @@
+'''
+ * Adapted from BLIP (https://github.com/salesforce/BLIP)
+'''
+
+import transformers
+transformers.logging.set_verbosity_error()
+
+from torch import nn
+import os
+from .med import BertConfig, BertModel
+from .blip import create_vit, init_tokenizer
+
+class BLIP_Pretrain(nn.Module):
+ def __init__(self,
+ med_config = "med_config.json",
+ image_size = 224,
+ vit = 'base',
+ vit_grad_ckpt = False,
+ vit_ckpt_layer = 0,
+ embed_dim = 256,
+ queue_size = 57600,
+ momentum = 0.995,
+ bert_model_path = ""
+ ):
+ """
+ Args:
+ med_config (str): path for the mixture of encoder-decoder model's configuration file
+ image_size (int): input image size
+ vit (str): model size of vision transformer
+ """
+ super().__init__()
+
+ self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer, 0)
+
+ self.tokenizer = init_tokenizer(bert_model_path)
+ encoder_config = BertConfig.from_json_file(med_config)
+ encoder_config.encoder_width = vision_width
+ self.text_encoder = BertModel(config=encoder_config, add_pooling_layer=False)
+
+ text_width = self.text_encoder.config.hidden_size
+
+ self.vision_proj = nn.Linear(vision_width, embed_dim)
+ self.text_proj = nn.Linear(text_width, embed_dim)
+
diff --git a/diffsynth/extensions/ImageQualityMetric/BLIP/med.py b/diffsynth/extensions/ImageQualityMetric/BLIP/med.py
new file mode 100644
index 0000000000000000000000000000000000000000..5905a346c8f3a511fc56d1e052467dcef3246e4b
--- /dev/null
+++ b/diffsynth/extensions/ImageQualityMetric/BLIP/med.py
@@ -0,0 +1,947 @@
+'''
+ * Adapted from BLIP (https://github.com/salesforce/BLIP)
+ * Based on huggingface code base
+ * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
+'''
+
+import math
+from typing import Tuple
+
+import torch
+from torch import Tensor, device, nn
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import CrossEntropyLoss
+
+from transformers.activations import ACT2FN
+from transformers.file_utils import (
+ ModelOutput,
+)
+from transformers.modeling_outputs import (
+ BaseModelOutputWithPastAndCrossAttentions,
+ BaseModelOutputWithPoolingAndCrossAttentions,
+ CausalLMOutputWithCrossAttentions,
+ MaskedLMOutput,
+ MultipleChoiceModelOutput,
+ NextSentencePredictorOutput,
+ QuestionAnsweringModelOutput,
+ SequenceClassifierOutput,
+ TokenClassifierOutput,
+)
+from transformers.modeling_utils import (
+ PreTrainedModel,
+ apply_chunking_to_forward,
+ find_pruneable_heads_and_indices,
+ prune_linear_layer,
+)
+from transformers.utils import logging
+from transformers.models.bert.configuration_bert import BertConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+class BertEmbeddings(nn.Module):
+ """Construct the embeddings from word and position embeddings."""
+
+ def __init__(self, config):
+ super().__init__()
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
+
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
+ # any TensorFlow checkpoint file
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
+
+ self.config = config
+
+ def forward(
+ self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
+ ):
+ if input_ids is not None:
+ input_shape = input_ids.size()
+ else:
+ input_shape = inputs_embeds.size()[:-1]
+
+ seq_length = input_shape[1]
+
+ if position_ids is None:
+ position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
+
+ if inputs_embeds is None:
+ inputs_embeds = self.word_embeddings(input_ids)
+
+ embeddings = inputs_embeds
+
+ if self.position_embedding_type == "absolute":
+ position_embeddings = self.position_embeddings(position_ids)
+ embeddings += position_embeddings
+ embeddings = self.LayerNorm(embeddings)
+ embeddings = self.dropout(embeddings)
+ return embeddings
+
+
+class BertSelfAttention(nn.Module):
+ def __init__(self, config, is_cross_attention):
+ super().__init__()
+ self.config = config
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
+ raise ValueError(
+ "The hidden size (%d) is not a multiple of the number of attention "
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads)
+ )
+
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
+ if is_cross_attention:
+ self.key = nn.Linear(config.encoder_width, self.all_head_size)
+ self.value = nn.Linear(config.encoder_width, self.all_head_size)
+ else:
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
+
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
+ self.max_position_embeddings = config.max_position_embeddings
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
+ self.save_attention = False
+
+ def save_attn_gradients(self, attn_gradients):
+ self.attn_gradients = attn_gradients
+
+ def get_attn_gradients(self):
+ return self.attn_gradients
+
+ def save_attention_map(self, attention_map):
+ self.attention_map = attention_map
+
+ def get_attention_map(self):
+ return self.attention_map
+
+ def transpose_for_scores(self, x):
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
+ x = x.view(*new_x_shape)
+ return x.permute(0, 2, 1, 3)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_value=None,
+ output_attentions=False,
+ ):
+ mixed_query_layer = self.query(hidden_states)
+
+ # If this is instantiated as a cross-attention module, the keys
+ # and values come from an encoder; the attention mask needs to be
+ # such that the encoder's padding tokens are not attended to.
+ is_cross_attention = encoder_hidden_states is not None
+
+ if is_cross_attention:
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
+ attention_mask = encoder_attention_mask
+ elif past_key_value is not None:
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
+ else:
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+
+ query_layer = self.transpose_for_scores(mixed_query_layer)
+
+ past_key_value = (key_layer, value_layer)
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
+
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
+ seq_length = hidden_states.size()[1]
+ position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
+ position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
+ distance = position_ids_l - position_ids_r
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
+
+ if self.position_embedding_type == "relative_key":
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
+ attention_scores = attention_scores + relative_position_scores
+ elif self.position_embedding_type == "relative_key_query":
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
+
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
+ if attention_mask is not None:
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
+ attention_scores = attention_scores + attention_mask
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
+
+ if is_cross_attention and self.save_attention:
+ self.save_attention_map(attention_probs)
+ attention_probs.register_hook(self.save_attn_gradients)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs_dropped = self.dropout(attention_probs)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attention_probs_dropped = attention_probs_dropped * head_mask
+
+ context_layer = torch.matmul(attention_probs_dropped, value_layer)
+
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(*new_context_layer_shape)
+
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
+
+ outputs = outputs + (past_key_value,)
+ return outputs
+
+
+class BertSelfOutput(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states, input_tensor):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+class BertAttention(nn.Module):
+ def __init__(self, config, is_cross_attention=False):
+ super().__init__()
+ self.self = BertSelfAttention(config, is_cross_attention)
+ self.output = BertSelfOutput(config)
+ self.pruned_heads = set()
+
+ def prune_heads(self, heads):
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
+ )
+
+ # Prune linear layers
+ self.self.query = prune_linear_layer(self.self.query, index)
+ self.self.key = prune_linear_layer(self.self.key, index)
+ self.self.value = prune_linear_layer(self.self.value, index)
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+ # Update hyper params and store pruned heads
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_value=None,
+ output_attentions=False,
+ ):
+ self_outputs = self.self(
+ hidden_states,
+ attention_mask,
+ head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ past_key_value,
+ output_attentions,
+ )
+ attention_output = self.output(self_outputs[0], hidden_states)
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
+ return outputs
+
+
+class BertIntermediate(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.intermediate_act_fn = config.hidden_act
+
+ def forward(self, hidden_states):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+ return hidden_states
+
+
+class BertOutput(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states, input_tensor):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+class BertLayer(nn.Module):
+ def __init__(self, config, layer_num):
+ super().__init__()
+ self.config = config
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
+ self.seq_len_dim = 1
+ self.attention = BertAttention(config)
+ self.layer_num = layer_num
+ if self.config.add_cross_attention:
+ self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention)
+ self.intermediate = BertIntermediate(config)
+ self.output = BertOutput(config)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_value=None,
+ output_attentions=False,
+ mode=None,
+ ):
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
+ self_attention_outputs = self.attention(
+ hidden_states,
+ attention_mask,
+ head_mask,
+ output_attentions=output_attentions,
+ past_key_value=self_attn_past_key_value,
+ )
+ attention_output = self_attention_outputs[0]
+
+ outputs = self_attention_outputs[1:-1]
+ present_key_value = self_attention_outputs[-1]
+
+ if mode=='multimodal':
+ assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers"
+
+ cross_attention_outputs = self.crossattention(
+ attention_output,
+ attention_mask,
+ head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ output_attentions=output_attentions,
+ )
+ attention_output = cross_attention_outputs[0]
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
+ layer_output = apply_chunking_to_forward(
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
+ )
+ outputs = (layer_output,) + outputs
+
+ outputs = outputs + (present_key_value,)
+
+ return outputs
+
+ def feed_forward_chunk(self, attention_output):
+ intermediate_output = self.intermediate(attention_output)
+ layer_output = self.output(intermediate_output, attention_output)
+ return layer_output
+
+
+class BertEncoder(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.layer = nn.ModuleList([BertLayer(config,i) for i in range(config.num_hidden_layers)])
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_values=None,
+ use_cache=None,
+ output_attentions=False,
+ output_hidden_states=False,
+ return_dict=True,
+ mode='multimodal',
+ ):
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
+
+ next_decoder_cache = () if use_cache else None
+
+ for i in range(self.config.num_hidden_layers):
+ layer_module = self.layer[i]
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+ past_key_value = past_key_values[i] if past_key_values is not None else None
+
+ if self.gradient_checkpointing and self.training:
+
+ if use_cache:
+ logger.warning(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs, past_key_value, output_attentions)
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(layer_module),
+ hidden_states,
+ attention_mask,
+ layer_head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ mode=mode,
+ )
+ else:
+ layer_outputs = layer_module(
+ hidden_states,
+ attention_mask,
+ layer_head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ past_key_value,
+ output_attentions,
+ mode=mode,
+ )
+
+ hidden_states = layer_outputs[0]
+ if use_cache:
+ next_decoder_cache += (layer_outputs[-1],)
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [
+ hidden_states,
+ next_decoder_cache,
+ all_hidden_states,
+ all_self_attentions,
+ all_cross_attentions,
+ ]
+ if v is not None
+ )
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=next_decoder_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ cross_attentions=all_cross_attentions,
+ )
+
+
+class BertPooler(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.activation = nn.Tanh()
+
+ def forward(self, hidden_states):
+ # We "pool" the model by simply taking the hidden state corresponding
+ # to the first token.
+ first_token_tensor = hidden_states[:, 0]
+ pooled_output = self.dense(first_token_tensor)
+ pooled_output = self.activation(pooled_output)
+ return pooled_output
+
+
+class BertPredictionHeadTransform(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ if isinstance(config.hidden_act, str):
+ self.transform_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.transform_act_fn = config.hidden_act
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ def forward(self, hidden_states):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.transform_act_fn(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states)
+ return hidden_states
+
+
+class BertLMPredictionHead(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.transform = BertPredictionHeadTransform(config)
+
+ # The output weights are the same as the input embeddings, but there is
+ # an output-only bias for each token.
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
+
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
+ self.decoder.bias = self.bias
+
+ def forward(self, hidden_states):
+ hidden_states = self.transform(hidden_states)
+ hidden_states = self.decoder(hidden_states)
+ return hidden_states
+
+
+class BertOnlyMLMHead(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.predictions = BertLMPredictionHead(config)
+
+ def forward(self, sequence_output):
+ prediction_scores = self.predictions(sequence_output)
+ return prediction_scores
+
+
+class BertPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = BertConfig
+ base_model_prefix = "bert"
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
+
+ def _init_weights(self, module):
+ """ Initialize the weights """
+ if isinstance(module, (nn.Linear, nn.Embedding)):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ if isinstance(module, nn.Linear) and module.bias is not None:
+ module.bias.data.zero_()
+
+
+class BertModel(BertPreTrainedModel):
+ """
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
+ cross-attention is added between the self-attention layers, following the architecture described in `Attention is
+ all you need `__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
+ argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
+ input to the forward pass.
+ """
+
+ def __init__(self, config, add_pooling_layer=True):
+ super().__init__(config)
+ self.config = config
+
+ self.embeddings = BertEmbeddings(config)
+
+ self.encoder = BertEncoder(config)
+
+ self.pooler = BertPooler(config) if add_pooling_layer else None
+
+ self.init_weights()
+
+
+ def get_input_embeddings(self):
+ return self.embeddings.word_embeddings
+
+ def set_input_embeddings(self, value):
+ self.embeddings.word_embeddings = value
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+
+ def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool) -> Tensor:
+ """
+ Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
+
+ Arguments:
+ attention_mask (:obj:`torch.Tensor`):
+ Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
+ input_shape (:obj:`Tuple[int]`):
+ The shape of the input to the model.
+ device: (:obj:`torch.device`):
+ The device of the input to the model.
+
+ Returns:
+ :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
+ """
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
+ # ourselves in which case we just need to make it broadcastable to all heads.
+ if attention_mask.dim() == 3:
+ extended_attention_mask = attention_mask[:, None, :, :]
+ elif attention_mask.dim() == 2:
+ # Provided a padding mask of dimensions [batch_size, seq_length]
+ # - if the model is a decoder, apply a causal mask in addition to the padding mask
+ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
+ if is_decoder:
+ batch_size, seq_length = input_shape
+
+ seq_ids = torch.arange(seq_length, device=device)
+ causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
+ # in case past_key_values are used we need to add a prefix ones mask to the causal mask
+ # causal and attention masks must have same type with pytorch version < 1.3
+ causal_mask = causal_mask.to(attention_mask.dtype)
+
+ if causal_mask.shape[1] < attention_mask.shape[1]:
+ prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
+ causal_mask = torch.cat(
+ [
+ torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype),
+ causal_mask,
+ ],
+ axis=-1,
+ )
+
+ extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
+ else:
+ extended_attention_mask = attention_mask[:, None, None, :]
+ else:
+ raise ValueError(
+ "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
+ input_shape, attention_mask.shape
+ )
+ )
+
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
+ # masked positions, this operation will create a tensor which is 0.0 for
+ # positions we want to attend and -10000.0 for masked positions.
+ # Since we are adding it to the raw scores before the softmax, this is
+ # effectively the same as removing these entirely.
+ extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
+ return extended_attention_mask
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ encoder_embeds=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_values=None,
+ use_cache=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ is_decoder=False,
+ mode='multimodal',
+ ):
+ r"""
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
+ the model is configured as a decoder.
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
+ use_cache (:obj:`bool`, `optional`):
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
+ decoding (see :obj:`past_key_values`).
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if is_decoder:
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ else:
+ use_cache = False
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ input_shape = input_ids.size()
+ batch_size, seq_length = input_shape
+ device = input_ids.device
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ batch_size, seq_length = input_shape
+ device = inputs_embeds.device
+ elif encoder_embeds is not None:
+ input_shape = encoder_embeds.size()[:-1]
+ batch_size, seq_length = input_shape
+ device = encoder_embeds.device
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds or encoder_embeds")
+
+ # past_key_values_length
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
+
+ if attention_mask is None:
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
+
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
+ # ourselves in which case we just need to make it broadcastable to all heads.
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape,
+ device, is_decoder)
+
+ # If a 2D or 3D attention mask is provided for the cross-attention
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
+ if encoder_hidden_states is not None:
+ if type(encoder_hidden_states) == list:
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
+ else:
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
+
+ if type(encoder_attention_mask) == list:
+ encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
+ elif encoder_attention_mask is None:
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
+ else:
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
+ else:
+ encoder_extended_attention_mask = None
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+ if encoder_embeds is None:
+ embedding_output = self.embeddings(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ inputs_embeds=inputs_embeds,
+ past_key_values_length=past_key_values_length,
+ )
+ else:
+ embedding_output = encoder_embeds
+
+ encoder_outputs = self.encoder(
+ embedding_output,
+ attention_mask=extended_attention_mask,
+ head_mask=head_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_extended_attention_mask,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ mode=mode,
+ )
+ sequence_output = encoder_outputs[0]
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
+
+ if not return_dict:
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
+
+ return BaseModelOutputWithPoolingAndCrossAttentions(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ past_key_values=encoder_outputs.past_key_values,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ cross_attentions=encoder_outputs.cross_attentions,
+ )
+
+
+
+class BertLMHeadModel(BertPreTrainedModel):
+
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
+
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.bert = BertModel(config, add_pooling_layer=False)
+ self.cls = BertOnlyMLMHead(config)
+
+ self.init_weights()
+
+ def get_output_embeddings(self):
+ return self.cls.predictions.decoder
+
+ def set_output_embeddings(self, new_embeddings):
+ self.cls.predictions.decoder = new_embeddings
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ labels=None,
+ past_key_values=None,
+ use_cache=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ return_logits=False,
+ is_decoder=True,
+ reduction='mean',
+ mode='multimodal',
+ ):
+ r"""
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
+ the model is configured as a decoder.
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
+ ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
+ ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
+ use_cache (:obj:`bool`, `optional`):
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
+ decoding (see :obj:`past_key_values`).
+ Returns:
+ Example::
+ >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
+ >>> import torch
+ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
+ >>> config = BertConfig.from_pretrained("bert-base-cased")
+ >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
+ >>> outputs = model(**inputs)
+ >>> prediction_logits = outputs.logits
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ if labels is not None:
+ use_cache = False
+
+ outputs = self.bert(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ is_decoder=is_decoder,
+ mode=mode,
+ )
+
+ sequence_output = outputs[0]
+ prediction_scores = self.cls(sequence_output)
+
+ if return_logits:
+ return prediction_scores[:, :-1, :].contiguous()
+
+ lm_loss = None
+ if labels is not None:
+ # we are doing next-token prediction; shift prediction scores and input ids by one
+ shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
+ labels = labels[:, 1:].contiguous()
+ loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
+ lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
+ if reduction=='none':
+ lm_loss = lm_loss.view(prediction_scores.size(0),-1).sum(1)
+
+ if not return_dict:
+ output = (prediction_scores,) + outputs[2:]
+ return ((lm_loss,) + output) if lm_loss is not None else output
+
+ return CausalLMOutputWithCrossAttentions(
+ loss=lm_loss,
+ logits=prediction_scores,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ cross_attentions=outputs.cross_attentions,
+ )
+
+ def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
+ input_shape = input_ids.shape
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
+ if attention_mask is None:
+ attention_mask = input_ids.new_ones(input_shape)
+
+ # cut decoder_input_ids if past is used
+ if past is not None:
+ input_ids = input_ids[:, -1:]
+
+ return {
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ "past_key_values": past,
+ "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
+ "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
+ "is_decoder": True,
+ }
+
+ def _reorder_cache(self, past, beam_idx):
+ reordered_past = ()
+ for layer_past in past:
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
+ return reordered_past
diff --git a/diffsynth/extensions/ImageQualityMetric/BLIP/vit.py b/diffsynth/extensions/ImageQualityMetric/BLIP/vit.py
new file mode 100644
index 0000000000000000000000000000000000000000..cef7b650a95f56266775cf0f18b28bc0f74987ab
--- /dev/null
+++ b/diffsynth/extensions/ImageQualityMetric/BLIP/vit.py
@@ -0,0 +1,301 @@
+'''
+ * Adapted from BLIP (https://github.com/salesforce/BLIP)
+ * Based on timm code base
+ * https://github.com/rwightman/pytorch-image-models/tree/master/timm
+'''
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from functools import partial
+
+from timm.models.vision_transformer import _cfg, PatchEmbed
+from timm.models.registry import register_model
+from timm.models.layers import trunc_normal_, DropPath
+from timm.models.helpers import named_apply, adapt_input_conv
+
+# from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
+
+class Mlp(nn.Module):
+ """ MLP as used in Vision Transformer, MLP-Mixer and related networks
+ """
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+class Attention(nn.Module):
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
+ self.scale = qk_scale or head_dim ** -0.5
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+ self.attn_gradients = None
+ self.attention_map = None
+
+ def save_attn_gradients(self, attn_gradients):
+ self.attn_gradients = attn_gradients
+
+ def get_attn_gradients(self):
+ return self.attn_gradients
+
+ def save_attention_map(self, attention_map):
+ self.attention_map = attention_map
+
+ def get_attention_map(self):
+ return self.attention_map
+
+ def forward(self, x, register_hook=False):
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
+
+ attn = (q @ k.transpose(-2, -1)) * self.scale
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ if register_hook:
+ self.save_attention_map(attn)
+ attn.register_hook(self.save_attn_gradients)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class Block(nn.Module):
+
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_grad_checkpointing=False):
+ super().__init__()
+ self.norm1 = norm_layer(dim)
+ self.attn = Attention(
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+ # if use_grad_checkpointing:
+ # self.attn = checkpoint_wrapper(self.attn)
+ # self.mlp = checkpoint_wrapper(self.mlp)
+
+ def forward(self, x, register_hook=False):
+ x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook))
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+ return x
+
+
+class VisionTransformer(nn.Module):
+ """ Vision Transformer
+ A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
+ https://arxiv.org/abs/2010.11929
+ """
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
+ num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None,
+ use_grad_checkpointing=False, ckpt_layer=0):
+ """
+ Args:
+ img_size (int, tuple): input image size
+ patch_size (int, tuple): patch size
+ in_chans (int): number of input channels
+ num_classes (int): number of classes for classification head
+ embed_dim (int): embedding dimension
+ depth (int): depth of transformer
+ num_heads (int): number of attention heads
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
+ qkv_bias (bool): enable bias for qkv if True
+ qk_scale (float): override default qk scale of head_dim ** -0.5 if set
+ representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
+ drop_rate (float): dropout rate
+ attn_drop_rate (float): attention dropout rate
+ drop_path_rate (float): stochastic depth rate
+ norm_layer: (nn.Module): normalization layer
+ """
+ super().__init__()
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
+
+ self.patch_embed = PatchEmbed(
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
+
+ num_patches = self.patch_embed.num_patches
+
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
+ self.pos_drop = nn.Dropout(p=drop_rate)
+
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
+ self.blocks = nn.ModuleList([
+ Block(
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
+ use_grad_checkpointing=(use_grad_checkpointing and i>=depth-ckpt_layer)
+ )
+ for i in range(depth)])
+ self.norm = norm_layer(embed_dim)
+
+ trunc_normal_(self.pos_embed, std=.02)
+ trunc_normal_(self.cls_token, std=.02)
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {'pos_embed', 'cls_token'}
+
+ def forward(self, x, register_blk=-1):
+ B = x.shape[0]
+ x = self.patch_embed(x)
+
+ cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
+ x = torch.cat((cls_tokens, x), dim=1)
+
+ x = x + self.pos_embed[:,:x.size(1),:]
+ x = self.pos_drop(x)
+
+ for i,blk in enumerate(self.blocks):
+ x = blk(x, register_blk==i)
+ x = self.norm(x)
+
+ return x
+
+ @torch.jit.ignore()
+ def load_pretrained(self, checkpoint_path, prefix=''):
+ _load_weights(self, checkpoint_path, prefix)
+
+
+@torch.no_grad()
+def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''):
+ """ Load weights from .npz checkpoints for official Google Brain Flax implementation
+ """
+ import numpy as np
+
+ def _n2p(w, t=True):
+ if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
+ w = w.flatten()
+ if t:
+ if w.ndim == 4:
+ w = w.transpose([3, 2, 0, 1])
+ elif w.ndim == 3:
+ w = w.transpose([2, 0, 1])
+ elif w.ndim == 2:
+ w = w.transpose([1, 0])
+ return torch.from_numpy(w)
+
+ w = np.load(checkpoint_path)
+ if not prefix and 'opt/target/embedding/kernel' in w:
+ prefix = 'opt/target/'
+
+ if hasattr(model.patch_embed, 'backbone'):
+ # hybrid
+ backbone = model.patch_embed.backbone
+ stem_only = not hasattr(backbone, 'stem')
+ stem = backbone if stem_only else backbone.stem
+ stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel'])))
+ stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale']))
+ stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias']))
+ if not stem_only:
+ for i, stage in enumerate(backbone.stages):
+ for j, block in enumerate(stage.blocks):
+ bp = f'{prefix}block{i + 1}/unit{j + 1}/'
+ for r in range(3):
+ getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel']))
+ getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale']))
+ getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias']))
+ if block.downsample is not None:
+ block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel']))
+ block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale']))
+ block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias']))
+ embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
+ else:
+ embed_conv_w = adapt_input_conv(
+ model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel']))
+ model.patch_embed.proj.weight.copy_(embed_conv_w)
+ model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
+ model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
+ pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
+ if pos_embed_w.shape != model.pos_embed.shape:
+ pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
+ pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
+ model.pos_embed.copy_(pos_embed_w)
+ model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
+ model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
+# if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
+# model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
+# model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
+# if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
+# model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
+# model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
+ for i, block in enumerate(model.blocks.children()):
+ block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
+ mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'
+ block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
+ block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
+ block.attn.qkv.weight.copy_(torch.cat([
+ _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
+ block.attn.qkv.bias.copy_(torch.cat([
+ _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
+ block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
+ block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
+ for r in range(2):
+ getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel']))
+ getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias']))
+ block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale']))
+ block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias']))
+
+
+def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder):
+ # interpolate position embedding
+ embedding_size = pos_embed_checkpoint.shape[-1]
+ num_patches = visual_encoder.patch_embed.num_patches
+ num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches
+ # height (== width) for the checkpoint position embedding
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
+ # height (== width) for the new position embedding
+ new_size = int(num_patches ** 0.5)
+
+ if orig_size!=new_size:
+ # class_token and dist_token are kept unchanged
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
+ # only the position tokens are interpolated
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
+ pos_tokens = torch.nn.functional.interpolate(
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
+ print('reshape position embedding from %d to %d'%(orig_size ** 2,new_size ** 2))
+
+ return new_pos_embed
+ else:
+ return pos_embed_checkpoint
\ No newline at end of file
diff --git a/diffsynth/extensions/ImageQualityMetric/__init__.py b/diffsynth/extensions/ImageQualityMetric/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..fcfb7c02b0ce2b6a2fbe345d87c31e0d1bb3a128
--- /dev/null
+++ b/diffsynth/extensions/ImageQualityMetric/__init__.py
@@ -0,0 +1,148 @@
+from modelscope import snapshot_download
+from typing_extensions import Literal, TypeAlias
+import os
+from diffsynth.extensions.ImageQualityMetric.aesthetic import AestheticScore
+from diffsynth.extensions.ImageQualityMetric.imagereward import ImageRewardScore
+from diffsynth.extensions.ImageQualityMetric.pickscore import PickScore
+from diffsynth.extensions.ImageQualityMetric.clip import CLIPScore
+from diffsynth.extensions.ImageQualityMetric.hps import HPScore_v2
+from diffsynth.extensions.ImageQualityMetric.mps import MPScore
+
+
+preference_model_id: TypeAlias = Literal[
+ "ImageReward",
+ "Aesthetic",
+ "PickScore",
+ "CLIP",
+ "HPSv2",
+ "HPSv2.1",
+ "MPS",
+]
+model_dict = {
+ "ImageReward": {
+ "model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained",
+ "allow_file_pattern": [
+ "ImageReward/ImageReward.safetensors",
+ "ImageReward/med_config.json",
+ "bert-base-uncased/config.json",
+ "bert-base-uncased/model.safetensors",
+ "bert-base-uncased/tokenizer.json",
+ "bert-base-uncased/tokenizer_config.json",
+ "bert-base-uncased/vocab.txt",
+ ],
+ "load_path": {
+ "imagereward": "ImageReward/ImageReward.safetensors",
+ "med_config": "ImageReward/med_config.json",
+ "bert_model_path": "bert-base-uncased",
+ },
+ "model_class": ImageRewardScore
+ },
+ "Aesthetic": {
+ "model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained",
+ "allow_file_pattern": [
+ "aesthetic-predictor/sac+logos+ava1-l14-linearMSE.safetensors",
+ "clip-vit-large-patch14/config.json",
+ "clip-vit-large-patch14/merges.txt",
+ "clip-vit-large-patch14/model.safetensors",
+ "clip-vit-large-patch14/preprocessor_config.json",
+ "clip-vit-large-patch14/special_tokens_map.json",
+ "clip-vit-large-patch14/tokenizer.json",
+ "clip-vit-large-patch14/tokenizer_config.json",
+ "clip-vit-large-patch14/vocab.json",
+ ],
+ "load_path": {
+ "aesthetic_predictor": "aesthetic-predictor/sac+logos+ava1-l14-linearMSE.safetensors",
+ "clip-large": "clip-vit-large-patch14",
+ },
+ "model_class": AestheticScore
+ },
+ "PickScore": {
+ "model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained",
+ "allow_file_pattern": [
+ "PickScore_v1/*",
+ "CLIP-ViT-H-14-laion2B-s32B-b79K/config.json",
+ "CLIP-ViT-H-14-laion2B-s32B-b79K/merges.txt",
+ "CLIP-ViT-H-14-laion2B-s32B-b79K/preprocessor_config.json",
+ "CLIP-ViT-H-14-laion2B-s32B-b79K/special_tokens_map.json",
+ "CLIP-ViT-H-14-laion2B-s32B-b79K/tokenizer.json",
+ "CLIP-ViT-H-14-laion2B-s32B-b79K/tokenizer_config.json",
+ "CLIP-ViT-H-14-laion2B-s32B-b79K/vocab.json",
+ ],
+ "load_path": {
+ "pickscore": "PickScore_v1",
+ "clip": "CLIP-ViT-H-14-laion2B-s32B-b79K",
+ },
+ "model_class": PickScore
+ },
+ "CLIP": {
+ "model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained",
+ "allow_file_pattern": [
+ "CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_pytorch_model.bin",
+ "bpe_simple_vocab_16e6.txt.gz",
+ ],
+ "load_path": {
+ "open_clip": "CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_pytorch_model.bin",
+ "open_clip_bpe": "bpe_simple_vocab_16e6.txt.gz",
+ },
+ "model_class": CLIPScore
+ },
+ "HPSv2": {
+ "model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained",
+ "allow_file_pattern": [
+ "HPS_v2/HPS_v2_compressed.safetensors",
+ "bpe_simple_vocab_16e6.txt.gz",
+ ],
+ "load_path": {
+ "hpsv2": "HPS_v2/HPS_v2_compressed.safetensors",
+ "open_clip_bpe": "bpe_simple_vocab_16e6.txt.gz",
+ },
+ "model_class": HPScore_v2,
+ "extra_kwargs": {"model_version": "v2"}
+ },
+ "HPSv2.1": {
+ "model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained",
+ "allow_file_pattern": [
+ "HPS_v2/HPS_v2.1_compressed.safetensors",
+ "bpe_simple_vocab_16e6.txt.gz",
+ ],
+ "load_path": {
+ "hpsv2.1": "HPS_v2/HPS_v2.1_compressed.safetensors",
+ "open_clip_bpe": "bpe_simple_vocab_16e6.txt.gz",
+ },
+ "model_class": HPScore_v2,
+ "extra_kwargs": {"model_version": "v21"}
+ },
+ "MPS": {
+ "model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained",
+ "allow_file_pattern": [
+ "MPS_overall_checkpoint/MPS_overall_checkpoint_diffsynth.safetensors",
+ "CLIP-ViT-H-14-laion2B-s32B-b79K/config.json",
+ "CLIP-ViT-H-14-laion2B-s32B-b79K/merges.txt",
+ "CLIP-ViT-H-14-laion2B-s32B-b79K/preprocessor_config.json",
+ "CLIP-ViT-H-14-laion2B-s32B-b79K/special_tokens_map.json",
+ "CLIP-ViT-H-14-laion2B-s32B-b79K/tokenizer.json",
+ "CLIP-ViT-H-14-laion2B-s32B-b79K/tokenizer_config.json",
+ "CLIP-ViT-H-14-laion2B-s32B-b79K/vocab.json",
+ ],
+ "load_path": {
+ "mps": "MPS_overall_checkpoint/MPS_overall_checkpoint_diffsynth.safetensors",
+ "clip": "CLIP-ViT-H-14-laion2B-s32B-b79K",
+ },
+ "model_class": MPScore
+ },
+}
+
+
+def download_preference_model(model_name: preference_model_id, cache_dir="models"):
+ metadata = model_dict[model_name]
+ snapshot_download(model_id=metadata["model_id"], allow_file_pattern=metadata["allow_file_pattern"], cache_dir=cache_dir)
+ load_path = metadata["load_path"]
+ load_path = {key: os.path.join(cache_dir, metadata["model_id"], path) for key, path in load_path.items()}
+ return load_path
+
+
+def load_preference_model(model_name: preference_model_id, device = "cuda", path = None):
+ model_class = model_dict[model_name]["model_class"]
+ extra_kwargs = model_dict[model_name].get("extra_kwargs", {})
+ preference_model = model_class(device=device, path=path, **extra_kwargs)
+ return preference_model
diff --git a/diffsynth/extensions/ImageQualityMetric/aesthetic.py b/diffsynth/extensions/ImageQualityMetric/aesthetic.py
new file mode 100644
index 0000000000000000000000000000000000000000..13da98a1f45ca7eea0411e18c307cc5d0154488f
--- /dev/null
+++ b/diffsynth/extensions/ImageQualityMetric/aesthetic.py
@@ -0,0 +1,148 @@
+from typing import List, Optional
+from PIL import Image
+import torch
+from transformers import AutoProcessor, AutoModel
+from safetensors.torch import load_file
+import os
+from typing import Union, List
+from .config import MODEL_PATHS
+
+class MLP(torch.nn.Module):
+ def __init__(self, input_size: int, xcol: str = "emb", ycol: str = "avg_rating"):
+ super().__init__()
+ self.input_size = input_size
+ self.xcol = xcol
+ self.ycol = ycol
+ self.layers = torch.nn.Sequential(
+ torch.nn.Linear(self.input_size, 1024),
+ #torch.nn.ReLU(),
+ torch.nn.Dropout(0.2),
+ torch.nn.Linear(1024, 128),
+ #torch.nn.ReLU(),
+ torch.nn.Dropout(0.2),
+ torch.nn.Linear(128, 64),
+ #torch.nn.ReLU(),
+ torch.nn.Dropout(0.1),
+ torch.nn.Linear(64, 16),
+ #torch.nn.ReLU(),
+ torch.nn.Linear(16, 1),
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.layers(x)
+
+ def training_step(self, batch: dict, batch_idx: int) -> torch.Tensor:
+ x = batch[self.xcol]
+ y = batch[self.ycol].reshape(-1, 1)
+ x_hat = self.layers(x)
+ loss = torch.nn.functional.mse_loss(x_hat, y)
+ return loss
+
+ def validation_step(self, batch: dict, batch_idx: int) -> torch.Tensor:
+ x = batch[self.xcol]
+ y = batch[self.ycol].reshape(-1, 1)
+ x_hat = self.layers(x)
+ loss = torch.nn.functional.mse_loss(x_hat, y)
+ return loss
+
+ def configure_optimizers(self) -> torch.optim.Optimizer:
+ return torch.optim.Adam(self.parameters(), lr=1e-3)
+
+
+class AestheticScore(torch.nn.Module):
+ def __init__(self, device: torch.device, path: str = MODEL_PATHS):
+ super().__init__()
+ self.device = device
+ self.aes_model_path = path.get("aesthetic_predictor")
+ # Load the MLP model
+ self.model = MLP(768)
+ try:
+ if self.aes_model_path.endswith(".safetensors"):
+ state_dict = load_file(self.aes_model_path)
+ else:
+ state_dict = torch.load(self.aes_model_path)
+ self.model.load_state_dict(state_dict)
+ except Exception as e:
+ raise ValueError(f"Error loading model weights from {self.aes_model_path}: {e}")
+
+ self.model.to(device)
+ self.model.eval()
+
+ # Load the CLIP model and processor
+ clip_model_name = path.get('clip-large')
+ self.model2 = AutoModel.from_pretrained(clip_model_name).eval().to(device)
+ self.processor = AutoProcessor.from_pretrained(clip_model_name)
+
+ def _calculate_score(self, image: torch.Tensor) -> float:
+ """Calculate the aesthetic score for a single image.
+
+ Args:
+ image (torch.Tensor): The processed image tensor.
+
+ Returns:
+ float: The aesthetic score.
+ """
+ with torch.no_grad():
+ # Get image embeddings
+ image_embs = self.model2.get_image_features(image)
+ image_embs = image_embs / torch.norm(image_embs, dim=-1, keepdim=True)
+
+ # Compute score
+ score = self.model(image_embs).cpu().flatten().item()
+
+ return score
+
+ @torch.no_grad()
+ def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str = "") -> List[float]:
+ """Score the images based on their aesthetic quality.
+
+ Args:
+ images (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s).
+
+ Returns:
+ List[float]: List of scores for the images.
+ """
+ try:
+ if isinstance(images, (str, Image.Image)):
+ # Single image
+ if isinstance(images, str):
+ pil_image = Image.open(images)
+ else:
+ pil_image = images
+
+ # Prepare image inputs
+ image_inputs = self.processor(
+ images=pil_image,
+ padding=True,
+ truncation=True,
+ max_length=77,
+ return_tensors="pt",
+ ).to(self.device)
+
+ return [self._calculate_score(image_inputs["pixel_values"])]
+ elif isinstance(images, list):
+ # Multiple images
+ scores = []
+ for one_image in images:
+ if isinstance(one_image, str):
+ pil_image = Image.open(one_image)
+ elif isinstance(one_image, Image.Image):
+ pil_image = one_image
+ else:
+ raise TypeError("The type of parameter images is illegal.")
+
+ # Prepare image inputs
+ image_inputs = self.processor(
+ images=pil_image,
+ padding=True,
+ truncation=True,
+ max_length=77,
+ return_tensors="pt",
+ ).to(self.device)
+
+ scores.append(self._calculate_score(image_inputs["pixel_values"]))
+ return scores
+ else:
+ raise TypeError("The type of parameter images is illegal.")
+ except Exception as e:
+ raise RuntimeError(f"Error in scoring images: {e}")
diff --git a/diffsynth/extensions/ImageQualityMetric/clip.py b/diffsynth/extensions/ImageQualityMetric/clip.py
new file mode 100644
index 0000000000000000000000000000000000000000..f70941e0a45db61be87e21c347e97ad8bb390fff
--- /dev/null
+++ b/diffsynth/extensions/ImageQualityMetric/clip.py
@@ -0,0 +1,97 @@
+from typing import List, Union
+from PIL import Image
+import torch
+from .open_clip import create_model_and_transforms, get_tokenizer
+from .config import MODEL_PATHS
+
+class CLIPScore(torch.nn.Module):
+ def __init__(self, device: torch.device, path: str = MODEL_PATHS):
+ super().__init__()
+ """Initialize the CLIPScore with a model and tokenizer.
+
+ Args:
+ device (torch.device): The device to load the model on.
+ """
+ self.device = device
+
+ # Create model and transforms
+ self.model, _, self.preprocess_val = create_model_and_transforms(
+ "ViT-H-14",
+ # "laion2B-s32B-b79K",
+ pretrained=path.get("open_clip"),
+ precision="amp",
+ device=device,
+ jit=False,
+ force_quick_gelu=False,
+ force_custom_text=False,
+ force_patch_dropout=False,
+ force_image_size=None,
+ pretrained_image=False,
+ image_mean=None,
+ image_std=None,
+ light_augmentation=True,
+ aug_cfg={},
+ output_dict=True,
+ with_score_predictor=False,
+ with_region_predictor=False,
+ )
+
+ # Initialize tokenizer
+ self.tokenizer = get_tokenizer("ViT-H-14", path["open_clip_bpe"])
+ self.model = self.model.to(device)
+ self.model.eval()
+
+ def _calculate_score(self, image: torch.Tensor, prompt: str) -> float:
+ """Calculate the CLIP score for a single image and prompt.
+
+ Args:
+ image (torch.Tensor): The processed image tensor.
+ prompt (str): The prompt text.
+
+ Returns:
+ float: The CLIP score.
+ """
+ with torch.no_grad():
+ # Process the prompt
+ text = self.tokenizer([prompt]).to(device=self.device, non_blocking=True)
+
+ # Calculate the CLIP score
+ outputs = self.model(image, text)
+ image_features, text_features = outputs["image_features"], outputs["text_features"]
+ logits_per_image = image_features @ text_features.T
+ clip_score = torch.diagonal(logits_per_image).cpu().numpy()
+
+ return clip_score[0].item()
+
+ @torch.no_grad()
+ def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str) -> List[float]:
+ """Score the images based on the prompt.
+
+ Args:
+ images (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s).
+ prompt (str): The prompt text.
+
+ Returns:
+ List[float]: List of CLIP scores for the images.
+ """
+ if isinstance(images, (str, Image.Image)):
+ # Single image
+ if isinstance(images, str):
+ image = self.preprocess_val(Image.open(images)).unsqueeze(0).to(device=self.device, non_blocking=True)
+ else:
+ image = self.preprocess_val(images).unsqueeze(0).to(device=self.device, non_blocking=True)
+ return [self._calculate_score(image, prompt)]
+ elif isinstance(images, list):
+ # Multiple images
+ scores = []
+ for one_images in images:
+ if isinstance(one_images, str):
+ image = self.preprocess_val(Image.open(one_images)).unsqueeze(0).to(device=self.device, non_blocking=True)
+ elif isinstance(one_images, Image.Image):
+ image = self.preprocess_val(one_images).unsqueeze(0).to(device=self.device, non_blocking=True)
+ else:
+ raise TypeError("The type of parameter images is illegal.")
+ scores.append(self._calculate_score(image, prompt))
+ return scores
+ else:
+ raise TypeError("The type of parameter images is illegal.")
diff --git a/diffsynth/extensions/ImageQualityMetric/config.py b/diffsynth/extensions/ImageQualityMetric/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..60faadcb1e5554c8f8f29a64fc55c3150d8a8bbe
--- /dev/null
+++ b/diffsynth/extensions/ImageQualityMetric/config.py
@@ -0,0 +1,23 @@
+import os
+
+current_dir = os.path.dirname(os.path.abspath(__file__))
+project_root = os.path.abspath(os.path.join(current_dir, '../../../'))
+model_path = os.path.join(project_root, 'models', 'QualityMetric')
+
+
+def get_model_path(model_name):
+ return os.path.join(model_path, model_name)
+
+
+MODEL_PATHS = {
+ "aesthetic_predictor": get_model_path("aesthetic-predictor/sac+logos+ava1-l14-linearMSE.safetensors"),
+ "open_clip": get_model_path("CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_pytorch_model.bin"),
+ "hpsv2": get_model_path("HPS_v2/HPS_v2_compressed.safetensors"),
+ "hpsv2.1": get_model_path("HPS_v2/HPS_v2.1_compressed.safetensors"),
+ "imagereward": get_model_path("ImageReward/ImageReward.safetensors"),
+ "med_config": get_model_path("ImageReward/med_config.json"),
+ "clip": get_model_path("CLIP-ViT-H-14-laion2B-s32B-b79K"),
+ "clip-large": get_model_path("clip-vit-large-patch14"),
+ "mps": get_model_path("MPS_overall_checkpoint/MPS_overall_checkpoint_diffsynth.safetensors"),
+ "pickscore": get_model_path("PickScore_v1")
+}
\ No newline at end of file
diff --git a/diffsynth/extensions/ImageQualityMetric/hps.py b/diffsynth/extensions/ImageQualityMetric/hps.py
new file mode 100644
index 0000000000000000000000000000000000000000..a4b266bd261a95676ba700d38c3a63b143bbbb40
--- /dev/null
+++ b/diffsynth/extensions/ImageQualityMetric/hps.py
@@ -0,0 +1,118 @@
+from typing import List, Union
+from PIL import Image
+import torch
+from .open_clip import create_model_and_transforms, get_tokenizer
+from safetensors.torch import load_file
+import os
+from .config import MODEL_PATHS
+
+class HPScore_v2(torch.nn.Module):
+ def __init__(self, device: torch.device, path: str = MODEL_PATHS, model_version: str = "v2"):
+ super().__init__()
+ """Initialize the Selector with a model and tokenizer.
+
+ Args:
+ device (torch.device): The device to load the model on.
+ model_version (str): The version of the model to load. Supports "v2" or "v21". Default is "v2".
+ """
+ self.device = device
+
+ if model_version == "v2":
+ safetensors_path = path.get("hpsv2")
+ elif model_version == "v21":
+ safetensors_path = path.get("hpsv2.1")
+ else:
+ raise ValueError(f"Unsupported model version: {model_version}. Choose 'v2' or 'v21'.")
+
+ # Create model and transforms
+ model, _, self.preprocess_val = create_model_and_transforms(
+ "ViT-H-14",
+ # "laion2B-s32B-b79K",
+ pretrained=path.get("open_clip"),
+ precision="amp",
+ device=device,
+ jit=False,
+ force_quick_gelu=False,
+ force_custom_text=False,
+ force_patch_dropout=False,
+ force_image_size=None,
+ pretrained_image=False,
+ image_mean=None,
+ image_std=None,
+ light_augmentation=True,
+ aug_cfg={},
+ output_dict=True,
+ with_score_predictor=False,
+ with_region_predictor=False,
+ )
+
+ # Load model weights
+ try:
+ state_dict = load_file(safetensors_path)
+ model.load_state_dict(state_dict)
+ except Exception as e:
+ raise ValueError(f"Error loading model weights from {safetensors_path}: {e}")
+
+ # Initialize tokenizer and model
+ self.tokenizer = get_tokenizer("ViT-H-14", path["open_clip_bpe"])
+ model = model.to(device)
+ model.eval()
+ self.model = model
+
+ def _calculate_score(self, image: torch.Tensor, prompt: str) -> float:
+ """Calculate the HPS score for a single image and prompt.
+
+ Args:
+ image (torch.Tensor): The processed image tensor.
+ prompt (str): The prompt text.
+
+ Returns:
+ float: The HPS score.
+ """
+ with torch.no_grad():
+ # Process the prompt
+ text = self.tokenizer([prompt]).to(device=self.device, non_blocking=True)
+
+ # Calculate the HPS score
+ outputs = self.model(image, text)
+ image_features, text_features = outputs["image_features"], outputs["text_features"]
+ logits_per_image = image_features @ text_features.T
+ hps_score = torch.diagonal(logits_per_image).cpu().numpy()
+
+ return hps_score[0].item()
+
+ @torch.no_grad()
+ def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str) -> List[float]:
+ """Score the images based on the prompt.
+
+ Args:
+ images (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s).
+ prompt (str): The prompt text.
+
+ Returns:
+ List[float]: List of HPS scores for the images.
+ """
+ try:
+ if isinstance(images, (str, Image.Image)):
+ # Single image
+ if isinstance(images, str):
+ image = self.preprocess_val(Image.open(images)).unsqueeze(0).to(device=self.device, non_blocking=True)
+ else:
+ image = self.preprocess_val(images).unsqueeze(0).to(device=self.device, non_blocking=True)
+ return [self._calculate_score(image, prompt)]
+ elif isinstance(images, list):
+ # Multiple images
+ scores = []
+ for one_images in images:
+ if isinstance(one_images, str):
+ image = self.preprocess_val(Image.open(one_images)).unsqueeze(0).to(device=self.device, non_blocking=True)
+ elif isinstance(one_images, Image.Image):
+ image = self.preprocess_val(one_images).unsqueeze(0).to(device=self.device, non_blocking=True)
+ else:
+ raise TypeError("The type of parameter images is illegal.")
+ scores.append(self._calculate_score(image, prompt))
+ return scores
+ else:
+ raise TypeError("The type of parameter images is illegal.")
+ except Exception as e:
+ raise RuntimeError(f"Error in scoring images: {e}")
diff --git a/diffsynth/extensions/ImageQualityMetric/imagereward.py b/diffsynth/extensions/ImageQualityMetric/imagereward.py
new file mode 100644
index 0000000000000000000000000000000000000000..27607904b23fa1691c5a6966eb4030cd813567b0
--- /dev/null
+++ b/diffsynth/extensions/ImageQualityMetric/imagereward.py
@@ -0,0 +1,212 @@
+import os
+import torch
+from PIL import Image
+from typing import List, Union
+from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
+from .BLIP.blip_pretrain import BLIP_Pretrain
+from torchvision.transforms import InterpolationMode
+from safetensors.torch import load_file
+from .config import MODEL_PATHS
+BICUBIC = InterpolationMode.BICUBIC
+
+def _convert_image_to_rgb(image):
+ return image.convert("RGB")
+
+def _transform(n_px):
+ return Compose([
+ Resize(n_px, interpolation=BICUBIC),
+ CenterCrop(n_px),
+ _convert_image_to_rgb,
+ ToTensor(),
+ Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
+ ])
+
+class MLP(torch.nn.Module):
+ def __init__(self, input_size):
+ super().__init__()
+ self.input_size = input_size
+
+ self.layers = torch.nn.Sequential(
+ torch.nn.Linear(self.input_size, 1024),
+ #nn.ReLU(),
+ torch.nn.Dropout(0.2),
+ torch.nn.Linear(1024, 128),
+ #nn.ReLU(),
+ torch.nn.Dropout(0.2),
+ torch.nn.Linear(128, 64),
+ #nn.ReLU(),
+ torch.nn.Dropout(0.1),
+ torch.nn.Linear(64, 16),
+ #nn.ReLU(),
+ torch.nn.Linear(16, 1)
+ )
+
+ # initial MLP param
+ for name, param in self.layers.named_parameters():
+ if 'weight' in name:
+ torch.nn.init.normal_(param, mean=0.0, std=1.0/(self.input_size+1))
+ if 'bias' in name:
+ torch.nn.init.constant_(param, val=0)
+
+ def forward(self, input):
+ return self.layers(input)
+
+class ImageReward(torch.nn.Module):
+ def __init__(self, med_config, device='cpu', bert_model_path=""):
+ super().__init__()
+ self.device = device
+
+ self.blip = BLIP_Pretrain(image_size=224, vit='large', med_config=med_config, bert_model_path=bert_model_path)
+ self.preprocess = _transform(224)
+ self.mlp = MLP(768)
+
+ self.mean = 0.16717362830052426
+ self.std = 1.0333394966054072
+
+ def score_grad(self, prompt_ids, prompt_attention_mask, image):
+ """Calculate the score with gradient for a single image and prompt.
+
+ Args:
+ prompt_ids (torch.Tensor): Tokenized prompt IDs.
+ prompt_attention_mask (torch.Tensor): Attention mask for the prompt.
+ image (torch.Tensor): The processed image tensor.
+
+ Returns:
+ torch.Tensor: The reward score.
+ """
+ image_embeds = self.blip.visual_encoder(image)
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(self.device)
+ text_output = self.blip.text_encoder(
+ prompt_ids,
+ attention_mask=prompt_attention_mask,
+ encoder_hidden_states=image_embeds,
+ encoder_attention_mask=image_atts,
+ return_dict=True,
+ )
+ txt_features = text_output.last_hidden_state[:, 0, :]
+ rewards = self.mlp(txt_features)
+ rewards = (rewards - self.mean) / self.std
+ return rewards
+
+ def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str = "") -> List[float]:
+ """Score the images based on the prompt.
+
+ Args:
+ prompt (str): The prompt text.
+ images (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s).
+
+ Returns:
+ List[float]: List of scores for the images.
+ """
+ if isinstance(images, (str, Image.Image)):
+ # Single image
+ if isinstance(images, str):
+ pil_image = Image.open(images)
+ else:
+ pil_image = images
+ image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
+ return [self._calculate_score(prompt, image).item()]
+ elif isinstance(images, list):
+ # Multiple images
+ scores = []
+ for one_image in images:
+ if isinstance(one_image, str):
+ pil_image = Image.open(one_image)
+ elif isinstance(one_image, Image.Image):
+ pil_image = one_image
+ else:
+ raise TypeError("The type of parameter images is illegal.")
+ image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
+ scores.append(self._calculate_score(prompt, image).item())
+ return scores
+ else:
+ raise TypeError("The type of parameter images is illegal.")
+
+ def _calculate_score(self, prompt: str, image: torch.Tensor) -> torch.Tensor:
+ """Calculate the score for a single image and prompt.
+
+ Args:
+ prompt (str): The prompt text.
+ image (torch.Tensor): The processed image tensor.
+
+ Returns:
+ torch.Tensor: The reward score.
+ """
+ text_input = self.blip.tokenizer(prompt, padding='max_length', truncation=True, max_length=35, return_tensors="pt").to(self.device)
+ image_embeds = self.blip.visual_encoder(image)
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(self.device)
+ text_output = self.blip.text_encoder(
+ text_input.input_ids,
+ attention_mask=text_input.attention_mask,
+ encoder_hidden_states=image_embeds,
+ encoder_attention_mask=image_atts,
+ return_dict=True,
+ )
+ txt_features = text_output.last_hidden_state[:, 0, :].float()
+ rewards = self.mlp(txt_features)
+ rewards = (rewards - self.mean) / self.std
+ return rewards
+
+ def inference_rank(self, prompt: str, generations_list: List[Union[str, Image.Image]]) -> tuple:
+ """Rank the images based on the prompt.
+
+ Args:
+ prompt (str): The prompt text.
+ generations_list (List[Union[str, Image.Image]]): List of image paths or PIL images.
+
+ Returns:
+ tuple: (indices, rewards) where indices are the ranks and rewards are the scores.
+ """
+ text_input = self.blip.tokenizer(prompt, padding='max_length', truncation=True, max_length=35, return_tensors="pt").to(self.device)
+ txt_set = []
+ for generation in generations_list:
+ if isinstance(generation, str):
+ pil_image = Image.open(generation)
+ elif isinstance(generation, Image.Image):
+ pil_image = generation
+ else:
+ raise TypeError("The type of parameter generations_list is illegal.")
+ image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
+ image_embeds = self.blip.visual_encoder(image)
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(self.device)
+ text_output = self.blip.text_encoder(
+ text_input.input_ids,
+ attention_mask=text_input.attention_mask,
+ encoder_hidden_states=image_embeds,
+ encoder_attention_mask=image_atts,
+ return_dict=True,
+ )
+ txt_set.append(text_output.last_hidden_state[:, 0, :])
+ txt_features = torch.cat(txt_set, 0).float()
+ rewards = self.mlp(txt_features)
+ rewards = (rewards - self.mean) / self.std
+ rewards = torch.squeeze(rewards)
+ _, rank = torch.sort(rewards, dim=0, descending=True)
+ _, indices = torch.sort(rank, dim=0)
+ indices = indices + 1
+ return indices.detach().cpu().numpy().tolist(), rewards.detach().cpu().numpy().tolist()
+
+
+class ImageRewardScore(torch.nn.Module):
+ def __init__(self, device: Union[str, torch.device], path: str = MODEL_PATHS):
+ super().__init__()
+ self.device = device if isinstance(device, torch.device) else torch.device(device)
+ model_path = path.get("imagereward")
+ med_config = path.get("med_config")
+ state_dict = load_file(model_path)
+ self.model = ImageReward(device=self.device, med_config=med_config, bert_model_path=path.get("bert_model_path")).to(self.device)
+ self.model.load_state_dict(state_dict, strict=False)
+ self.model.eval()
+
+ @torch.no_grad()
+ def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str) -> List[float]:
+ """Score the images based on the prompt.
+
+ Args:
+ images (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s).
+ prompt (str): The prompt text.
+
+ Returns:
+ List[float]: List of scores for the images.
+ """
+ return self.model.score(images, prompt)
diff --git a/diffsynth/extensions/ImageQualityMetric/mps.py b/diffsynth/extensions/ImageQualityMetric/mps.py
new file mode 100644
index 0000000000000000000000000000000000000000..d15aad4b81026a743911512bcc569520182b31c5
--- /dev/null
+++ b/diffsynth/extensions/ImageQualityMetric/mps.py
@@ -0,0 +1,129 @@
+import numpy as np
+import torch
+from PIL import Image
+from io import BytesIO
+from tqdm.auto import tqdm
+from transformers import CLIPFeatureExtractor, CLIPImageProcessor
+from transformers import CLIPConfig
+from dataclasses import dataclass
+from transformers import CLIPModel as HFCLIPModel
+from safetensors.torch import load_file
+from torch import nn, einsum
+
+from .trainer.models.base_model import BaseModelConfig
+
+from transformers import CLIPConfig
+from transformers import AutoProcessor, AutoModel, AutoTokenizer
+from typing import Any, Optional, Tuple, Union, List
+import torch
+
+from .trainer.models.cross_modeling import Cross_model
+from .trainer.models import clip_model
+import torch.nn.functional as F
+import gc
+import json
+from .config import MODEL_PATHS
+
+class MPScore(torch.nn.Module):
+ def __init__(self, device: Union[str, torch.device], path: str = MODEL_PATHS, condition: str = 'overall'):
+ super().__init__()
+ """Initialize the MPSModel with a processor, tokenizer, and model.
+
+ Args:
+ device (Union[str, torch.device]): The device to load the model on.
+ """
+ self.device = device
+ processor_name_or_path = path.get("clip")
+ self.image_processor = CLIPImageProcessor.from_pretrained(processor_name_or_path)
+ self.tokenizer = AutoTokenizer.from_pretrained(processor_name_or_path, trust_remote_code=True)
+ self.model = clip_model.CLIPModel(processor_name_or_path, config_file=True)
+ state_dict = load_file(path.get("mps"))
+ self.model.load_state_dict(state_dict, strict=False)
+ self.model.to(device)
+ self.condition = condition
+
+ def _calculate_score(self, image: torch.Tensor, prompt: str) -> float:
+ """Calculate the reward score for a single image and prompt.
+
+ Args:
+ image (torch.Tensor): The processed image tensor.
+ prompt (str): The prompt text.
+
+ Returns:
+ float: The reward score.
+ """
+ def _tokenize(caption):
+ input_ids = self.tokenizer(
+ caption,
+ max_length=self.tokenizer.model_max_length,
+ padding="max_length",
+ truncation=True,
+ return_tensors="pt"
+ ).input_ids
+ return input_ids
+
+ text_input = _tokenize(prompt).to(self.device)
+ if self.condition == 'overall':
+ condition_prompt = 'light, color, clarity, tone, style, ambiance, artistry, shape, face, hair, hands, limbs, structure, instance, texture, quantity, attributes, position, number, location, word, things'
+ elif self.condition == 'aesthetics':
+ condition_prompt = 'light, color, clarity, tone, style, ambiance, artistry'
+ elif self.condition == 'quality':
+ condition_prompt = 'shape, face, hair, hands, limbs, structure, instance, texture'
+ elif self.condition == 'semantic':
+ condition_prompt = 'quantity, attributes, position, number, location'
+ else:
+ raise ValueError(
+ f"Unsupported condition: {self.condition}. Choose 'overall', 'aesthetics', 'quality', or 'semantic'.")
+ condition_batch = _tokenize(condition_prompt).repeat(text_input.shape[0], 1).to(self.device)
+
+ with torch.no_grad():
+ text_f, text_features = self.model.model.get_text_features(text_input)
+
+ image_f = self.model.model.get_image_features(image.half())
+ condition_f, _ = self.model.model.get_text_features(condition_batch)
+
+ sim_text_condition = einsum('b i d, b j d -> b j i', text_f, condition_f)
+ sim_text_condition = torch.max(sim_text_condition, dim=1, keepdim=True)[0]
+ sim_text_condition = sim_text_condition / sim_text_condition.max()
+ mask = torch.where(sim_text_condition > 0.3, 0, float('-inf'))
+ mask = mask.repeat(1, image_f.shape[1], 1)
+ image_features = self.model.cross_model(image_f, text_f, mask.half())[:, 0, :]
+
+ image_features = image_features / image_features.norm(dim=-1, keepdim=True)
+ text_features = text_features / text_features.norm(dim=-1, keepdim=True)
+ image_score = self.model.logit_scale.exp() * text_features @ image_features.T
+
+ return image_score[0].cpu().numpy().item()
+
+ @torch.no_grad()
+ def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str) -> List[float]:
+ """Score the images based on the prompt.
+
+ Args:
+ images (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s).
+ prompt (str): The prompt text.
+
+ Returns:
+ List[float]: List of reward scores for the images.
+ """
+ if isinstance(images, (str, Image.Image)):
+ # Single image
+ if isinstance(images, str):
+ image = self.image_processor(Image.open(images), return_tensors="pt")["pixel_values"].to(self.device)
+ else:
+ image = self.image_processor(images, return_tensors="pt")["pixel_values"].to(self.device)
+ return [self._calculate_score(image, prompt)]
+ elif isinstance(images, list):
+ # Multiple images
+ scores = []
+ for one_images in images:
+ if isinstance(one_images, str):
+ image = self.image_processor(Image.open(one_images), return_tensors="pt")["pixel_values"].to(self.device)
+ elif isinstance(one_images, Image.Image):
+ image = self.image_processor(one_images, return_tensors="pt")["pixel_values"].to(self.device)
+ else:
+ raise TypeError("The type of parameter images is illegal.")
+ scores.append(self._calculate_score(image, prompt))
+ return scores
+ else:
+ raise TypeError("The type of parameter images is illegal.")
diff --git a/diffsynth/extensions/ImageQualityMetric/open_clip/__init__.py b/diffsynth/extensions/ImageQualityMetric/open_clip/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1560db0b543b7b8857f39d7de435c834380666ab
--- /dev/null
+++ b/diffsynth/extensions/ImageQualityMetric/open_clip/__init__.py
@@ -0,0 +1,14 @@
+from .coca_model import CoCa
+from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
+from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_loss
+from .factory import list_models, add_model_config, get_model_config, load_checkpoint
+from .loss import ClipLoss, DistillClipLoss, CoCaLoss
+from .model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg, \
+ convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype
+from .openai import load_openai_model, list_openai_models
+from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, \
+ get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained
+from .push_to_hf_hub import push_pretrained_to_hf_hub, push_to_hf_hub
+from .tokenizer import SimpleTokenizer
+from .transform import image_transform, AugmentationCfg
+from .utils import freeze_batch_norm_2d
diff --git a/diffsynth/extensions/ImageQualityMetric/open_clip/coca_model.py b/diffsynth/extensions/ImageQualityMetric/open_clip/coca_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..039453af70d1c865dd7cc6016f732aff2f7dc3d2
--- /dev/null
+++ b/diffsynth/extensions/ImageQualityMetric/open_clip/coca_model.py
@@ -0,0 +1,458 @@
+from typing import Optional
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+import numpy as np
+from dataclasses import dataclass
+
+from .transformer import (
+ LayerNormFp32,
+ LayerNorm,
+ QuickGELU,
+ MultimodalTransformer,
+)
+from .model import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower
+
+try:
+ from transformers import (
+ BeamSearchScorer,
+ LogitsProcessorList,
+ TopPLogitsWarper,
+ TopKLogitsWarper,
+ RepetitionPenaltyLogitsProcessor,
+ MinLengthLogitsProcessor,
+ MaxLengthCriteria,
+ StoppingCriteriaList
+ )
+
+ GENERATION_TYPES = {
+ "top_k": TopKLogitsWarper,
+ "top_p": TopPLogitsWarper,
+ "beam_search": "beam_search"
+ }
+ _has_transformers = True
+except ImportError as e:
+ GENERATION_TYPES = {
+ "top_k": None,
+ "top_p": None,
+ "beam_search": "beam_search"
+ }
+ _has_transformers = False
+
+
+@dataclass
+class MultimodalCfg(CLIPTextCfg):
+ mlp_ratio: int = 4
+ dim_head: int = 64
+ heads: int = 8
+ n_queries: int = 256
+ attn_pooler_heads: int = 8
+
+
+def _build_text_decoder_tower(
+ embed_dim,
+ multimodal_cfg,
+ quick_gelu: bool = False,
+ cast_dtype: Optional[torch.dtype] = None,
+):
+ multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg
+ act_layer = QuickGELU if quick_gelu else nn.GELU
+ norm_layer = (
+ LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
+ )
+
+ decoder = MultimodalTransformer(
+ context_length=multimodal_cfg.context_length,
+ width=multimodal_cfg.width,
+ heads=multimodal_cfg.heads,
+ layers=multimodal_cfg.layers,
+ ls_init_value=multimodal_cfg.ls_init_value,
+ output_dim=embed_dim,
+ act_layer=act_layer,
+ norm_layer=norm_layer,
+ )
+
+ return decoder
+
+
+class CoCa(nn.Module):
+ def __init__(
+ self,
+ embed_dim,
+ multimodal_cfg: MultimodalCfg,
+ text_cfg: CLIPTextCfg,
+ vision_cfg: CLIPVisionCfg,
+ quick_gelu: bool = False,
+ cast_dtype: Optional[torch.dtype] = None,
+ pad_id: int = 0,
+ ):
+ super().__init__()
+ multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg
+ text_cfg = CLIPTextCfg(**text_cfg) if isinstance(text_cfg, dict) else text_cfg
+ vision_cfg = CLIPVisionCfg(**vision_cfg) if isinstance(vision_cfg, dict) else vision_cfg
+
+ self.text = _build_text_tower(
+ embed_dim=embed_dim,
+ text_cfg=text_cfg,
+ quick_gelu=quick_gelu,
+ cast_dtype=cast_dtype,
+ )
+
+ vocab_size = (
+ text_cfg.vocab_size # for hf models
+ if hasattr(text_cfg, "hf_model_name") and text_cfg.hf_model_name is not None
+ else text_cfg.vocab_size
+ )
+
+ self.visual = _build_vision_tower(
+ embed_dim=embed_dim,
+ vision_cfg=vision_cfg,
+ quick_gelu=quick_gelu,
+ cast_dtype=cast_dtype,
+ )
+
+ self.text_decoder = _build_text_decoder_tower(
+ vocab_size,
+ multimodal_cfg=multimodal_cfg,
+ quick_gelu=quick_gelu,
+ cast_dtype=cast_dtype,
+ )
+
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
+ self.pad_id = pad_id
+
+ @torch.jit.ignore
+ def set_grad_checkpointing(self, enable=True):
+ self.visual.set_grad_checkpointing(enable)
+ self.text.set_grad_checkpointing(enable)
+ self.text_decoder.set_grad_checkpointing(enable)
+
+ def _encode_image(self, images, normalize=True):
+ image_latent, tokens_embs = self.visual(images)
+ image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent
+ return image_latent, tokens_embs
+
+ def _encode_text(self, text, normalize=True, embed_cls=True):
+ text = text[:, :-1] if embed_cls else text # make space for CLS token
+ text_latent, token_emb = self.text(text)
+ text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent
+ return text_latent, token_emb
+
+ def encode_image(self, images, normalize=True):
+ image_latent, _ = self._encode_image(images, normalize=normalize)
+ return image_latent
+
+ def encode_text(self, text, normalize=True, embed_cls=True):
+ text_latent, _ = self._encode_text(text, normalize=normalize, embed_cls=embed_cls)
+ return text_latent
+
+ def forward(self, image, text, embed_cls=True, image_latent=None, image_embs=None):
+ text_latent, token_embs = self._encode_text(text, embed_cls=embed_cls)
+ if image_latent is None or image_embs is None:
+ image_latent, image_embs = self._encode_image(image)
+
+ # TODO: add assertion to avoid bugs?
+ labels = text[:, -token_embs.shape[1]:]
+
+ logits = self.text_decoder(image_embs, token_embs)
+ return {
+ "image_features": image_latent,
+ "text_features": text_latent,
+ "logits": logits,
+ "labels": labels,
+ "logit_scale": self.logit_scale.exp()
+ }
+
+ def generate(
+ self,
+ image,
+ text=None,
+ seq_len=30,
+ max_seq_len=77,
+ temperature=1.,
+ generation_type="beam_search",
+ top_p=0.1, # keep tokens in the 1 - top_p quantile
+ top_k=1, # keeps the top_k most probable tokens
+ pad_token_id=None,
+ eos_token_id=None,
+ sot_token_id=None,
+ num_beams=6,
+ num_beam_groups=3,
+ min_seq_len=5,
+ stopping_criteria=None,
+ repetition_penalty=1.0,
+ fixed_output_length=False # if True output.shape == (batch_size, seq_len)
+ ):
+ # taking many ideas and components from HuggingFace GenerationMixin
+ # https://huggingface.co/docs/transformers/main/en/main_classes/text_generation
+ assert _has_transformers, "Please install transformers for generate functionality. `pip install transformers`."
+ assert seq_len > min_seq_len, "seq_len must be larger than min_seq_len"
+
+ with torch.no_grad():
+ sot_token_id = 49406 if sot_token_id is None else sot_token_id
+ eos_token_id = 49407 if eos_token_id is None else eos_token_id
+ pad_token_id = self.pad_id if pad_token_id is None else pad_token_id
+ logit_processor = LogitsProcessorList(
+ [
+ MinLengthLogitsProcessor(min_seq_len, eos_token_id),
+ RepetitionPenaltyLogitsProcessor(repetition_penalty),
+ ]
+ )
+
+ if stopping_criteria is None:
+ stopping_criteria = [MaxLengthCriteria(max_length=seq_len)]
+
+ stopping_criteria = StoppingCriteriaList(
+ stopping_criteria
+ )
+
+ device = image.device
+
+ if generation_type == "beam_search":
+ output = self._generate_beamsearch(
+ image_inputs = image,
+ pad_token_id=pad_token_id,
+ eos_token_id=eos_token_id,
+ sot_token_id=sot_token_id,
+ num_beams=num_beams,
+ num_beam_groups=num_beam_groups,
+ min_seq_len=min_seq_len,
+ stopping_criteria=stopping_criteria,
+ logit_processor=logit_processor,
+ )
+ if fixed_output_length and output.shape[1] < seq_len:
+ return torch.cat(
+ (output, torch.ones(output.shape[0], seq_len-output.shape[1], device=device, dtype=output.dtype) * self.pad_id),
+ dim=1
+ )
+ return output
+
+ elif generation_type == "top_p":
+ logit_warper = GENERATION_TYPES[generation_type](top_p)
+ elif generation_type == "top_k":
+ logit_warper = GENERATION_TYPES[generation_type](top_k)
+ else:
+ raise ValueError(
+ f"generation_type has to be one of "
+ f"{'| ' + ' | '.join(list(GENERATION_TYPES.keys())) + ' |'}."
+ )
+
+ image_latent, image_embs = self._encode_image(image)
+
+ if text is None:
+ text = torch.ones((image.shape[0], 1), device=device, dtype=torch.long) * sot_token_id
+
+ was_training = self.training
+ num_dims = len(text.shape)
+
+ if num_dims == 1:
+ text = text[None, :]
+
+ cur_len = text.shape[1]
+ self.eval()
+ out = text
+
+ while True:
+ x = out[:, -max_seq_len:]
+ cur_len = x.shape[1]
+ logits = self(image, x, image_latent=image_latent, image_embs=image_embs, embed_cls=False)["logits"][:, -1]
+ mask = (out[:, -1] == eos_token_id) | (out[:, -1] == pad_token_id)
+ sample = torch.ones((out.shape[0], 1), device=device, dtype=torch.long) * pad_token_id
+
+ if mask.all():
+ if not fixed_output_length:
+ break
+ else:
+ logits = logits[~mask, :]
+ filtered_logits = logit_processor(x[~mask, :], logits)
+ filtered_logits = logit_warper(x[~mask, :], filtered_logits)
+ probs = F.softmax(filtered_logits / temperature, dim=-1)
+
+ if (cur_len + 1 == seq_len):
+ sample[~mask, :] = torch.ones((sum(~mask), 1), device=device, dtype=torch.long) * eos_token_id
+ else:
+ sample[~mask, :] = torch.multinomial(probs, 1)
+
+ out = torch.cat((out, sample), dim=-1)
+
+ cur_len += 1
+
+ if stopping_criteria(out, None):
+ break
+
+ if num_dims == 1:
+ out = out.squeeze(0)
+
+ self.train(was_training)
+ return out
+
+ def _generate_beamsearch(
+ self,
+ image_inputs,
+ pad_token_id=None,
+ eos_token_id=None,
+ sot_token_id=None,
+ num_beams=6,
+ num_beam_groups=3,
+ min_seq_len=5,
+ stopping_criteria=None,
+ logit_processor=None,
+ logit_warper=None,
+ ):
+ device = image_inputs.device
+ batch_size = image_inputs.shape[0]
+ image_inputs = torch.repeat_interleave(image_inputs, num_beams, dim=0)
+ image_latent, image_embs = self._encode_image(image_inputs)
+
+ input_ids = torch.ones((batch_size * num_beams, 1), device=device, dtype=torch.long)
+ input_ids = input_ids * sot_token_id
+ beam_scorer = BeamSearchScorer(
+ batch_size=batch_size,
+ num_beams=num_beams,
+ device=device,
+ num_beam_groups=num_beam_groups,
+ )
+ # instantiate logits processors
+ logits_processor = (
+ LogitsProcessorList([MinLengthLogitsProcessor(min_seq_len, eos_token_id=eos_token_id)])
+ if logit_processor is None
+ else logit_processor
+ )
+
+ batch_size = len(beam_scorer._beam_hyps)
+ num_beams = beam_scorer.num_beams
+ num_beam_groups = beam_scorer.num_beam_groups
+ num_sub_beams = num_beams // num_beam_groups
+ batch_beam_size, cur_len = input_ids.shape
+ beam_indices = None
+
+ if num_beams * batch_size != batch_beam_size:
+ raise ValueError(
+ f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
+ )
+
+ beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device)
+ # initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in
+ # the same group don't produce same tokens everytime.
+ beam_scores[:, ::num_sub_beams] = 0
+ beam_scores = beam_scores.view((batch_size * num_beams,))
+
+ while True:
+
+ # predicted tokens in cur_len step
+ current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device)
+
+ # indices which will form the beams in the next time step
+ reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device)
+
+ # do one decoder step on all beams of all sentences in batch
+ model_inputs = prepare_inputs_for_generation(input_ids=input_ids, image_inputs=image_inputs)
+ outputs = self(
+ model_inputs['images'],
+ model_inputs['text'],
+ embed_cls=False,
+ image_latent=image_latent,
+ image_embs=image_embs
+ )
+
+ for beam_group_idx in range(num_beam_groups):
+ group_start_idx = beam_group_idx * num_sub_beams
+ group_end_idx = min(group_start_idx + num_sub_beams, num_beams)
+ group_size = group_end_idx - group_start_idx
+
+ # indices of beams of current group among all sentences in batch
+ batch_group_indices = []
+
+ for batch_idx in range(batch_size):
+ batch_group_indices.extend(
+ [batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)]
+ )
+ group_input_ids = input_ids[batch_group_indices]
+
+ # select outputs of beams of currentg group only
+ next_token_logits = outputs['logits'][batch_group_indices, -1, :]
+ vocab_size = next_token_logits.shape[-1]
+
+ next_token_scores_processed = logits_processor(
+ group_input_ids, next_token_logits, current_tokens=current_tokens, beam_group_idx=beam_group_idx
+ )
+ next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1)
+ next_token_scores = next_token_scores.expand_as(next_token_scores_processed)
+
+ # reshape for beam search
+ next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size)
+
+ next_token_scores, next_tokens = torch.topk(
+ next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True
+ )
+
+ next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
+ next_tokens = next_tokens % vocab_size
+
+ # stateless
+ process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
+ beam_outputs = beam_scorer.process(
+ group_input_ids,
+ next_token_scores,
+ next_tokens,
+ next_indices,
+ pad_token_id=pad_token_id,
+ eos_token_id=eos_token_id,
+ beam_indices=process_beam_indices,
+ )
+ beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"]
+ beam_next_tokens = beam_outputs["next_beam_tokens"]
+ beam_idx = beam_outputs["next_beam_indices"]
+
+ input_ids[batch_group_indices] = group_input_ids[beam_idx]
+ group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
+ current_tokens[batch_group_indices] = group_input_ids[:, -1]
+
+ # (beam_idx // group_size) -> batch_idx
+ # (beam_idx % group_size) -> offset of idx inside the group
+ reordering_indices[batch_group_indices] = (
+ num_beams * torch.div(beam_idx, group_size, rounding_mode="floor") + group_start_idx + (beam_idx % group_size)
+ )
+
+ input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1)
+
+ # increase cur_len
+ cur_len = cur_len + 1
+ if beam_scorer.is_done or stopping_criteria(input_ids, None):
+ break
+
+ final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
+ sequence_outputs = beam_scorer.finalize(
+ input_ids,
+ beam_scores,
+ next_tokens,
+ next_indices,
+ pad_token_id=pad_token_id,
+ eos_token_id=eos_token_id,
+ max_length=stopping_criteria.max_length,
+ beam_indices=final_beam_indices,
+ )
+ return sequence_outputs['sequences']
+
+
+def prepare_inputs_for_generation(input_ids, image_inputs, past=None, **kwargs):
+ if past:
+ input_ids = input_ids[:, -1].unsqueeze(-1)
+
+ attention_mask = kwargs.get("attention_mask", None)
+ position_ids = kwargs.get("position_ids", None)
+
+ if attention_mask is not None and position_ids is None:
+ # create position_ids on the fly for batch generation
+ position_ids = attention_mask.long().cumsum(-1) - 1
+ position_ids.masked_fill_(attention_mask == 0, 1)
+ else:
+ position_ids = None
+ return {
+ "text": input_ids,
+ "images": image_inputs,
+ "past_key_values": past,
+ "position_ids": position_ids,
+ "attention_mask": attention_mask,
+ }
diff --git a/diffsynth/extensions/ImageQualityMetric/open_clip/constants.py b/diffsynth/extensions/ImageQualityMetric/open_clip/constants.py
new file mode 100644
index 0000000000000000000000000000000000000000..a670bb3fab442baeb9af53b91c312e6982af57ee
--- /dev/null
+++ b/diffsynth/extensions/ImageQualityMetric/open_clip/constants.py
@@ -0,0 +1,2 @@
+OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
+OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
diff --git a/diffsynth/extensions/ImageQualityMetric/open_clip/factory.py b/diffsynth/extensions/ImageQualityMetric/open_clip/factory.py
new file mode 100644
index 0000000000000000000000000000000000000000..5bd51a1bb6b69e0e69147c8b7cb8d7bd4899b349
--- /dev/null
+++ b/diffsynth/extensions/ImageQualityMetric/open_clip/factory.py
@@ -0,0 +1,433 @@
+import json
+import logging
+import os
+import pathlib
+import re
+from copy import deepcopy
+from pathlib import Path
+# from turtle import forward
+from typing import Any, Dict, Optional, Tuple, Union
+
+import torch
+
+from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
+from .model import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\
+ resize_pos_embed, get_cast_dtype
+from .coca_model import CoCa
+from .loss import ClipLoss, DistillClipLoss, CoCaLoss
+from .openai import load_openai_model
+from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained, list_pretrained_tags_by_model, download_pretrained_from_hf
+from .transform import image_transform, AugmentationCfg
+from .tokenizer import HFTokenizer, SimpleTokenizer
+
+
+HF_HUB_PREFIX = 'hf-hub:'
+_MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
+_MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
+
+
+def _natural_key(string_):
+ return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
+
+
+def _rescan_model_configs():
+ global _MODEL_CONFIGS
+
+ config_ext = ('.json',)
+ config_files = []
+ for config_path in _MODEL_CONFIG_PATHS:
+ if config_path.is_file() and config_path.suffix in config_ext:
+ config_files.append(config_path)
+ elif config_path.is_dir():
+ for ext in config_ext:
+ config_files.extend(config_path.glob(f'*{ext}'))
+
+ for cf in config_files:
+ with open(cf, 'r') as f:
+ model_cfg = json.load(f)
+ if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')):
+ _MODEL_CONFIGS[cf.stem] = model_cfg
+
+ _MODEL_CONFIGS = {k: v for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))}
+
+
+_rescan_model_configs() # initial populate of model config registry
+
+
+def list_models():
+ """ enumerate available model architectures based on config files """
+ return list(_MODEL_CONFIGS.keys())
+
+
+def add_model_config(path):
+ """ add model config path or file and update registry """
+ if not isinstance(path, Path):
+ path = Path(path)
+ _MODEL_CONFIG_PATHS.append(path)
+ _rescan_model_configs()
+
+
+def get_model_config(model_name):
+ if model_name in _MODEL_CONFIGS:
+ return deepcopy(_MODEL_CONFIGS[model_name])
+ else:
+ return None
+
+
+def get_tokenizer(model_name, open_clip_bpe_path=None):
+ if model_name.startswith(HF_HUB_PREFIX):
+ tokenizer = HFTokenizer(model_name[len(HF_HUB_PREFIX):])
+ else:
+ config = get_model_config(model_name)
+ tokenizer = HFTokenizer(
+ config['text_cfg']['hf_tokenizer_name']) if 'hf_tokenizer_name' in config['text_cfg'] else SimpleTokenizer(open_clip_bpe_path)
+ return tokenizer
+
+
+def load_state_dict(checkpoint_path: str, map_location='cpu'):
+ checkpoint = torch.load(checkpoint_path, map_location=map_location)
+ if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
+ state_dict = checkpoint['state_dict']
+ else:
+ state_dict = checkpoint
+ if next(iter(state_dict.items()))[0].startswith('module'):
+ state_dict = {k[7:]: v for k, v in state_dict.items()}
+ return state_dict
+
+
+def load_checkpoint(model, checkpoint_path, strict=True):
+ state_dict = load_state_dict(checkpoint_path)
+ # detect old format and make compatible with new format
+ if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'):
+ state_dict = convert_to_custom_text_state_dict(state_dict)
+ resize_pos_embed(state_dict, model)
+ incompatible_keys = model.load_state_dict(state_dict, strict=strict)
+ return incompatible_keys
+
+
+def create_model(
+ model_name: str,
+ pretrained: Optional[str] = None,
+ precision: str = 'fp32',
+ device: Union[str, torch.device] = 'cpu',
+ jit: bool = False,
+ force_quick_gelu: bool = False,
+ force_custom_text: bool = False,
+ force_patch_dropout: Optional[float] = None,
+ force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
+ pretrained_image: bool = False,
+ pretrained_hf: bool = True,
+ cache_dir: Optional[str] = None,
+ output_dict: Optional[bool] = None,
+ require_pretrained: bool = False,
+):
+ has_hf_hub_prefix = model_name.startswith(HF_HUB_PREFIX)
+ if has_hf_hub_prefix:
+ model_id = model_name[len(HF_HUB_PREFIX):]
+ checkpoint_path = download_pretrained_from_hf(model_id, cache_dir=cache_dir)
+ config_path = download_pretrained_from_hf(model_id, filename='open_clip_config.json', cache_dir=cache_dir)
+
+ with open(config_path, 'r', encoding='utf-8') as f:
+ config = json.load(f)
+ pretrained_cfg = config['preprocess_cfg']
+ model_cfg = config['model_cfg']
+ else:
+ model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names
+ checkpoint_path = None
+ pretrained_cfg = {}
+ model_cfg = None
+
+ if isinstance(device, str):
+ device = torch.device(device)
+
+ if pretrained and pretrained.lower() == 'openai':
+ logging.info(f'Loading pretrained {model_name} from OpenAI.')
+ model = load_openai_model(
+ model_name,
+ precision=precision,
+ device=device,
+ jit=jit,
+ cache_dir=cache_dir,
+ )
+
+ # to always output dict even if it is clip
+ if output_dict and hasattr(model, "output_dict"):
+ model.output_dict = True
+ else:
+ model_cfg = model_cfg or get_model_config(model_name)
+ if model_cfg is not None:
+ logging.info(f'Loaded {model_name} model config.')
+ else:
+ logging.error(f'Model config for {model_name} not found; available models {list_models()}.')
+ raise RuntimeError(f'Model config for {model_name} not found.')
+
+ if force_quick_gelu:
+ # override for use of QuickGELU on non-OpenAI transformer models
+ model_cfg["quick_gelu"] = True
+
+ if force_patch_dropout is not None:
+ # override the default patch dropout value
+ model_cfg["vision_cfg"]["patch_dropout"] = force_patch_dropout
+
+ if force_image_size is not None:
+ # override model config's image size
+ model_cfg["vision_cfg"]["image_size"] = force_image_size
+
+ if pretrained_image:
+ if 'timm_model_name' in model_cfg.get('vision_cfg', {}):
+ # pretrained weight loading for timm models set via vision_cfg
+ model_cfg['vision_cfg']['timm_model_pretrained'] = True
+ else:
+ assert False, 'pretrained image towers currently only supported for timm models'
+
+ cast_dtype = get_cast_dtype(precision)
+ is_hf_model = 'hf_model_name' in model_cfg.get('text_cfg', {})
+ custom_text = model_cfg.pop('custom_text', False) or force_custom_text or is_hf_model
+
+ if custom_text:
+ if is_hf_model:
+ model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf
+ if "coca" in model_name:
+ model = CoCa(**model_cfg, cast_dtype=cast_dtype)
+ else:
+ model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype)
+ else:
+ model = CLIP(**model_cfg, cast_dtype=cast_dtype)
+
+ pretrained_loaded = False
+ if pretrained:
+ checkpoint_path = ''
+ pretrained_cfg = get_pretrained_cfg(model_name, pretrained)
+ if pretrained_cfg:
+ checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir)
+ elif os.path.exists(pretrained):
+ checkpoint_path = pretrained
+
+ if checkpoint_path:
+ logging.info(f'Loading pretrained {model_name} weights ({pretrained}).')
+ load_checkpoint(model, checkpoint_path)
+ else:
+ error_str = (
+ f'Pretrained weights ({pretrained}) not found for model {model_name}.'
+ f'Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.')
+ logging.warning(error_str)
+ raise RuntimeError(error_str)
+ pretrained_loaded = True
+ elif has_hf_hub_prefix:
+ logging.info(f'Loading pretrained {model_name} weights ({pretrained}).')
+ load_checkpoint(model, checkpoint_path)
+ pretrained_loaded = True
+
+ if require_pretrained and not pretrained_loaded:
+ # callers of create_model_from_pretrained always expect pretrained weights
+ raise RuntimeError(
+ f'Pretrained weights were required for (model: {model_name}, pretrained: {pretrained}) but not loaded.')
+
+ model.to(device=device)
+ if precision in ("fp16", "bf16"):
+ convert_weights_to_lp(model, dtype=torch.bfloat16 if precision == 'bf16' else torch.float16)
+
+ # set image / mean metadata from pretrained_cfg if available, or use default
+ model.visual.image_mean = pretrained_cfg.get('mean', None) or OPENAI_DATASET_MEAN
+ model.visual.image_std = pretrained_cfg.get('std', None) or OPENAI_DATASET_STD
+
+ # to always output dict even if it is clip
+ if output_dict and hasattr(model, "output_dict"):
+ model.output_dict = True
+
+ if jit:
+ model = torch.jit.script(model)
+
+ return model
+
+
+def create_loss(args):
+ if args.distill:
+ return DistillClipLoss(
+ local_loss=args.local_loss,
+ gather_with_grad=args.gather_with_grad,
+ cache_labels=True,
+ rank=args.rank,
+ world_size=args.world_size,
+ use_horovod=args.horovod,
+ )
+ elif "coca" in args.model.lower():
+ return CoCaLoss(
+ caption_loss_weight=args.coca_caption_loss_weight,
+ clip_loss_weight=args.coca_contrastive_loss_weight,
+ local_loss=args.local_loss,
+ gather_with_grad=args.gather_with_grad,
+ cache_labels=True,
+ rank=args.rank,
+ world_size=args.world_size,
+ use_horovod=args.horovod,
+ )
+ return ClipLoss(
+ local_loss=args.local_loss,
+ gather_with_grad=args.gather_with_grad,
+ cache_labels=True,
+ rank=args.rank,
+ world_size=args.world_size,
+ use_horovod=args.horovod,
+ )
+
+class MLP(torch.nn.Module):
+ def __init__(self, input_size):
+ super().__init__()
+ self.input_size = input_size
+ self.layers = torch.nn.Sequential(
+ torch.nn.Linear(self.input_size, 1024),
+ torch.nn.Dropout(0.2),
+ torch.nn.Linear(1024, 128),
+ torch.nn.Dropout(0.2),
+ torch.nn.Linear(128, 64),
+ torch.nn.Dropout(0.1),
+ torch.nn.Linear(64, 16),
+ torch.nn.Linear(16, 1)
+ )
+
+ def forward(self, x):
+ return self.layers(x)
+
+# class semantic_head(torch.nn.Module):
+# def __init__(self, input_size):
+# super().__init__()
+# self.input_size = input_size # for ViT-L-14 is 1024
+# self.seg_head = torch.nn.Sequential(
+# torch.nn.Linear(input_size, 128),
+# torch.nn.Dropout(0.2),
+# torch.nn.Linear(128, 64),
+# torch.nn.Dropout(0.1),
+# torch.nn.Linear(64, 16),
+# torch.nn.Linear(16, 1),
+# )
+# self.sigmoid = torch.nn.Sigmoid()
+
+# def forward(self, x):
+# return self.sigmoid(self.seg_head(x))
+
+def create_model_and_transforms(
+ model_name: str,
+ pretrained: Optional[str] = None,
+ precision: str = 'fp32',
+ device: Union[str, torch.device] = 'cpu',
+ jit: bool = False,
+ force_quick_gelu: bool = False,
+ force_custom_text: bool = False,
+ force_patch_dropout: Optional[float] = None,
+ force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
+ pretrained_image: bool = False,
+ pretrained_hf: bool = True,
+ image_mean: Optional[Tuple[float, ...]] = None,
+ image_std: Optional[Tuple[float, ...]] = None,
+ aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
+ cache_dir: Optional[str] = None,
+ light_augmentation = False,
+ output_dict: Optional[bool] = None,
+ with_score_predictor: bool = False,
+ with_region_predictor: bool = False
+):
+ model = create_model(
+ model_name,
+ pretrained,
+ precision=precision,
+ device=device,
+ jit=jit,
+ force_quick_gelu=force_quick_gelu,
+ force_custom_text=force_custom_text,
+ force_patch_dropout=force_patch_dropout,
+ force_image_size=force_image_size,
+ pretrained_image=pretrained_image,
+ pretrained_hf=pretrained_hf,
+ cache_dir=cache_dir,
+ output_dict=output_dict,
+ )
+
+ image_mean = image_mean or getattr(model.visual, 'image_mean', None)
+ image_std = image_std or getattr(model.visual, 'image_std', None)
+
+ if with_score_predictor:
+ model.score_predictor = MLP(model.visual.proj.size(1)).to(device=device, dtype=model.visual.proj.dtype)
+
+ if with_region_predictor:
+ # model.region_predictor = semantic_head(model.visual.proj.size(1)).to(device=device, dtype=model.visual.proj.dtype)
+ model.region_predictor = torch.nn.Linear(model.visual.proj.size(0), 1).to(device=device, dtype=model.visual.proj.dtype)
+ # preprocess_train = image_transform_region(
+ # model.visual.image_size,
+ # is_train=True,
+ # mean=image_mean,
+ # std=image_std
+ # )
+ # preprocess_val = image_transform_region(
+ # model.visual.image_size,
+ # is_train=False,
+ # mean=image_mean,
+ # std=image_std
+ # )
+
+ if light_augmentation:
+ preprocess_val = image_transform(
+ model.visual.image_size,
+ is_train=False,
+ mean=image_mean,
+ std=image_std,
+ resize_longest_max=True,
+ )
+ preprocess_train = preprocess_val
+ else:
+ preprocess_train = image_transform(
+ model.visual.image_size,
+ is_train=True,
+ mean=image_mean,
+ std=image_std
+ )
+ preprocess_val = image_transform(
+ model.visual.image_size,
+ is_train=False,
+ mean=image_mean,
+ std=image_std
+ )
+
+ return model, preprocess_train, preprocess_val
+
+
+def create_model_from_pretrained(
+ model_name: str,
+ pretrained: Optional[str] = None,
+ precision: str = 'fp32',
+ device: Union[str, torch.device] = 'cpu',
+ jit: bool = False,
+ force_quick_gelu: bool = False,
+ force_custom_text: bool = False,
+ force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
+ return_transform: bool = True,
+ image_mean: Optional[Tuple[float, ...]] = None,
+ image_std: Optional[Tuple[float, ...]] = None,
+ cache_dir: Optional[str] = None,
+):
+ model = create_model(
+ model_name,
+ pretrained,
+ precision=precision,
+ device=device,
+ jit=jit,
+ force_quick_gelu=force_quick_gelu,
+ force_custom_text=force_custom_text,
+ force_image_size=force_image_size,
+ cache_dir=cache_dir,
+ require_pretrained=True,
+ )
+
+ if not return_transform:
+ return model
+
+ image_mean = image_mean or getattr(model.visual, 'image_mean', None)
+ image_std = image_std or getattr(model.visual, 'image_std', None)
+ preprocess = image_transform(
+ model.visual.image_size,
+ is_train=False,
+ mean=image_mean,
+ std=image_std,
+ )
+
+ return model, preprocess
diff --git a/diffsynth/extensions/ImageQualityMetric/open_clip/generation_utils.py b/diffsynth/extensions/ImageQualityMetric/open_clip/generation_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/diffsynth/extensions/ImageQualityMetric/open_clip/hf_configs.py b/diffsynth/extensions/ImageQualityMetric/open_clip/hf_configs.py
new file mode 100644
index 0000000000000000000000000000000000000000..e236222bafce0358445ea16953ca0b2d5a84758a
--- /dev/null
+++ b/diffsynth/extensions/ImageQualityMetric/open_clip/hf_configs.py
@@ -0,0 +1,45 @@
+# HF architecture dict:
+arch_dict = {
+ # https://huggingface.co/docs/transformers/model_doc/roberta#roberta
+ "roberta": {
+ "config_names": {
+ "context_length": "max_position_embeddings",
+ "vocab_size": "vocab_size",
+ "width": "hidden_size",
+ "heads": "num_attention_heads",
+ "layers": "num_hidden_layers",
+ "layer_attr": "layer",
+ "token_embeddings_attr": "embeddings"
+ },
+ "pooler": "mean_pooler",
+ },
+ # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig
+ "xlm-roberta": {
+ "config_names": {
+ "context_length": "max_position_embeddings",
+ "vocab_size": "vocab_size",
+ "width": "hidden_size",
+ "heads": "num_attention_heads",
+ "layers": "num_hidden_layers",
+ "layer_attr": "layer",
+ "token_embeddings_attr": "embeddings"
+ },
+ "pooler": "mean_pooler",
+ },
+ # https://huggingface.co/docs/transformers/model_doc/mt5#mt5
+ "mt5": {
+ "config_names": {
+ # unlimited seqlen
+ # https://github.com/google-research/text-to-text-transfer-transformer/issues/273
+ # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374
+ "context_length": "",
+ "vocab_size": "vocab_size",
+ "width": "d_model",
+ "heads": "num_heads",
+ "layers": "num_layers",
+ "layer_attr": "block",
+ "token_embeddings_attr": "embed_tokens"
+ },
+ "pooler": "mean_pooler",
+ },
+}
diff --git a/diffsynth/extensions/ImageQualityMetric/open_clip/hf_model.py b/diffsynth/extensions/ImageQualityMetric/open_clip/hf_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..fbccc812757bf10b122ff14096980e0e38d1d221
--- /dev/null
+++ b/diffsynth/extensions/ImageQualityMetric/open_clip/hf_model.py
@@ -0,0 +1,176 @@
+""" huggingface model adapter
+
+Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model.
+"""
+
+import re
+
+import torch
+import torch.nn as nn
+from torch import TensorType
+
+try:
+ import transformers
+ from transformers import AutoModel, AutoTokenizer, AutoConfig, PretrainedConfig
+ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \
+ BaseModelOutputWithPoolingAndCrossAttentions
+except ImportError as e:
+ transformers = None
+
+
+ class BaseModelOutput:
+ pass
+
+
+ class PretrainedConfig:
+ pass
+
+from .hf_configs import arch_dict
+
+
+# utils
+def _camel2snake(s):
+ return re.sub(r'(? torch.Tensor:
+ # calculated ground-truth and cache if enabled
+ if self.prev_num_logits != num_logits or device not in self.labels:
+ labels = torch.arange(num_logits, device=device, dtype=torch.long)
+ if self.world_size > 1 and self.local_loss:
+ labels = labels + num_logits * self.rank
+ if self.cache_labels:
+ self.labels[device] = labels
+ self.prev_num_logits = num_logits
+ else:
+ labels = self.labels[device]
+ return labels
+
+ def get_logits(self, image_features, text_features, logit_scale):
+ if self.world_size > 1:
+ all_image_features, all_text_features = gather_features(
+ image_features, text_features,
+ self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod)
+
+ if self.local_loss:
+ logits_per_image = logit_scale * image_features @ all_text_features.T
+ logits_per_text = logit_scale * text_features @ all_image_features.T
+ else:
+ logits_per_image = logit_scale * all_image_features @ all_text_features.T
+ logits_per_text = logits_per_image.T
+ else:
+ logits_per_image = logit_scale * image_features @ text_features.T
+ logits_per_text = logit_scale * text_features @ image_features.T
+
+ return logits_per_image, logits_per_text
+
+ def forward(self, image_features, text_features, logit_scale, output_dict=False):
+ device = image_features.device
+ logits_per_image, logits_per_text = self.get_logits(image_features, text_features, logit_scale)
+
+ labels = self.get_ground_truth(device, logits_per_image.shape[0])
+
+ total_loss = (
+ F.cross_entropy(logits_per_image, labels) +
+ F.cross_entropy(logits_per_text, labels)
+ ) / 2
+ return total_loss
+
+class PreferenceLoss(nn.Module):
+
+ def forward(self, logits_per_image, num_images, labels):
+
+ paired_logits_list = [logit[:,i] for i, logit in enumerate(logits_per_image.split(num_images.tolist()))]
+ paired_logits = pad_sequence(paired_logits_list, batch_first=True, padding_value=-999)
+
+ ce_loss = F.cross_entropy(paired_logits, labels)
+ return ce_loss
+
+class HPSLoss(nn.Module):
+
+ def forward(self, text_logits, labels):
+
+ device = text_logits.device
+ text_0_logits, text_1_logits = text_logits.chunk(2, dim=-1)
+ label_0, label_1 = labels.chunk(2, dim=-1)
+
+ index = torch.arange(text_0_logits.shape[0], device=device, dtype=torch.long)
+ text_0_logits = text_0_logits[index, index]
+ text_1_logits = text_1_logits[index, index]
+ text_logits = torch.stack([text_0_logits, text_1_logits], dim=-1)
+ text_0_labels = torch.zeros(text_logits.shape[0], device=device, dtype=torch.long)
+ text_1_labels = text_0_labels + 1
+
+ text_0_loss = torch.nn.functional.cross_entropy(text_logits, text_0_labels, reduction="none")
+ text_1_loss = torch.nn.functional.cross_entropy(text_logits, text_1_labels, reduction="none")
+
+ text_loss = label_0 * text_0_loss + label_1 * text_1_loss
+
+ # absolute_example_weight = 1 / num_per_prompt
+ # denominator = absolute_example_weight.sum()
+ # weight_per_example = absolute_example_weight / denominator
+ # text_loss *= weight_per_example
+
+ text_loss = text_loss.sum()
+ return text_loss
+
+class RankingLoss(nn.Module):
+
+ def forward(self, logits_per_image, num_images, labels, margin = 1.0):
+ paired_logits_list = [logit[:,i] for i, logit in enumerate(logits_per_image.split(num_images.tolist()))]
+ label_list = [label for label in labels.split(num_images.tolist())]
+ # ranked_logits = [torch.index_select(paired_logits_list[i], 0, rank) for i, rank in enumerate(label_list)]
+
+ paired_logits = pad_sequence(paired_logits_list, batch_first=True, padding_value=-1)
+ padded_labels = pad_sequence(label_list, batch_first=True, padding_value=10)
+
+ # regulized_logits = torch.log(torch.sigmoid(paired_logits))
+
+ diff = paired_logits.unsqueeze(1) - paired_logits.unsqueeze(2)
+ # diff = paired_logits.unsqueeze(1) - paired_logits.unsqueeze(2)
+ # diff_label = torch.clamp(padded_labels.unsqueeze(1) - padded_labels.unsqueeze(2), min=-1, max=1)
+ diff_label = - (padded_labels.unsqueeze(1) - padded_labels.unsqueeze(2))
+ mask = torch.triu(torch.ones(diff.shape[1], diff.shape[1]), diagonal=1).bool().detach()
+
+ loss = torch.clamp(margin - torch.mul(diff[:, ~mask],diff_label[:,~mask]), min=0).mean()
+ return loss
+
+class CoCaLoss(ClipLoss):
+ def __init__(
+ self,
+ caption_loss_weight,
+ clip_loss_weight,
+ pad_id=0, # pad_token for open_clip custom tokenizer
+ local_loss=False,
+ gather_with_grad=False,
+ cache_labels=False,
+ rank=0,
+ world_size=1,
+ use_horovod=False,
+ ):
+ super().__init__(
+ local_loss=local_loss,
+ gather_with_grad=gather_with_grad,
+ cache_labels=cache_labels,
+ rank=rank,
+ world_size=world_size,
+ use_horovod=use_horovod
+ )
+
+ self.clip_loss_weight = clip_loss_weight
+ self.caption_loss_weight = caption_loss_weight
+ self.caption_loss = nn.CrossEntropyLoss(ignore_index=pad_id)
+
+ def forward(self, image_features, text_features, logits, labels, logit_scale, output_dict=False):
+ clip_loss = super().forward(image_features, text_features, logit_scale)
+ clip_loss = self.clip_loss_weight * clip_loss
+
+ caption_loss = self.caption_loss(
+ logits.permute(0, 2, 1),
+ labels,
+ )
+ caption_loss = caption_loss * self.caption_loss_weight
+
+ if output_dict:
+ return {"contrastive_loss": clip_loss, "caption_loss": caption_loss}
+
+ return clip_loss, caption_loss
+
+
+class DistillClipLoss(ClipLoss):
+
+ def dist_loss(self, teacher_logits, student_logits):
+ return -(teacher_logits.softmax(dim=1) * student_logits.log_softmax(dim=1)).sum(dim=1).mean(dim=0)
+
+ def forward(
+ self,
+ image_features,
+ text_features,
+ logit_scale,
+ dist_image_features,
+ dist_text_features,
+ dist_logit_scale,
+ output_dict=False,
+ ):
+ logits_per_image, logits_per_text = \
+ self.get_logits(image_features, text_features, logit_scale)
+
+ dist_logits_per_image, dist_logits_per_text = \
+ self.get_logits(dist_image_features, dist_text_features, dist_logit_scale)
+
+ labels = self.get_ground_truth(image_features.device, logits_per_image.shape[0])
+
+ contrastive_loss = (
+ F.cross_entropy(logits_per_image, labels) +
+ F.cross_entropy(logits_per_text, labels)
+ ) / 2
+
+ distill_loss = (
+ self.dist_loss(dist_logits_per_image, logits_per_image) +
+ self.dist_loss(dist_logits_per_text, logits_per_text)
+ ) / 2
+
+ if output_dict:
+ return {"contrastive_loss": contrastive_loss, "distill_loss": distill_loss}
+
+ return contrastive_loss, distill_loss
diff --git a/diffsynth/extensions/ImageQualityMetric/open_clip/model.py b/diffsynth/extensions/ImageQualityMetric/open_clip/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..e347c42fc8df6464ca28e59adadba61e53a38add
--- /dev/null
+++ b/diffsynth/extensions/ImageQualityMetric/open_clip/model.py
@@ -0,0 +1,461 @@
+""" CLIP Model
+
+Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
+"""
+from dataclasses import dataclass
+import logging
+import math
+from typing import Optional, Tuple, Union
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import nn
+from torch.utils.checkpoint import checkpoint
+
+from .hf_model import HFTextEncoder
+from .modified_resnet import ModifiedResNet
+from .timm_model import TimmModel
+from .transformer import LayerNormFp32, LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer
+from .utils import to_2tuple
+
+
+@dataclass
+class CLIPVisionCfg:
+ layers: Union[Tuple[int, int, int, int], int] = 12
+ width: int = 768
+ head_width: int = 64
+ mlp_ratio: float = 4.0
+ patch_size: int = 16
+ image_size: Union[Tuple[int, int], int] = 224
+ ls_init_value: Optional[float] = None # layer scale initial value
+ patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
+ input_patchnorm: bool = False # whether to use dual patchnorm - would only apply the input layernorm on each patch, as post-layernorm already exist in original clip vit design
+ global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580)
+ attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer
+ n_queries: int = 256 # n_queries for attentional pooler
+ attn_pooler_heads: int = 8 # n heads for attentional_pooling
+ timm_model_name: str = None # a valid model name overrides layers, width, patch_size
+ timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model
+ timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
+ timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '')
+ timm_proj_bias: bool = False # enable bias final projection
+ timm_drop: float = 0. # head dropout
+ timm_drop_path: Optional[float] = None # backbone stochastic depth
+ output_tokens: bool = False
+
+
+@dataclass
+class CLIPTextCfg:
+ context_length: int = 77
+ vocab_size: int = 49408
+ width: int = 512
+ heads: int = 8
+ layers: int = 12
+ ls_init_value: Optional[float] = None # layer scale initial value
+ hf_model_name: str = None
+ hf_tokenizer_name: str = None
+ hf_model_pretrained: bool = True
+ proj: str = 'mlp'
+ pooler_type: str = 'mean_pooler'
+ embed_cls: bool = False
+ pad_id: int = 0
+ output_tokens: bool = False
+
+
+def get_cast_dtype(precision: str):
+ cast_dtype = None
+ if precision == 'bf16':
+ cast_dtype = torch.bfloat16
+ elif precision == 'fp16':
+ cast_dtype = torch.float16
+ return cast_dtype
+
+
+def _build_vision_tower(
+ embed_dim: int,
+ vision_cfg: CLIPVisionCfg,
+ quick_gelu: bool = False,
+ cast_dtype: Optional[torch.dtype] = None
+):
+ if isinstance(vision_cfg, dict):
+ vision_cfg = CLIPVisionCfg(**vision_cfg)
+
+ # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more
+ # memory efficient in recent PyTorch releases (>= 1.10).
+ # NOTE: timm models always use native GELU regardless of quick_gelu flag.
+ act_layer = QuickGELU if quick_gelu else nn.GELU
+
+ if vision_cfg.timm_model_name:
+ visual = TimmModel(
+ vision_cfg.timm_model_name,
+ pretrained=vision_cfg.timm_model_pretrained,
+ pool=vision_cfg.timm_pool,
+ proj=vision_cfg.timm_proj,
+ proj_bias=vision_cfg.timm_proj_bias,
+ drop=vision_cfg.timm_drop,
+ drop_path=vision_cfg.timm_drop_path,
+ embed_dim=embed_dim,
+ image_size=vision_cfg.image_size,
+ )
+ act_layer = nn.GELU # so that text transformer doesn't use QuickGELU w/ timm models
+ elif isinstance(vision_cfg.layers, (tuple, list)):
+ vision_heads = vision_cfg.width * 32 // vision_cfg.head_width
+ visual = ModifiedResNet(
+ layers=vision_cfg.layers,
+ output_dim=embed_dim,
+ heads=vision_heads,
+ image_size=vision_cfg.image_size,
+ width=vision_cfg.width,
+ )
+ else:
+ vision_heads = vision_cfg.width // vision_cfg.head_width
+ norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
+ visual = VisionTransformer(
+ image_size=vision_cfg.image_size,
+ patch_size=vision_cfg.patch_size,
+ width=vision_cfg.width,
+ layers=vision_cfg.layers,
+ heads=vision_heads,
+ mlp_ratio=vision_cfg.mlp_ratio,
+ ls_init_value=vision_cfg.ls_init_value,
+ patch_dropout=vision_cfg.patch_dropout,
+ input_patchnorm=vision_cfg.input_patchnorm,
+ global_average_pool=vision_cfg.global_average_pool,
+ attentional_pool=vision_cfg.attentional_pool,
+ n_queries=vision_cfg.n_queries,
+ attn_pooler_heads=vision_cfg.attn_pooler_heads,
+ output_tokens=vision_cfg.output_tokens,
+ output_dim=embed_dim,
+ act_layer=act_layer,
+ norm_layer=norm_layer,
+ )
+
+ return visual
+
+
+def _build_text_tower(
+ embed_dim: int,
+ text_cfg: CLIPTextCfg,
+ quick_gelu: bool = False,
+ cast_dtype: Optional[torch.dtype] = None,
+):
+ if isinstance(text_cfg, dict):
+ text_cfg = CLIPTextCfg(**text_cfg)
+
+ if text_cfg.hf_model_name:
+ text = HFTextEncoder(
+ text_cfg.hf_model_name,
+ output_dim=embed_dim,
+ proj=text_cfg.proj,
+ pooler_type=text_cfg.pooler_type,
+ pretrained=text_cfg.hf_model_pretrained,
+ output_tokens=text_cfg.output_tokens,
+ )
+ else:
+ act_layer = QuickGELU if quick_gelu else nn.GELU
+ norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
+
+ text = TextTransformer(
+ context_length=text_cfg.context_length,
+ vocab_size=text_cfg.vocab_size,
+ width=text_cfg.width,
+ heads=text_cfg.heads,
+ layers=text_cfg.layers,
+ ls_init_value=text_cfg.ls_init_value,
+ output_dim=embed_dim,
+ embed_cls=text_cfg.embed_cls,
+ output_tokens=text_cfg.output_tokens,
+ pad_id=text_cfg.pad_id,
+ act_layer=act_layer,
+ norm_layer=norm_layer,
+ )
+ return text
+
+
+class CLIP(nn.Module):
+ output_dict: torch.jit.Final[bool]
+
+ def __init__(
+ self,
+ embed_dim: int,
+ vision_cfg: CLIPVisionCfg,
+ text_cfg: CLIPTextCfg,
+ quick_gelu: bool = False,
+ cast_dtype: Optional[torch.dtype] = None,
+ output_dict: bool = False,
+ ):
+ super().__init__()
+ self.output_dict = output_dict
+ self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
+
+ text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
+ self.transformer = text.transformer
+ self.vocab_size = text.vocab_size
+ self.token_embedding = text.token_embedding
+ self.positional_embedding = text.positional_embedding
+ self.ln_final = text.ln_final
+ self.text_projection = text.text_projection
+ self.register_buffer('attn_mask', text.attn_mask, persistent=False)
+
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
+
+ def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
+ # lock image tower as per LiT - https://arxiv.org/abs/2111.07991
+ self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
+
+ def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
+ locked_layers = []
+ locked_layers.append(self.token_embedding)
+ self.positional_embedding.requires_grad = False
+ if unlocked_layers > 0:
+ locked_layers.append(self.transformer.resblocks[:-unlocked_layers])
+ else:
+ locked_layers.append(self.transformer)
+ locked_layers.append(self.ln_final)
+ self.text_projection.requires_grad = False
+
+ # freeze layers
+ for module in locked_layers:
+ for n, p in module.named_parameters():
+ p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
+
+ @torch.jit.ignore
+ def set_grad_checkpointing(self, enable=True):
+ self.visual.set_grad_checkpointing(enable)
+ self.transformer.grad_checkpointing = enable
+
+ def encode_image(self, image, normalize: bool = False):
+ features = self.visual(image)
+ return F.normalize(features, dim=-1) if normalize else features
+
+ def encode_text(self, text, normalize: bool = False):
+ cast_dtype = self.transformer.get_cast_dtype()
+
+ x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
+
+ x = x + self.positional_embedding.to(cast_dtype)
+ x = x.permute(1, 0, 2) # NLD -> LND
+ x = self.transformer(x, attn_mask=self.attn_mask)
+ x = x.permute(1, 0, 2) # LND -> NLD
+ x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
+ x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
+ return F.normalize(x, dim=-1) if normalize else x
+
+ def forward(self, image, text):
+ image_features = self.encode_image(image, normalize=True)
+ text_features = self.encode_text(text, normalize=True)
+ if self.output_dict:
+ return {
+ "image_features": image_features,
+ "text_features": text_features,
+ "logit_scale": self.logit_scale.exp()
+ }
+ return image_features, text_features, self.logit_scale.exp()
+
+
+class CustomTextCLIP(nn.Module):
+ output_dict: torch.jit.Final[bool]
+
+ def __init__(
+ self,
+ embed_dim: int,
+ vision_cfg: CLIPVisionCfg,
+ text_cfg: CLIPTextCfg,
+ quick_gelu: bool = False,
+ cast_dtype: Optional[torch.dtype] = None,
+ output_dict: bool = False,
+ ):
+ super().__init__()
+ self.output_dict = output_dict
+ self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
+ self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
+
+ def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
+ # lock image tower as per LiT - https://arxiv.org/abs/2111.07991
+ self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
+
+ def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
+ self.text.lock(unlocked_layers, freeze_layer_norm)
+
+ @torch.jit.ignore
+ def set_grad_checkpointing(self, enable=True):
+ self.visual.set_grad_checkpointing(enable)
+ self.text.set_grad_checkpointing(enable)
+
+ def encode_image(self, image, normalize: bool = False):
+ features = self.visual(image)
+ return F.normalize(features, dim=-1) if normalize else features
+
+ def encode_text(self, text, normalize: bool = False):
+ features = self.text(text)
+ return F.normalize(features, dim=-1) if normalize else features
+
+ def forward(self, image, text):
+ image_features = self.encode_image(image, normalize=True)
+ text_features = self.encode_text(text, normalize=True)
+ if self.output_dict:
+ return {
+ "image_features": image_features,
+ "text_features": text_features,
+ "logit_scale": self.logit_scale.exp()
+ }
+ return image_features, text_features, self.logit_scale.exp()
+
+
+def convert_weights_to_lp(model: nn.Module, dtype=torch.float16):
+ """Convert applicable model parameters to low-precision (bf16 or fp16)"""
+
+ def _convert_weights(l):
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
+ l.weight.data = l.weight.data.to(dtype)
+ if l.bias is not None:
+ l.bias.data = l.bias.data.to(dtype)
+
+ if isinstance(l, (nn.MultiheadAttention, Attention)):
+ for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
+ tensor = getattr(l, attr)
+ if tensor is not None:
+ tensor.data = tensor.data.to(dtype)
+
+ for name in ["text_projection", "proj"]:
+ if hasattr(l, name):
+ attr = getattr(l, name)
+ if attr is not None:
+ attr.data = attr.data.to(dtype)
+
+ model.apply(_convert_weights)
+
+
+convert_weights_to_fp16 = convert_weights_to_lp # backwards compat
+
+
+# used to maintain checkpoint compatibility
+def convert_to_custom_text_state_dict(state_dict: dict):
+ if 'text_projection' in state_dict:
+ # old format state_dict, move text tower -> .text
+ new_state_dict = {}
+ for k, v in state_dict.items():
+ if any(k.startswith(p) for p in (
+ 'text_projection',
+ 'positional_embedding',
+ 'token_embedding',
+ 'transformer',
+ 'ln_final',
+ )):
+ k = 'text.' + k
+ new_state_dict[k] = v
+ return new_state_dict
+ return state_dict
+
+
+def build_model_from_openai_state_dict(
+ state_dict: dict,
+ quick_gelu=True,
+ cast_dtype=torch.float16,
+):
+ vit = "visual.proj" in state_dict
+
+ if vit:
+ vision_width = state_dict["visual.conv1.weight"].shape[0]
+ vision_layers = len(
+ [k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
+ vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
+ grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
+ image_size = vision_patch_size * grid_size
+ else:
+ counts: list = [
+ len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
+ vision_layers = tuple(counts)
+ vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
+ output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
+ vision_patch_size = None
+ assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
+ image_size = output_width * 32
+
+ embed_dim = state_dict["text_projection"].shape[1]
+ context_length = state_dict["positional_embedding"].shape[0]
+ vocab_size = state_dict["token_embedding.weight"].shape[0]
+ transformer_width = state_dict["ln_final.weight"].shape[0]
+ transformer_heads = transformer_width // 64
+ transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
+
+ vision_cfg = CLIPVisionCfg(
+ layers=vision_layers,
+ width=vision_width,
+ patch_size=vision_patch_size,
+ image_size=image_size,
+ )
+ text_cfg = CLIPTextCfg(
+ context_length=context_length,
+ vocab_size=vocab_size,
+ width=transformer_width,
+ heads=transformer_heads,
+ layers=transformer_layers,
+ )
+ model = CLIP(
+ embed_dim,
+ vision_cfg=vision_cfg,
+ text_cfg=text_cfg,
+ quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU
+ cast_dtype=cast_dtype,
+ )
+
+ for key in ["input_resolution", "context_length", "vocab_size"]:
+ state_dict.pop(key, None)
+
+ convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16
+ model.load_state_dict(state_dict)
+ return model.eval()
+
+
+def trace_model(model, batch_size=256, device=torch.device('cpu')):
+ model.eval()
+ image_size = model.visual.image_size
+ example_images = torch.ones((batch_size, 3, image_size, image_size), device=device)
+ example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device)
+ model = torch.jit.trace_module(
+ model,
+ inputs=dict(
+ forward=(example_images, example_text),
+ encode_text=(example_text,),
+ encode_image=(example_images,)
+ ))
+ model.visual.image_size = image_size
+ return model
+
+
+def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', antialias: bool = True):
+ # Rescale the grid of position embeddings when loading from state_dict
+ old_pos_embed = state_dict.get('visual.positional_embedding', None)
+ if old_pos_embed is None or not hasattr(model.visual, 'grid_size'):
+ return
+ grid_size = to_2tuple(model.visual.grid_size)
+ extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
+ new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
+ if new_seq_len == old_pos_embed.shape[0]:
+ return
+
+ if extra_tokens:
+ pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
+ else:
+ pos_emb_tok, pos_emb_img = None, old_pos_embed
+ old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
+
+ logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
+ pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
+ pos_emb_img = F.interpolate(
+ pos_emb_img,
+ size=grid_size,
+ mode=interpolation,
+ antialias=antialias,
+ align_corners=False,
+ )
+ pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
+ if pos_emb_tok is not None:
+ new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
+ else:
+ new_pos_embed = pos_emb_img
+ state_dict['visual.positional_embedding'] = new_pos_embed
diff --git a/diffsynth/extensions/ImageQualityMetric/open_clip/model_configs/ViT-H-14.json b/diffsynth/extensions/ImageQualityMetric/open_clip/model_configs/ViT-H-14.json
new file mode 100644
index 0000000000000000000000000000000000000000..3e3a7e934e7f02e41f4829996c4950e05f015a74
--- /dev/null
+++ b/diffsynth/extensions/ImageQualityMetric/open_clip/model_configs/ViT-H-14.json
@@ -0,0 +1,17 @@
+{
+ "embed_dim": 1024,
+ "vision_cfg": {
+ "image_size": 224,
+ "layers": 32,
+ "width": 1280,
+ "head_width": 80,
+ "patch_size": 14
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 1024,
+ "heads": 16,
+ "layers": 24
+ }
+}
\ No newline at end of file
diff --git a/diffsynth/extensions/ImageQualityMetric/open_clip/modified_resnet.py b/diffsynth/extensions/ImageQualityMetric/open_clip/modified_resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..6a8d3aeda91ecb394303becbbfccc8acd8cddcd9
--- /dev/null
+++ b/diffsynth/extensions/ImageQualityMetric/open_clip/modified_resnet.py
@@ -0,0 +1,181 @@
+from collections import OrderedDict
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from .utils import freeze_batch_norm_2d
+
+
+class Bottleneck(nn.Module):
+ expansion = 4
+
+ def __init__(self, inplanes, planes, stride=1):
+ super().__init__()
+
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
+ self.bn1 = nn.BatchNorm2d(planes)
+ self.act1 = nn.ReLU(inplace=True)
+
+ self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
+ self.bn2 = nn.BatchNorm2d(planes)
+ self.act2 = nn.ReLU(inplace=True)
+
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
+
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
+ self.act3 = nn.ReLU(inplace=True)
+
+ self.downsample = None
+ self.stride = stride
+
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
+ # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
+ self.downsample = nn.Sequential(OrderedDict([
+ ("-1", nn.AvgPool2d(stride)),
+ ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
+ ("1", nn.BatchNorm2d(planes * self.expansion))
+ ]))
+
+ def forward(self, x: torch.Tensor):
+ identity = x
+
+ out = self.act1(self.bn1(self.conv1(x)))
+ out = self.act2(self.bn2(self.conv2(out)))
+ out = self.avgpool(out)
+ out = self.bn3(self.conv3(out))
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+ out = self.act3(out)
+ return out
+
+
+class AttentionPool2d(nn.Module):
+ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
+ super().__init__()
+ self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
+ self.num_heads = num_heads
+
+ def forward(self, x):
+ x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
+ x, _ = F.multi_head_attention_forward(
+ query=x, key=x, value=x,
+ embed_dim_to_check=x.shape[-1],
+ num_heads=self.num_heads,
+ q_proj_weight=self.q_proj.weight,
+ k_proj_weight=self.k_proj.weight,
+ v_proj_weight=self.v_proj.weight,
+ in_proj_weight=None,
+ in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
+ bias_k=None,
+ bias_v=None,
+ add_zero_attn=False,
+ dropout_p=0.,
+ out_proj_weight=self.c_proj.weight,
+ out_proj_bias=self.c_proj.bias,
+ use_separate_proj_weight=True,
+ training=self.training,
+ need_weights=False
+ )
+
+ return x[0]
+
+
+class ModifiedResNet(nn.Module):
+ """
+ A ResNet class that is similar to torchvision's but contains the following changes:
+ - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
+ - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
+ - The final pooling layer is a QKV attention instead of an average pool
+ """
+
+ def __init__(self, layers, output_dim, heads, image_size=224, width=64):
+ super().__init__()
+ self.output_dim = output_dim
+ self.image_size = image_size
+
+ # the 3-layer stem
+ self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(width // 2)
+ self.act1 = nn.ReLU(inplace=True)
+ self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
+ self.bn2 = nn.BatchNorm2d(width // 2)
+ self.act2 = nn.ReLU(inplace=True)
+ self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
+ self.bn3 = nn.BatchNorm2d(width)
+ self.act3 = nn.ReLU(inplace=True)
+ self.avgpool = nn.AvgPool2d(2)
+
+ # residual layers
+ self._inplanes = width # this is a *mutable* variable used during construction
+ self.layer1 = self._make_layer(width, layers[0])
+ self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
+ self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
+ self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
+
+ embed_dim = width * 32 # the ResNet feature dimension
+ self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim)
+
+ self.init_parameters()
+
+ def _make_layer(self, planes, blocks, stride=1):
+ layers = [Bottleneck(self._inplanes, planes, stride)]
+
+ self._inplanes = planes * Bottleneck.expansion
+ for _ in range(1, blocks):
+ layers.append(Bottleneck(self._inplanes, planes))
+
+ return nn.Sequential(*layers)
+
+ def init_parameters(self):
+ if self.attnpool is not None:
+ std = self.attnpool.c_proj.in_features ** -0.5
+ nn.init.normal_(self.attnpool.q_proj.weight, std=std)
+ nn.init.normal_(self.attnpool.k_proj.weight, std=std)
+ nn.init.normal_(self.attnpool.v_proj.weight, std=std)
+ nn.init.normal_(self.attnpool.c_proj.weight, std=std)
+
+ for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]:
+ for name, param in resnet_block.named_parameters():
+ if name.endswith("bn3.weight"):
+ nn.init.zeros_(param)
+
+ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
+ assert unlocked_groups == 0, 'partial locking not currently supported for this model'
+ for param in self.parameters():
+ param.requires_grad = False
+ if freeze_bn_stats:
+ freeze_batch_norm_2d(self)
+
+ @torch.jit.ignore
+ def set_grad_checkpointing(self, enable=True):
+ # FIXME support for non-transformer
+ pass
+
+ def stem(self, x):
+ x = self.act1(self.bn1(self.conv1(x)))
+ x = self.act2(self.bn2(self.conv2(x)))
+ x = self.act3(self.bn3(self.conv3(x)))
+ x = self.avgpool(x)
+ return x
+
+ def forward(self, x):
+ x = self.stem(x)
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+ x = self.attnpool(x)
+
+ return x
diff --git a/diffsynth/extensions/ImageQualityMetric/open_clip/openai.py b/diffsynth/extensions/ImageQualityMetric/open_clip/openai.py
new file mode 100644
index 0000000000000000000000000000000000000000..cc4e13e876d6a7a3463b457e62c517cb063b1356
--- /dev/null
+++ b/diffsynth/extensions/ImageQualityMetric/open_clip/openai.py
@@ -0,0 +1,144 @@
+""" OpenAI pretrained model functions
+
+Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
+"""
+
+import os
+import warnings
+from typing import List, Optional, Union
+
+import torch
+
+from .model import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype
+from .pretrained import get_pretrained_url, list_pretrained_models_by_tag, download_pretrained_from_url
+
+__all__ = ["list_openai_models", "load_openai_model"]
+
+
+def list_openai_models() -> List[str]:
+ """Returns the names of available CLIP models"""
+ return list_pretrained_models_by_tag('openai')
+
+
+def load_openai_model(
+ name: str,
+ precision: Optional[str] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ jit: bool = True,
+ cache_dir: Optional[str] = None,
+):
+ """Load a CLIP model
+
+ Parameters
+ ----------
+ name : str
+ A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
+ precision: str
+ Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'.
+ device : Union[str, torch.device]
+ The device to put the loaded model
+ jit : bool
+ Whether to load the optimized JIT model (default) or more hackable non-JIT model.
+ cache_dir : Optional[str]
+ The directory to cache the downloaded model weights
+
+ Returns
+ -------
+ model : torch.nn.Module
+ The CLIP model
+ preprocess : Callable[[PIL.Image], torch.Tensor]
+ A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
+ """
+ if device is None:
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ if precision is None:
+ precision = 'fp32' if device == 'cpu' else 'fp16'
+
+ if get_pretrained_url(name, 'openai'):
+ model_path = download_pretrained_from_url(get_pretrained_url(name, 'openai'), cache_dir=cache_dir)
+ elif os.path.isfile(name):
+ model_path = name
+ else:
+ raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}")
+
+ try:
+ # loading JIT archive
+ model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
+ state_dict = None
+ except RuntimeError:
+ # loading saved state dict
+ if jit:
+ warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
+ jit = False
+ state_dict = torch.load(model_path, map_location="cpu")
+
+ if not jit:
+ # Build a non-jit model from the OpenAI jitted model state dict
+ cast_dtype = get_cast_dtype(precision)
+ try:
+ model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype)
+ except KeyError:
+ sd = {k[7:]: v for k, v in state_dict["state_dict"].items()}
+ model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype)
+
+ # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use
+ model = model.to(device)
+ if precision.startswith('amp') or precision == 'fp32':
+ model.float()
+ elif precision == 'bf16':
+ convert_weights_to_lp(model, dtype=torch.bfloat16)
+
+ return model
+
+ # patch the device names
+ device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
+ device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
+
+ def patch_device(module):
+ try:
+ graphs = [module.graph] if hasattr(module, "graph") else []
+ except RuntimeError:
+ graphs = []
+
+ if hasattr(module, "forward1"):
+ graphs.append(module.forward1.graph)
+
+ for graph in graphs:
+ for node in graph.findAllNodes("prim::Constant"):
+ if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
+ node.copyAttributes(device_node)
+
+ model.apply(patch_device)
+ patch_device(model.encode_image)
+ patch_device(model.encode_text)
+
+ # patch dtype to float32 (typically for CPU)
+ if precision == 'fp32':
+ float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
+ float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
+ float_node = float_input.node()
+
+ def patch_float(module):
+ try:
+ graphs = [module.graph] if hasattr(module, "graph") else []
+ except RuntimeError:
+ graphs = []
+
+ if hasattr(module, "forward1"):
+ graphs.append(module.forward1.graph)
+
+ for graph in graphs:
+ for node in graph.findAllNodes("aten::to"):
+ inputs = list(node.inputs())
+ for i in [1, 2]: # dtype can be the second or third argument to aten::to()
+ if inputs[i].node()["value"] == 5:
+ inputs[i].node().copyAttributes(float_node)
+
+ model.apply(patch_float)
+ patch_float(model.encode_image)
+ patch_float(model.encode_text)
+ model.float()
+
+ # ensure image_size attr available at consistent location for both jit and non-jit
+ model.visual.image_size = model.input_resolution.item()
+ return model
diff --git a/diffsynth/extensions/ImageQualityMetric/open_clip/pretrained.py b/diffsynth/extensions/ImageQualityMetric/open_clip/pretrained.py
new file mode 100644
index 0000000000000000000000000000000000000000..87e7e527497d643fdf6ac931ac73b6e887a90d0d
--- /dev/null
+++ b/diffsynth/extensions/ImageQualityMetric/open_clip/pretrained.py
@@ -0,0 +1,376 @@
+import hashlib
+import os
+import urllib
+import warnings
+from functools import partial
+from typing import Dict, Union
+
+from tqdm import tqdm
+
+from .version import __version__
+
+try:
+ from huggingface_hub import hf_hub_download
+ hf_hub_download = partial(hf_hub_download, library_name="open_clip", library_version=__version__)
+ _has_hf_hub = True
+except ImportError:
+ hf_hub_download = None
+ _has_hf_hub = False
+
+
+def _pcfg(url='', hf_hub='', mean=None, std=None):
+ return dict(
+ url=url,
+ hf_hub=hf_hub,
+ mean=mean,
+ std=std,
+ )
+
+
+_RN50 = dict(
+ openai=_pcfg(
+ "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt"),
+ yfcc15m=_pcfg(
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt"),
+ cc12m=_pcfg(
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt"),
+)
+
+_RN50_quickgelu = dict(
+ openai=_pcfg(
+ "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt"),
+ yfcc15m=_pcfg(
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt"),
+ cc12m=_pcfg(
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt"),
+)
+
+_RN101 = dict(
+ openai=_pcfg(
+ "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt"),
+ yfcc15m=_pcfg(
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt"),
+)
+
+_RN101_quickgelu = dict(
+ openai=_pcfg(
+ "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt"),
+ yfcc15m=_pcfg(
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt"),
+)
+
+_RN50x4 = dict(
+ openai=_pcfg(
+ "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt"),
+)
+
+_RN50x16 = dict(
+ openai=_pcfg(
+ "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt"),
+)
+
+_RN50x64 = dict(
+ openai=_pcfg(
+ "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt"),
+)
+
+_VITB32 = dict(
+ openai=_pcfg(
+ "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"),
+ laion400m_e31=_pcfg(
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"),
+ laion400m_e32=_pcfg(
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"),
+ laion2b_e16=_pcfg(
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-laion2b_e16-af8dbd0c.pth"),
+ laion2b_s34b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-laion2B-s34B-b79K/')
+)
+
+_VITB32_quickgelu = dict(
+ openai=_pcfg(
+ "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"),
+ laion400m_e31=_pcfg(
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"),
+ laion400m_e32=_pcfg(
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"),
+)
+
+_VITB16 = dict(
+ openai=_pcfg(
+ "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt"),
+ laion400m_e31=_pcfg(
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e31-00efa78f.pt"),
+ laion400m_e32=_pcfg(
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e32-55e67d44.pt"),
+ # laion400m_32k=_pcfg(
+ # url="",
+ # mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
+ # laion400m_64k=_pcfg(
+ # url="",
+ # mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
+ laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-laion2B-s34B-b88K/'),
+)
+
+_VITB16_PLUS_240 = dict(
+ laion400m_e31=_pcfg(
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e31-8fb26589.pt"),
+ laion400m_e32=_pcfg(
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e32-699c4b84.pt"),
+)
+
+_VITL14 = dict(
+ openai=_pcfg(
+ "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt"),
+ laion400m_e31=_pcfg(
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e31-69988bb6.pt"),
+ laion400m_e32=_pcfg(
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e32-3d133497.pt"),
+ laion2b_s32b_b82k=_pcfg(
+ hf_hub='laion/CLIP-ViT-L-14-laion2B-s32B-b82K/',
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
+)
+
+_VITL14_336 = dict(
+ openai=_pcfg(
+ "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt"),
+)
+
+_VITH14 = dict(
+ laion2b_s32b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-laion2B-s32B-b79K/'),
+)
+
+_VITg14 = dict(
+ laion2b_s12b_b42k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s12B-b42K/'),
+ laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s34B-b88K/'),
+)
+
+_VITbigG14 = dict(
+ laion2b_s39b_b160k=_pcfg(hf_hub='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k/'),
+)
+
+_robertaViTB32 = dict(
+ laion2b_s12b_b32k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-roberta-base-laion2B-s12B-b32k/'),
+)
+
+_xlmRobertaBaseViTB32 = dict(
+ laion5b_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-xlm-roberta-base-laion5B-s13B-b90k/'),
+)
+
+_xlmRobertaLargeFrozenViTH14 = dict(
+ frozen_laion5b_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-frozen-xlm-roberta-large-laion5B-s13B-b90k/'),
+)
+
+_convnext_base = dict(
+ laion400m_s13b_b51k=_pcfg(hf_hub='laion/CLIP-convnext_base-laion400M-s13B-b51K/'),
+)
+
+_convnext_base_w = dict(
+ laion2b_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion2B-s13B-b82K/'),
+ laion2b_s13b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion2B-s13B-b82K-augreg/'),
+ laion_aesthetic_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion_aesthetic-s13B-b82K/'),
+)
+
+_convnext_base_w_320 = dict(
+ laion_aesthetic_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K/'),
+ laion_aesthetic_s13b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K-augreg/'),
+)
+
+_convnext_large_d = dict(
+ laion2b_s26b_b102k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_large_d.laion2B-s26B-b102K-augreg/'),
+)
+
+_convnext_large_d_320 = dict(
+ laion2b_s29b_b131k_ft=_pcfg(hf_hub='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft/'),
+ laion2b_s29b_b131k_ft_soup=_pcfg(hf_hub='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft-soup/'),
+)
+
+_convnext_xxlarge = dict(
+ laion2b_s34b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg/'),
+ laion2b_s34b_b82k_augreg_rewind=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-rewind/'),
+ laion2b_s34b_b82k_augreg_soup=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-soup/'),
+)
+
+_coca_VITB32 = dict(
+ laion2b_s13b_b90k=_pcfg(hf_hub='laion/CoCa-ViT-B-32-laion2B-s13B-b90k/'),
+ mscoco_finetuned_laion2b_s13b_b90k=_pcfg(hf_hub='laion/mscoco_finetuned_CoCa-ViT-B-32-laion2B-s13B-b90k/')
+)
+
+_coca_VITL14 = dict(
+ laion2b_s13b_b90k=_pcfg(hf_hub='laion/CoCa-ViT-L-14-laion2B-s13B-b90k/'),
+ mscoco_finetuned_laion2b_s13b_b90k=_pcfg(hf_hub='laion/mscoco_finetuned_CoCa-ViT-L-14-laion2B-s13B-b90k/')
+)
+
+
+_PRETRAINED = {
+ "RN50": _RN50,
+ "RN50-quickgelu": _RN50_quickgelu,
+ "RN101": _RN101,
+ "RN101-quickgelu": _RN101_quickgelu,
+ "RN50x4": _RN50x4,
+ "RN50x16": _RN50x16,
+ "RN50x64": _RN50x64,
+ "ViT-B-32": _VITB32,
+ "ViT-B-32-quickgelu": _VITB32_quickgelu,
+ "ViT-B-16": _VITB16,
+ "ViT-B-16-plus-240": _VITB16_PLUS_240,
+ "ViT-L-14": _VITL14,
+ "ViT-L-14-336": _VITL14_336,
+ "ViT-H-14": _VITH14,
+ "ViT-g-14": _VITg14,
+ "ViT-bigG-14": _VITbigG14,
+ "roberta-ViT-B-32": _robertaViTB32,
+ "xlm-roberta-base-ViT-B-32": _xlmRobertaBaseViTB32,
+ "xlm-roberta-large-ViT-H-14": _xlmRobertaLargeFrozenViTH14,
+ "convnext_base": _convnext_base,
+ "convnext_base_w": _convnext_base_w,
+ "convnext_base_w_320": _convnext_base_w_320,
+ "convnext_large_d": _convnext_large_d,
+ "convnext_large_d_320": _convnext_large_d_320,
+ "convnext_xxlarge": _convnext_xxlarge,
+ "coca_ViT-B-32": _coca_VITB32,
+ "coca_ViT-L-14": _coca_VITL14,
+}
+
+
+def _clean_tag(tag: str):
+ # normalize pretrained tags
+ return tag.lower().replace('-', '_')
+
+
+def list_pretrained(as_str: bool = False):
+ """ returns list of pretrained models
+ Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True
+ """
+ return [':'.join([k, t]) if as_str else (k, t) for k in _PRETRAINED.keys() for t in _PRETRAINED[k].keys()]
+
+
+def list_pretrained_models_by_tag(tag: str):
+ """ return all models having the specified pretrain tag """
+ models = []
+ tag = _clean_tag(tag)
+ for k in _PRETRAINED.keys():
+ if tag in _PRETRAINED[k]:
+ models.append(k)
+ return models
+
+
+def list_pretrained_tags_by_model(model: str):
+ """ return all pretrain tags for the specified model architecture """
+ tags = []
+ if model in _PRETRAINED:
+ tags.extend(_PRETRAINED[model].keys())
+ return tags
+
+
+def is_pretrained_cfg(model: str, tag: str):
+ if model not in _PRETRAINED:
+ return False
+ return _clean_tag(tag) in _PRETRAINED[model]
+
+
+def get_pretrained_cfg(model: str, tag: str):
+ if model not in _PRETRAINED:
+ return {}
+ model_pretrained = _PRETRAINED[model]
+ return model_pretrained.get(_clean_tag(tag), {})
+
+
+def get_pretrained_url(model: str, tag: str):
+ cfg = get_pretrained_cfg(model, _clean_tag(tag))
+ return cfg.get('url', '')
+
+
+def download_pretrained_from_url(
+ url: str,
+ cache_dir: Union[str, None] = None,
+):
+ if not cache_dir:
+ cache_dir = os.path.expanduser("~/.cache/clip")
+ os.makedirs(cache_dir, exist_ok=True)
+ filename = os.path.basename(url)
+
+ if 'openaipublic' in url:
+ expected_sha256 = url.split("/")[-2]
+ elif 'mlfoundations' in url:
+ expected_sha256 = os.path.splitext(filename)[0].split("-")[-1]
+ else:
+ expected_sha256 = ''
+
+ download_target = os.path.join(cache_dir, filename)
+
+ if os.path.exists(download_target) and not os.path.isfile(download_target):
+ raise RuntimeError(f"{download_target} exists and is not a regular file")
+
+ if os.path.isfile(download_target):
+ if expected_sha256:
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256):
+ return download_target
+ else:
+ warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
+ else:
+ return download_target
+
+ with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
+ with tqdm(total=int(source.headers.get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop:
+ while True:
+ buffer = source.read(8192)
+ if not buffer:
+ break
+
+ output.write(buffer)
+ loop.update(len(buffer))
+
+ if expected_sha256 and not hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256):
+ raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
+
+ return download_target
+
+
+def has_hf_hub(necessary=False):
+ if not _has_hf_hub and necessary:
+ # if no HF Hub module installed, and it is necessary to continue, raise error
+ raise RuntimeError(
+ 'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.')
+ return _has_hf_hub
+
+
+def download_pretrained_from_hf(
+ model_id: str,
+ filename: str = 'open_clip_pytorch_model.bin',
+ revision=None,
+ cache_dir: Union[str, None] = None,
+):
+ has_hf_hub(True)
+ cached_file = hf_hub_download(model_id, filename, revision=revision, cache_dir=cache_dir)
+ return cached_file
+
+
+def download_pretrained(
+ cfg: Dict,
+ force_hf_hub: bool = False,
+ cache_dir: Union[str, None] = None,
+):
+ target = ''
+ if not cfg:
+ return target
+
+ download_url = cfg.get('url', '')
+ download_hf_hub = cfg.get('hf_hub', '')
+ if download_hf_hub and force_hf_hub:
+ # use HF hub even if url exists
+ download_url = ''
+
+ if download_url:
+ target = download_pretrained_from_url(download_url, cache_dir=cache_dir)
+ elif download_hf_hub:
+ has_hf_hub(True)
+ # we assume the hf_hub entries in pretrained config combine model_id + filename in
+ # 'org/model_name/filename.pt' form. To specify just the model id w/o filename and
+ # use 'open_clip_pytorch_model.bin' default, there must be a trailing slash 'org/model_name/'.
+ model_id, filename = os.path.split(download_hf_hub)
+ if filename:
+ target = download_pretrained_from_hf(model_id, filename=filename, cache_dir=cache_dir)
+ else:
+ target = download_pretrained_from_hf(model_id, cache_dir=cache_dir)
+
+ return target
diff --git a/diffsynth/extensions/ImageQualityMetric/open_clip/push_to_hf_hub.py b/diffsynth/extensions/ImageQualityMetric/open_clip/push_to_hf_hub.py
new file mode 100644
index 0000000000000000000000000000000000000000..23c0631c81dcb43829b7374fac09406ecefcb436
--- /dev/null
+++ b/diffsynth/extensions/ImageQualityMetric/open_clip/push_to_hf_hub.py
@@ -0,0 +1,243 @@
+import argparse
+import json
+from pathlib import Path
+from tempfile import TemporaryDirectory
+from typing import Optional, Tuple
+
+import torch
+
+try:
+ from huggingface_hub import (
+ create_repo,
+ get_hf_file_metadata,
+ hf_hub_download,
+ hf_hub_url,
+ repo_type_and_id_from_hf_id,
+ upload_folder,
+ )
+ from huggingface_hub.utils import EntryNotFoundError
+ _has_hf_hub = True
+except ImportError:
+ _has_hf_hub = False
+
+from .factory import create_model_from_pretrained, get_model_config, get_tokenizer
+from .tokenizer import HFTokenizer
+
+
+def save_config_for_hf(
+ model,
+ config_path: str,
+ model_config: Optional[dict]
+):
+ preprocess_cfg = {
+ 'mean': model.visual.image_mean,
+ 'std': model.visual.image_std,
+ }
+ hf_config = {
+ 'model_cfg': model_config,
+ 'preprocess_cfg': preprocess_cfg,
+ }
+
+ with config_path.open('w') as f:
+ json.dump(hf_config, f, indent=2)
+
+
+def save_for_hf(
+ model,
+ tokenizer: HFTokenizer,
+ model_config: dict,
+ save_directory: str,
+ weights_filename='open_clip_pytorch_model.bin',
+ config_filename='open_clip_config.json',
+):
+ save_directory = Path(save_directory)
+ save_directory.mkdir(exist_ok=True, parents=True)
+
+ weights_path = save_directory / weights_filename
+ torch.save(model.state_dict(), weights_path)
+
+ tokenizer.save_pretrained(save_directory)
+
+ config_path = save_directory / config_filename
+ save_config_for_hf(model, config_path, model_config=model_config)
+
+
+def push_to_hf_hub(
+ model,
+ tokenizer,
+ model_config: Optional[dict],
+ repo_id: str,
+ commit_message: str = 'Add model',
+ token: Optional[str] = None,
+ revision: Optional[str] = None,
+ private: bool = False,
+ create_pr: bool = False,
+ model_card: Optional[dict] = None,
+):
+ if not isinstance(tokenizer, HFTokenizer):
+ # default CLIP tokenizers use https://huggingface.co/openai/clip-vit-large-patch14
+ tokenizer = HFTokenizer('openai/clip-vit-large-patch14')
+
+ # Create repo if it doesn't exist yet
+ repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True)
+
+ # Infer complete repo_id from repo_url
+ # Can be different from the input `repo_id` if repo_owner was implicit
+ _, repo_owner, repo_name = repo_type_and_id_from_hf_id(repo_url)
+ repo_id = f"{repo_owner}/{repo_name}"
+
+ # Check if README file already exist in repo
+ try:
+ get_hf_file_metadata(hf_hub_url(repo_id=repo_id, filename="README.md", revision=revision))
+ has_readme = True
+ except EntryNotFoundError:
+ has_readme = False
+
+ # Dump model and push to Hub
+ with TemporaryDirectory() as tmpdir:
+ # Save model weights and config.
+ save_for_hf(
+ model,
+ tokenizer=tokenizer,
+ model_config=model_config,
+ save_directory=tmpdir,
+ )
+
+ # Add readme if it does not exist
+ if not has_readme:
+ model_card = model_card or {}
+ model_name = repo_id.split('/')[-1]
+ readme_path = Path(tmpdir) / "README.md"
+ readme_text = generate_readme(model_card, model_name)
+ readme_path.write_text(readme_text)
+
+ # Upload model and return
+ return upload_folder(
+ repo_id=repo_id,
+ folder_path=tmpdir,
+ revision=revision,
+ create_pr=create_pr,
+ commit_message=commit_message,
+ )
+
+
+def push_pretrained_to_hf_hub(
+ model_name,
+ pretrained: str,
+ repo_id: str,
+ image_mean: Optional[Tuple[float, ...]] = None,
+ image_std: Optional[Tuple[float, ...]] = None,
+ commit_message: str = 'Add model',
+ token: Optional[str] = None,
+ revision: Optional[str] = None,
+ private: bool = False,
+ create_pr: bool = False,
+ model_card: Optional[dict] = None,
+):
+ model, preprocess_eval = create_model_from_pretrained(
+ model_name,
+ pretrained=pretrained,
+ image_mean=image_mean,
+ image_std=image_std,
+ )
+
+ model_config = get_model_config(model_name)
+ assert model_config
+
+ tokenizer = get_tokenizer(model_name)
+
+ push_to_hf_hub(
+ model=model,
+ tokenizer=tokenizer,
+ model_config=model_config,
+ repo_id=repo_id,
+ commit_message=commit_message,
+ token=token,
+ revision=revision,
+ private=private,
+ create_pr=create_pr,
+ model_card=model_card,
+ )
+
+
+def generate_readme(model_card: dict, model_name: str):
+ readme_text = "---\n"
+ readme_text += "tags:\n- zero-shot-image-classification\n- clip\n"
+ readme_text += "library_tag: open_clip\n"
+ readme_text += f"license: {model_card.get('license', 'mit')}\n"
+ if 'details' in model_card and 'Dataset' in model_card['details']:
+ readme_text += 'datasets:\n'
+ readme_text += f"- {model_card['details']['Dataset'].lower()}\n"
+ readme_text += "---\n"
+ readme_text += f"# Model card for {model_name}\n"
+ if 'description' in model_card:
+ readme_text += f"\n{model_card['description']}\n"
+ if 'details' in model_card:
+ readme_text += f"\n## Model Details\n"
+ for k, v in model_card['details'].items():
+ if isinstance(v, (list, tuple)):
+ readme_text += f"- **{k}:**\n"
+ for vi in v:
+ readme_text += f" - {vi}\n"
+ elif isinstance(v, dict):
+ readme_text += f"- **{k}:**\n"
+ for ki, vi in v.items():
+ readme_text += f" - {ki}: {vi}\n"
+ else:
+ readme_text += f"- **{k}:** {v}\n"
+ if 'usage' in model_card:
+ readme_text += f"\n## Model Usage\n"
+ readme_text += model_card['usage']
+ readme_text += '\n'
+
+ if 'comparison' in model_card:
+ readme_text += f"\n## Model Comparison\n"
+ readme_text += model_card['comparison']
+ readme_text += '\n'
+
+ if 'citation' in model_card:
+ readme_text += f"\n## Citation\n"
+ if not isinstance(model_card['citation'], (list, tuple)):
+ citations = [model_card['citation']]
+ else:
+ citations = model_card['citation']
+ for c in citations:
+ readme_text += f"```bibtex\n{c}\n```\n"
+
+ return readme_text
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Push to Hugging Face Hub")
+ parser.add_argument(
+ "--model", type=str, help="Name of the model to use.",
+ )
+ parser.add_argument(
+ "--pretrained", type=str,
+ help="Use a pretrained CLIP model weights with the specified tag or file path.",
+ )
+ parser.add_argument(
+ "--repo-id", type=str,
+ help="Destination HF Hub repo-id ie 'organization/model_id'.",
+ )
+ parser.add_argument(
+ '--image-mean', type=float, nargs='+', default=None, metavar='MEAN',
+ help='Override default image mean value of dataset')
+ parser.add_argument(
+ '--image-std', type=float, nargs='+', default=None, metavar='STD',
+ help='Override default image std deviation of of dataset')
+ args = parser.parse_args()
+
+ print(f'Saving model {args.model} with pretrained weights {args.pretrained} to Hugging Face Hub at {args.repo_id}')
+
+ # FIXME add support to pass model_card json / template from file via cmd line
+
+ push_pretrained_to_hf_hub(
+ args.model,
+ args.pretrained,
+ args.repo_id,
+ image_mean=args.image_mean, # override image mean/std if trained w/ non defaults
+ image_std=args.image_std,
+ )
+
+ print(f'{args.model} saved.')
diff --git a/diffsynth/extensions/ImageQualityMetric/open_clip/timm_model.py b/diffsynth/extensions/ImageQualityMetric/open_clip/timm_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..dc71a693f9a42ec01fd88d307661bc382b4d05bc
--- /dev/null
+++ b/diffsynth/extensions/ImageQualityMetric/open_clip/timm_model.py
@@ -0,0 +1,127 @@
+""" timm model adapter
+
+Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model.
+"""
+import logging
+from collections import OrderedDict
+
+import torch
+import torch.nn as nn
+
+try:
+ import timm
+ from timm.models.layers import Mlp, to_2tuple
+ try:
+ # old timm imports < 0.8.1
+ from timm.models.layers.attention_pool2d import RotAttentionPool2d
+ from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d
+ except ImportError:
+ # new timm imports >= 0.8.1
+ from timm.layers import RotAttentionPool2d
+ from timm.layers import AttentionPool2d as AbsAttentionPool2d
+except ImportError:
+ timm = None
+
+from .utils import freeze_batch_norm_2d
+
+
+class TimmModel(nn.Module):
+ """ timm model adapter
+ # FIXME this adapter is a work in progress, may change in ways that break weight compat
+ """
+
+ def __init__(
+ self,
+ model_name,
+ embed_dim,
+ image_size=224,
+ pool='avg',
+ proj='linear',
+ proj_bias=False,
+ drop=0.,
+ drop_path=None,
+ pretrained=False,
+ ):
+ super().__init__()
+ if timm is None:
+ raise RuntimeError("Please `pip install timm` to use timm models.")
+
+ self.image_size = to_2tuple(image_size)
+ timm_kwargs = {}
+ if drop_path is not None:
+ timm_kwargs['drop_path_rate'] = drop_path
+ self.trunk = timm.create_model(model_name, pretrained=pretrained, **timm_kwargs)
+ feat_size = self.trunk.default_cfg.get('pool_size', None)
+ feature_ndim = 1 if not feat_size else 2
+ if pool in ('abs_attn', 'rot_attn'):
+ assert feature_ndim == 2
+ # if attn pooling used, remove both classifier and default pool
+ self.trunk.reset_classifier(0, global_pool='')
+ else:
+ # reset global pool if pool config set, otherwise leave as network default
+ reset_kwargs = dict(global_pool=pool) if pool else {}
+ self.trunk.reset_classifier(0, **reset_kwargs)
+ prev_chs = self.trunk.num_features
+
+ head_layers = OrderedDict()
+ if pool == 'abs_attn':
+ head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim)
+ prev_chs = embed_dim
+ elif pool == 'rot_attn':
+ head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim)
+ prev_chs = embed_dim
+ else:
+ assert proj, 'projection layer needed if non-attention pooling is used.'
+
+ # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used
+ if proj == 'linear':
+ head_layers['drop'] = nn.Dropout(drop)
+ head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=proj_bias)
+ elif proj == 'mlp':
+ head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=(drop, 0), bias=(True, proj_bias))
+
+ self.head = nn.Sequential(head_layers)
+
+ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
+ """ lock modules
+ Args:
+ unlocked_groups (int): leave last n layer groups unlocked (default: 0)
+ """
+ if not unlocked_groups:
+ # lock full model
+ for param in self.trunk.parameters():
+ param.requires_grad = False
+ if freeze_bn_stats:
+ freeze_batch_norm_2d(self.trunk)
+ else:
+ # NOTE: partial freeze requires latest timm (master) branch and is subject to change
+ try:
+ # FIXME import here until API stable and in an official release
+ from timm.models.helpers import group_parameters, group_modules
+ except ImportError:
+ raise RuntimeError(
+ 'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`')
+ matcher = self.trunk.group_matcher()
+ gparams = group_parameters(self.trunk, matcher)
+ max_layer_id = max(gparams.keys())
+ max_layer_id = max_layer_id - unlocked_groups
+ for group_idx in range(max_layer_id + 1):
+ group = gparams[group_idx]
+ for param in group:
+ self.trunk.get_parameter(param).requires_grad = False
+ if freeze_bn_stats:
+ gmodules = group_modules(self.trunk, matcher, reverse=True)
+ gmodules = {k for k, v in gmodules.items() if v <= max_layer_id}
+ freeze_batch_norm_2d(self.trunk, gmodules)
+
+ @torch.jit.ignore
+ def set_grad_checkpointing(self, enable=True):
+ try:
+ self.trunk.set_grad_checkpointing(enable)
+ except Exception as e:
+ logging.warning('grad checkpointing not supported for this timm image tower, continuing without...')
+
+ def forward(self, x):
+ x = self.trunk(x)
+ x = self.head(x)
+ return x
diff --git a/diffsynth/extensions/ImageQualityMetric/open_clip/tokenizer.py b/diffsynth/extensions/ImageQualityMetric/open_clip/tokenizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..22ec4880b13ec73594d5c19b3d3be83aadb55aba
--- /dev/null
+++ b/diffsynth/extensions/ImageQualityMetric/open_clip/tokenizer.py
@@ -0,0 +1,211 @@
+""" CLIP tokenizer
+
+Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
+"""
+import gzip
+import html
+import os
+from functools import lru_cache
+from typing import Union, List
+
+import ftfy
+import regex as re
+import torch
+
+# https://stackoverflow.com/q/62691279
+import os
+os.environ["TOKENIZERS_PARALLELISM"] = "false"
+
+
+@lru_cache()
+def default_bpe():
+ current_dir = os.path.dirname(os.path.abspath(__file__))
+ project_root = os.path.abspath(os.path.join(current_dir, '../../../../'))
+ quality_metric_path = os.path.join(project_root, 'models', 'QualityMetric')
+ return os.path.join(quality_metric_path, "bpe_simple_vocab_16e6.txt.gz")
+
+
+@lru_cache()
+def bytes_to_unicode():
+ """
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
+ The reversible bpe codes work on unicode strings.
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
+ This is a significant percentage of your normal, say, 32K bpe vocab.
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
+ """
+ bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
+ cs = bs[:]
+ n = 0
+ for b in range(2**8):
+ if b not in bs:
+ bs.append(b)
+ cs.append(2**8+n)
+ n += 1
+ cs = [chr(n) for n in cs]
+ return dict(zip(bs, cs))
+
+
+def get_pairs(word):
+ """Return set of symbol pairs in a word.
+ Word is represented as tuple of symbols (symbols being variable-length strings).
+ """
+ pairs = set()
+ prev_char = word[0]
+ for char in word[1:]:
+ pairs.add((prev_char, char))
+ prev_char = char
+ return pairs
+
+
+def basic_clean(text):
+ text = ftfy.fix_text(text)
+ text = html.unescape(html.unescape(text))
+ return text.strip()
+
+
+def whitespace_clean(text):
+ text = re.sub(r'\s+', ' ', text)
+ text = text.strip()
+ return text
+
+
+class SimpleTokenizer(object):
+ def __init__(self, bpe_path: str = default_bpe(), special_tokens=None):
+ self.byte_encoder = bytes_to_unicode()
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
+ merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
+ merges = merges[1:49152-256-2+1]
+ merges = [tuple(merge.split()) for merge in merges]
+ vocab = list(bytes_to_unicode().values())
+ vocab = vocab + [v+'' for v in vocab]
+ for merge in merges:
+ vocab.append(''.join(merge))
+ if not special_tokens:
+ special_tokens = ['', '']
+ else:
+ special_tokens = ['', ''] + special_tokens
+ vocab.extend(special_tokens)
+ self.encoder = dict(zip(vocab, range(len(vocab))))
+ self.decoder = {v: k for k, v in self.encoder.items()}
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
+ self.cache = {t:t for t in special_tokens}
+ special = "|".join(special_tokens)
+ self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
+
+ self.vocab_size = len(self.encoder)
+ self.all_special_ids = [self.encoder[t] for t in special_tokens]
+
+ def bpe(self, token):
+ if token in self.cache:
+ return self.cache[token]
+ word = tuple(token[:-1]) + ( token[-1] + '',)
+ pairs = get_pairs(word)
+
+ if not pairs:
+ return token+''
+
+ while True:
+ bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
+ if bigram not in self.bpe_ranks:
+ break
+ first, second = bigram
+ new_word = []
+ i = 0
+ while i < len(word):
+ try:
+ j = word.index(first, i)
+ new_word.extend(word[i:j])
+ i = j
+ except:
+ new_word.extend(word[i:])
+ break
+
+ if word[i] == first and i < len(word)-1 and word[i+1] == second:
+ new_word.append(first+second)
+ i += 2
+ else:
+ new_word.append(word[i])
+ i += 1
+ new_word = tuple(new_word)
+ word = new_word
+ if len(word) == 1:
+ break
+ else:
+ pairs = get_pairs(word)
+ word = ' '.join(word)
+ self.cache[token] = word
+ return word
+
+ def encode(self, text):
+ bpe_tokens = []
+ text = whitespace_clean(basic_clean(text)).lower()
+ for token in re.findall(self.pat, text):
+ token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
+ bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
+ return bpe_tokens
+
+ def decode(self, tokens):
+ text = ''.join([self.decoder[token] for token in tokens])
+ text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ')
+ return text
+
+ def __call__(self, texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor:
+ """
+ Returns the tokenized representation of given input string(s)
+
+ Parameters
+ ----------
+ texts : Union[str, List[str]]
+ An input string or a list of input strings to tokenize
+ context_length : int
+ The context length to use; all CLIP models use 77 as the context length
+
+ Returns
+ -------
+ A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
+ """
+ if isinstance(texts, str):
+ texts = [texts]
+
+ sot_token = self.encoder[""]
+ eot_token = self.encoder[""]
+ all_tokens = [[sot_token] + self.encode(text) + [eot_token] for text in texts]
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
+
+ for i, tokens in enumerate(all_tokens):
+ if len(tokens) > context_length:
+ tokens = tokens[:context_length] # Truncate
+ tokens[-1] = eot_token
+ result[i, :len(tokens)] = torch.tensor(tokens)
+
+ return result
+
+
+
+class HFTokenizer:
+ """HuggingFace tokenizer wrapper"""
+
+ def __init__(self, tokenizer_name: str):
+ from transformers import AutoTokenizer
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
+
+ def save_pretrained(self, dest):
+ self.tokenizer.save_pretrained(dest)
+
+ def __call__(self, texts: Union[str, List[str]], context_length: int = 77) -> torch.Tensor:
+ # same cleaning as for default tokenizer, except lowercasing
+ # adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance
+ if isinstance(texts, str):
+ texts = [texts]
+ texts = [whitespace_clean(basic_clean(text)) for text in texts]
+ input_ids = self.tokenizer(
+ texts,
+ return_tensors='pt',
+ max_length=context_length,
+ padding='max_length',
+ truncation=True,
+ ).input_ids
+ return input_ids
diff --git a/diffsynth/extensions/ImageQualityMetric/open_clip/transform.py b/diffsynth/extensions/ImageQualityMetric/open_clip/transform.py
new file mode 100644
index 0000000000000000000000000000000000000000..fe4e21fa5b515f2412049f9274bd06fbe77fb9b9
--- /dev/null
+++ b/diffsynth/extensions/ImageQualityMetric/open_clip/transform.py
@@ -0,0 +1,216 @@
+import warnings
+from dataclasses import dataclass, asdict
+from typing import Any, Dict, Optional, Sequence, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torchvision.transforms.functional as F
+from functools import partial
+from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \
+ CenterCrop
+
+from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
+
+
+@dataclass
+class AugmentationCfg:
+ scale: Tuple[float, float] = (0.9, 1.0)
+ ratio: Optional[Tuple[float, float]] = None
+ color_jitter: Optional[Union[float, Tuple[float, float, float]]] = None
+ interpolation: Optional[str] = None
+ re_prob: Optional[float] = None
+ re_count: Optional[int] = None
+ use_timm: bool = False
+
+
+class ResizeMaxSize(nn.Module):
+
+ def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn='max', fill=0):
+ super().__init__()
+ if not isinstance(max_size, int):
+ raise TypeError(f"Size should be int. Got {type(max_size)}")
+ self.max_size = max_size
+ self.interpolation = interpolation
+ self.fn = min if fn == 'min' else min
+ self.fill = fill
+
+ def forward(self, img):
+ if isinstance(img, torch.Tensor):
+ height, width = img.shape[1:]
+ else:
+ width, height = img.size
+ scale = self.max_size / float(max(height, width))
+ if scale != 1.0:
+ new_size = tuple(round(dim * scale) for dim in (height, width))
+ img = F.resize(img, new_size, self.interpolation)
+ pad_h = self.max_size - new_size[0]
+ pad_w = self.max_size - new_size[1]
+ img = F.pad(img, padding=[pad_w//2, pad_h//2, pad_w - pad_w//2, pad_h - pad_h//2], fill=self.fill)
+ return img
+
+
+def _convert_to_rgb_or_rgba(image):
+ if image.mode == 'RGBA':
+ return image
+ else:
+ return image.convert('RGB')
+
+# def transform_and_split(merged, transform_fn, normalize_fn):
+# transformed = transform_fn(merged)
+# crop_img, crop_label = torch.split(transformed, [3,1], dim=0)
+
+# # crop_img = _convert_to_rgb(crop_img)
+# crop_img = normalize_fn(ToTensor()(crop_img))
+# return crop_img, crop_label
+
+class MaskAwareNormalize(nn.Module):
+ def __init__(self, mean, std):
+ super().__init__()
+ self.normalize = Normalize(mean=mean, std=std)
+
+ def forward(self, tensor):
+ if tensor.shape[0] == 4:
+ return torch.cat([self.normalize(tensor[:3]), tensor[3:]], dim=0)
+ else:
+ return self.normalize(tensor)
+
+def image_transform(
+ image_size: int,
+ is_train: bool,
+ mean: Optional[Tuple[float, ...]] = None,
+ std: Optional[Tuple[float, ...]] = None,
+ resize_longest_max: bool = False,
+ fill_color: int = 0,
+ aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
+):
+ mean = mean or OPENAI_DATASET_MEAN
+ if not isinstance(mean, (list, tuple)):
+ mean = (mean,) * 3
+
+ std = std or OPENAI_DATASET_STD
+ if not isinstance(std, (list, tuple)):
+ std = (std,) * 3
+
+ if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]:
+ # for square size, pass size as int so that Resize() uses aspect preserving shortest edge
+ image_size = image_size[0]
+
+ if isinstance(aug_cfg, dict):
+ aug_cfg = AugmentationCfg(**aug_cfg)
+ else:
+ aug_cfg = aug_cfg or AugmentationCfg()
+ normalize = MaskAwareNormalize(mean=mean, std=std)
+ if is_train:
+ aug_cfg_dict = {k: v for k, v in asdict(aug_cfg).items() if v is not None}
+ use_timm = aug_cfg_dict.pop('use_timm', False)
+ if use_timm:
+ assert False, "not tested for augmentation with mask"
+ from timm.data import create_transform # timm can still be optional
+ if isinstance(image_size, (tuple, list)):
+ assert len(image_size) >= 2
+ input_size = (3,) + image_size[-2:]
+ else:
+ input_size = (3, image_size, image_size)
+ # by default, timm aug randomly alternates bicubic & bilinear for better robustness at inference time
+ aug_cfg_dict.setdefault('interpolation', 'random')
+ aug_cfg_dict.setdefault('color_jitter', None) # disable by default
+ train_transform = create_transform(
+ input_size=input_size,
+ is_training=True,
+ hflip=0.,
+ mean=mean,
+ std=std,
+ re_mode='pixel',
+ **aug_cfg_dict,
+ )
+ else:
+ train_transform = Compose([
+ _convert_to_rgb_or_rgba,
+ ToTensor(),
+ RandomResizedCrop(
+ image_size,
+ scale=aug_cfg_dict.pop('scale'),
+ interpolation=InterpolationMode.BICUBIC,
+ ),
+ normalize,
+ ])
+ if aug_cfg_dict:
+ warnings.warn(f'Unused augmentation cfg items, specify `use_timm` to use ({list(aug_cfg_dict.keys())}).')
+ return train_transform
+ else:
+ transforms = [
+ _convert_to_rgb_or_rgba,
+ ToTensor(),
+ ]
+ if resize_longest_max:
+ transforms.extend([
+ ResizeMaxSize(image_size, fill=fill_color)
+ ])
+ else:
+ transforms.extend([
+ Resize(image_size, interpolation=InterpolationMode.BICUBIC),
+ CenterCrop(image_size),
+ ])
+ transforms.extend([
+ normalize,
+ ])
+ return Compose(transforms)
+
+
+# def image_transform_region(
+# image_size: int,
+# is_train: bool,
+# mean: Optional[Tuple[float, ...]] = None,
+# std: Optional[Tuple[float, ...]] = None,
+# resize_longest_max: bool = False,
+# fill_color: int = 0,
+# aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
+# ):
+# mean = mean or OPENAI_DATASET_MEAN
+# if not isinstance(mean, (list, tuple)):
+# mean = (mean,) * 3
+
+# std = std or OPENAI_DATASET_STD
+# if not isinstance(std, (list, tuple)):
+# std = (std,) * 3
+
+# if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]:
+# # for square size, pass size as int so that Resize() uses aspect preserving shortest edge
+# image_size = image_size[0]
+
+# if isinstance(aug_cfg, dict):
+# aug_cfg = AugmentationCfg(**aug_cfg)
+# else:
+# aug_cfg = aug_cfg or AugmentationCfg()
+# normalize = Normalize(mean=mean, std=std)
+# if is_train:
+# aug_cfg_dict = {k: v for k, v in asdict(aug_cfg).items() if v is not None}
+
+# transform = Compose([
+# RandomResizedCrop(
+# image_size,
+# scale=aug_cfg_dict.pop('scale'),
+# interpolation=InterpolationMode.BICUBIC,
+# ),
+# ])
+# train_transform = Compose([
+# partial(transform_and_split, transform_fn=transform,normalize_fn=normalize)
+# ])
+# return train_transform
+# else:
+# if resize_longest_max:
+# transform = [
+# ResizeMaxSize(image_size, fill=fill_color)
+# ]
+# val_transform = Compose([
+# partial(transform_and_split, transform_fn=transform,normalize_fn=normalize),
+# ])
+# else:
+# transform = [
+# Resize(image_size, interpolation=InterpolationMode.BICUBIC),
+# CenterCrop(image_size),
+# ]
+# val_transform = Compose([
+# partial(transform_and_split, transform_fn=transform,normalize_fn=normalize),
+# ])
+# return val_transform
\ No newline at end of file
diff --git a/diffsynth/extensions/ImageQualityMetric/open_clip/transformer.py b/diffsynth/extensions/ImageQualityMetric/open_clip/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..7465c1b20bf388a17e0f4f80f7b8eee3b564af92
--- /dev/null
+++ b/diffsynth/extensions/ImageQualityMetric/open_clip/transformer.py
@@ -0,0 +1,727 @@
+from collections import OrderedDict
+import math
+from typing import Callable, Optional, Sequence, Tuple
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+from torch.utils.checkpoint import checkpoint
+
+from .utils import to_2tuple
+
+
+class LayerNormFp32(nn.LayerNorm):
+ """Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back)."""
+
+ def forward(self, x: torch.Tensor):
+ orig_type = x.dtype
+ x = F.layer_norm(x.to(torch.float32), self.normalized_shape, self.weight, self.bias, self.eps)
+ return x.to(orig_type)
+
+
+class LayerNorm(nn.LayerNorm):
+ """Subclass torch's LayerNorm (with cast back to input dtype)."""
+
+ def forward(self, x: torch.Tensor):
+ orig_type = x.dtype
+ x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
+ return x.to(orig_type)
+
+
+class QuickGELU(nn.Module):
+ # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory
+ def forward(self, x: torch.Tensor):
+ return x * torch.sigmoid(1.702 * x)
+
+
+class LayerScale(nn.Module):
+ def __init__(self, dim, init_values=1e-5, inplace=False):
+ super().__init__()
+ self.inplace = inplace
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
+
+ def forward(self, x):
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
+
+
+class PatchDropout(nn.Module):
+ """
+ https://arxiv.org/abs/2212.00794
+ """
+
+ def __init__(self, prob, exclude_first_token=True):
+ super().__init__()
+ assert 0 <= prob < 1.
+ self.prob = prob
+ self.exclude_first_token = exclude_first_token # exclude CLS token
+
+ def forward(self, x):
+ if not self.training or self.prob == 0.:
+ return x
+
+ if self.exclude_first_token:
+ cls_tokens, x = x[:, :1], x[:, 1:]
+ else:
+ cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])
+
+ batch = x.size()[0]
+ num_tokens = x.size()[1]
+
+ batch_indices = torch.arange(batch)
+ batch_indices = batch_indices[..., None]
+
+ keep_prob = 1 - self.prob
+ num_patches_keep = max(1, int(num_tokens * keep_prob))
+
+ rand = torch.randn(batch, num_tokens)
+ patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
+
+ x = x[batch_indices, patch_indices_keep]
+
+ if self.exclude_first_token:
+ x = torch.cat((cls_tokens, x), dim=1)
+
+ return x
+
+
+class Attention(nn.Module):
+ def __init__(
+ self,
+ dim,
+ num_heads=8,
+ qkv_bias=True,
+ scaled_cosine=False,
+ scale_heads=False,
+ logit_scale_max=math.log(1. / 0.01),
+ attn_drop=0.,
+ proj_drop=0.
+ ):
+ super().__init__()
+ self.scaled_cosine = scaled_cosine
+ self.scale_heads = scale_heads
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
+ self.num_heads = num_heads
+ self.head_dim = dim // num_heads
+ self.scale = self.head_dim ** -0.5
+ self.logit_scale_max = logit_scale_max
+
+ # keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original
+ self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale)
+ if qkv_bias:
+ self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3))
+ else:
+ self.in_proj_bias = None
+
+ if self.scaled_cosine:
+ self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))
+ else:
+ self.logit_scale = None
+ self.attn_drop = nn.Dropout(attn_drop)
+ if self.scale_heads:
+ self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1)))
+ else:
+ self.head_scale = None
+ self.out_proj = nn.Linear(dim, dim)
+ self.out_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
+ L, N, C = x.shape
+ q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1)
+ q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
+ k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
+ v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
+
+ if self.logit_scale is not None:
+ attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2))
+ logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp()
+ attn = attn.view(N, self.num_heads, L, L) * logit_scale
+ attn = attn.view(-1, L, L)
+ else:
+ q = q * self.scale
+ attn = torch.bmm(q, k.transpose(-1, -2))
+
+ if attn_mask is not None:
+ if attn_mask.dtype == torch.bool:
+ new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
+ new_attn_mask.masked_fill_(attn_mask, float("-inf"))
+ attn_mask = new_attn_mask
+ attn += attn_mask
+
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = torch.bmm(attn, v)
+ if self.head_scale is not None:
+ x = x.view(N, self.num_heads, L, C) * self.head_scale
+ x = x.view(-1, L, C)
+ x = x.transpose(0, 1).reshape(L, N, C)
+ x = self.out_proj(x)
+ x = self.out_drop(x)
+ return x
+
+
+class AttentionalPooler(nn.Module):
+ def __init__(
+ self,
+ d_model: int,
+ context_dim: int,
+ n_head: int = 8,
+ n_queries: int = 256,
+ norm_layer: Callable = LayerNorm
+ ):
+ super().__init__()
+ self.query = nn.Parameter(torch.randn(n_queries, d_model))
+ self.attn = nn.MultiheadAttention(d_model, n_head, kdim=context_dim, vdim=context_dim)
+ self.ln_q = norm_layer(d_model)
+ self.ln_k = norm_layer(context_dim)
+
+ def forward(self, x: torch.Tensor):
+ x = self.ln_k(x).permute(1, 0, 2) # NLD -> LND
+ N = x.shape[1]
+ q = self.ln_q(self.query)
+ out = self.attn(self._repeat(q, N), x, x, need_weights=False)[0]
+ return out.permute(1, 0, 2) # LND -> NLD
+
+ def _repeat(self, query, N: int):
+ return query.unsqueeze(1).repeat(1, N, 1)
+
+
+class ResidualAttentionBlock(nn.Module):
+ def __init__(
+ self,
+ d_model: int,
+ n_head: int,
+ mlp_ratio: float = 4.0,
+ ls_init_value: float = None,
+ act_layer: Callable = nn.GELU,
+ norm_layer: Callable = LayerNorm,
+ is_cross_attention: bool = False,
+ ):
+ super().__init__()
+
+ self.ln_1 = norm_layer(d_model)
+ self.attn = nn.MultiheadAttention(d_model, n_head)
+ self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
+ if is_cross_attention:
+ self.ln_1_kv = norm_layer(d_model)
+
+ self.ln_2 = norm_layer(d_model)
+ mlp_width = int(d_model * mlp_ratio)
+ self.mlp = nn.Sequential(OrderedDict([
+ ("c_fc", nn.Linear(d_model, mlp_width)),
+ ("gelu", act_layer()),
+ ("c_proj", nn.Linear(mlp_width, d_model))
+ ]))
+ self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
+
+ def attention(
+ self,
+ q_x: torch.Tensor,
+ k_x: Optional[torch.Tensor] = None,
+ v_x: Optional[torch.Tensor] = None,
+ attn_mask: Optional[torch.Tensor] = None,
+ ):
+ k_x = k_x if k_x is not None else q_x
+ v_x = v_x if v_x is not None else q_x
+
+ attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None
+ return self.attn(
+ q_x, k_x, v_x, need_weights=False, attn_mask=attn_mask
+ )[0]
+
+ def forward(
+ self,
+ q_x: torch.Tensor,
+ k_x: Optional[torch.Tensor] = None,
+ v_x: Optional[torch.Tensor] = None,
+ attn_mask: Optional[torch.Tensor] = None,
+ ):
+ k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None
+ v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None
+
+ x = q_x + self.ls_1(self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask))
+ x = x + self.ls_2(self.mlp(self.ln_2(x)))
+ return x
+
+
+class CustomResidualAttentionBlock(nn.Module):
+ def __init__(
+ self,
+ d_model: int,
+ n_head: int,
+ mlp_ratio: float = 4.0,
+ ls_init_value: float = None,
+ act_layer: Callable = nn.GELU,
+ norm_layer: Callable = LayerNorm,
+ scale_cosine_attn: bool = False,
+ scale_heads: bool = False,
+ scale_attn: bool = False,
+ scale_fc: bool = False,
+ ):
+ super().__init__()
+
+ self.ln_1 = norm_layer(d_model)
+ self.attn = Attention(
+ d_model, n_head,
+ scaled_cosine=scale_cosine_attn,
+ scale_heads=scale_heads,
+ )
+ self.ln_attn = norm_layer(d_model) if scale_attn else nn.Identity()
+ self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
+
+ self.ln_2 = norm_layer(d_model)
+ mlp_width = int(d_model * mlp_ratio)
+ self.mlp = nn.Sequential(OrderedDict([
+ ("c_fc", nn.Linear(d_model, mlp_width)),
+ ('ln', norm_layer(mlp_width) if scale_fc else nn.Identity()),
+ ("gelu", act_layer()),
+ ("c_proj", nn.Linear(mlp_width, d_model))
+ ]))
+ self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
+
+ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
+ x = x + self.ls_1(self.ln_attn(self.attn(self.ln_1(x), attn_mask=attn_mask)))
+ x = x + self.ls_2(self.mlp(self.ln_2(x)))
+ return x
+
+
+class Transformer(nn.Module):
+ def __init__(
+ self,
+ width: int,
+ layers: int,
+ heads: int,
+ mlp_ratio: float = 4.0,
+ ls_init_value: float = None,
+ act_layer: Callable = nn.GELU,
+ norm_layer: Callable = LayerNorm,
+ ):
+ super().__init__()
+ self.width = width
+ self.layers = layers
+ self.grad_checkpointing = False
+
+ self.resblocks = nn.ModuleList([
+ ResidualAttentionBlock(
+ width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer)
+ for _ in range(layers)
+ ])
+
+ def get_cast_dtype(self) -> torch.dtype:
+ return self.resblocks[0].mlp.c_fc.weight.dtype
+
+ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
+ for r in self.resblocks:
+ if self.grad_checkpointing and not torch.jit.is_scripting():
+ # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372
+ x = checkpoint(r, x, None, None, attn_mask)
+ else:
+ x = r(x, attn_mask=attn_mask)
+ return x
+
+
+class VisionTransformer(nn.Module):
+ output_tokens: torch.jit.Final[bool]
+
+ def __init__(
+ self,
+ image_size: int,
+ patch_size: int,
+ width: int,
+ layers: int,
+ heads: int,
+ mlp_ratio: float,
+ ls_init_value: float = None,
+ global_average_pool: bool = False,
+ attentional_pool: bool = False,
+ n_queries: int = 256,
+ attn_pooler_heads: int = 8,
+ output_dim: int = 512,
+ patch_dropout: float = 0.,
+ input_patchnorm: bool = False,
+ act_layer: Callable = nn.GELU,
+ norm_layer: Callable = LayerNorm,
+ output_tokens: bool = False
+ ):
+ super().__init__()
+ self.output_tokens = output_tokens
+ image_height, image_width = self.image_size = to_2tuple(image_size)
+ patch_height, patch_width = self.patch_size = to_2tuple(patch_size)
+ self.grid_size = (image_height // patch_height, image_width // patch_width)
+ self.output_dim = output_dim
+
+ # whether to layernorm each patch, as done in dual patchnorm paper - https://arxiv.org/abs/2302.01327v1
+ self.input_patchnorm = input_patchnorm
+
+ if input_patchnorm:
+ patch_input_dim = patch_height * patch_width * 3
+ self.patchnorm_pre_ln = LayerNorm(patch_input_dim)
+ self.conv1 = nn.Linear(patch_input_dim, width)
+ else:
+ self.patchnorm_pre_ln = nn.Identity()
+ self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
+
+ # class embeddings and positional embeddings
+ scale = width ** -0.5
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
+ self.positional_embedding = nn.Parameter(scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width))
+
+ # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
+ self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity()
+
+ self.ln_pre = norm_layer(width)
+ self.transformer = Transformer(
+ width,
+ layers,
+ heads,
+ mlp_ratio,
+ ls_init_value=ls_init_value,
+ act_layer=act_layer,
+ norm_layer=norm_layer,
+ )
+
+ self.global_average_pool = global_average_pool
+ if attentional_pool:
+ self.attn_pool = AttentionalPooler(output_dim, width, n_head=attn_pooler_heads, n_queries=n_queries)
+ self.ln_post = norm_layer(output_dim)
+ self.proj = nn.Parameter(scale * torch.randn(output_dim, output_dim))
+ else:
+ self.attn_pool = None
+ self.ln_post = norm_layer(width)
+ self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
+
+ self.init_parameters()
+
+ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
+ for param in self.parameters():
+ param.requires_grad = False
+
+ if unlocked_groups != 0:
+ groups = [
+ [
+ self.conv1,
+ self.class_embedding,
+ self.positional_embedding,
+ self.ln_pre,
+ ],
+ *self.transformer.resblocks[:-1],
+ [
+ self.transformer.resblocks[-1],
+ self.ln_post,
+ ],
+ self.proj,
+ ]
+
+ def _unlock(x):
+ if isinstance(x, Sequence):
+ for g in x:
+ _unlock(g)
+ else:
+ if isinstance(x, torch.nn.Parameter):
+ x.requires_grad = True
+ else:
+ for p in x.parameters():
+ p.requires_grad = True
+
+ _unlock(groups[-unlocked_groups:])
+
+ def init_parameters(self):
+ # FIXME OpenAI CLIP did not define an init for the VisualTransformer
+ # TODO experiment if default PyTorch init, below, or alternate init is best.
+
+ # nn.init.normal_(self.class_embedding, std=self.scale)
+ # nn.init.normal_(self.positional_embedding, std=self.scale)
+ #
+ # proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
+ # attn_std = self.transformer.width ** -0.5
+ # fc_std = (2 * self.transformer.width) ** -0.5
+ # for block in self.transformer.resblocks:
+ # nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
+ # nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
+ # nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
+ # nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
+ #
+ # if self.text_projection is not None:
+ # nn.init.normal_(self.text_projection, std=self.scale)
+ pass
+
+ @torch.jit.ignore
+ def set_grad_checkpointing(self, enable=True):
+ self.transformer.grad_checkpointing = enable
+
+ def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ if self.global_average_pool:
+ return x.mean(dim=1), x
+ else:
+ return x[:, 0], x[:, 1:]
+
+ def forward(self, x: torch.Tensor, skip_pool: bool = False):
+
+ # to patches - whether to use dual patchnorm - https://arxiv.org/abs/2302.01327v1
+ if self.input_patchnorm:
+ # einops - rearrange(x, 'b c (h p1) (w p2) -> b (h w) (c p1 p2)')
+ x = x.reshape(x.shape[0], x.shape[1], self.grid_size[0], self.patch_size[0], self.grid_size[1], self.patch_size[1])
+ x = x.permute(0, 2, 4, 1, 3, 5)
+ x = x.reshape(x.shape[0], self.grid_size[0] * self.grid_size[1], -1)
+ x = self.patchnorm_pre_ln(x)
+ x = self.conv1(x)
+ else:
+ x = self.conv1(x) # shape = [*, width, grid, grid]
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
+
+ # class embeddings and positional embeddings
+ x = torch.cat(
+ [self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
+ x], dim=1) # shape = [*, grid ** 2 + 1, width]
+ x = x + self.positional_embedding.to(x.dtype)
+
+ # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
+ x = self.patch_dropout(x)
+ x = self.ln_pre(x)
+
+ x = x.permute(1, 0, 2) # NLD -> LND
+ x = self.transformer(x)
+ x = x.permute(1, 0, 2) # LND -> NLD
+
+ if skip_pool:
+ return x
+
+ if self.attn_pool is not None:
+ x = self.attn_pool(x)
+ x = self.ln_post(x)
+ pooled, tokens = self._global_pool(x)
+ else:
+ pooled, tokens = self._global_pool(x)
+ pooled = self.ln_post(pooled)
+
+ if self.proj is not None:
+ pooled = pooled @ self.proj
+
+ if self.output_tokens:
+ return pooled, tokens
+
+ return pooled
+
+
+class TextTransformer(nn.Module):
+ output_tokens: torch.jit.Final[bool]
+
+ def __init__(
+ self,
+ context_length: int = 77,
+ vocab_size: int = 49408,
+ width: int = 512,
+ heads: int = 8,
+ layers: int = 12,
+ ls_init_value: float = None,
+ output_dim: int = 512,
+ act_layer: Callable = nn.GELU,
+ norm_layer: Callable = LayerNorm,
+ embed_cls: bool = False,
+ pad_id: int = 0,
+ output_tokens: bool = False,
+ ):
+ super().__init__()
+ self.output_tokens = output_tokens
+ self.num_pos = self.context_length = context_length
+ self.vocab_size = vocab_size
+ self.width = width
+ self.output_dim = output_dim
+ self.heads = heads
+ self.pad_id = pad_id
+
+ self.text_projection = nn.Parameter(torch.empty(width, output_dim))
+
+ if embed_cls:
+ self.cls_emb = nn.Parameter(torch.empty(width))
+ self.num_pos += 1
+ else:
+ self.cls_emb = None
+
+ self.token_embedding = nn.Embedding(vocab_size, width)
+ self.positional_embedding = nn.Parameter(torch.empty(self.num_pos, width))
+ self.transformer = Transformer(
+ width=width,
+ layers=layers,
+ heads=heads,
+ ls_init_value=ls_init_value,
+ act_layer=act_layer,
+ norm_layer=norm_layer,
+ )
+ self.ln_final = norm_layer(width)
+
+ self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False)
+
+ self.init_parameters()
+
+ def init_parameters(self):
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
+ nn.init.normal_(self.positional_embedding, std=0.01)
+ if self.cls_emb is not None:
+ nn.init.normal_(self.cls_emb, std=0.01)
+
+ proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
+ attn_std = self.transformer.width ** -0.5
+ fc_std = (2 * self.transformer.width) ** -0.5
+ for block in self.transformer.resblocks:
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
+
+ if self.text_projection is not None:
+ nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
+
+ @torch.jit.ignore
+ def set_grad_checkpointing(self, enable=True):
+ self.transformer.grad_checkpointing = enable
+
+ def build_attention_mask(self):
+ # lazily create causal attention mask, with full attention between the tokens
+ # pytorch uses additive attention mask; fill with -inf
+ mask = torch.empty(self.num_pos, self.num_pos)
+ mask.fill_(float("-inf"))
+ mask.triu_(1) # zero out the lower diagonal
+ return mask
+
+ def build_cls_mask(self, text, cast_dtype: torch.dtype):
+ cls_mask = (text != self.pad_id).unsqueeze(1)
+ cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=1.0)
+ additive_mask = torch.empty(cls_mask.shape, dtype=cast_dtype, device=cls_mask.device)
+ additive_mask.fill_(0)
+ additive_mask.masked_fill_(~cls_mask, float("-inf"))
+ additive_mask = torch.repeat_interleave(additive_mask, self.heads, 0)
+ return additive_mask
+
+ def _repeat(self, t, N: int):
+ return t.reshape(1, 1, -1).repeat(N, 1, 1)
+
+ def forward(self, text):
+ cast_dtype = self.transformer.get_cast_dtype()
+ seq_len = text.shape[1]
+
+ x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
+ attn_mask = self.attn_mask
+ if self.cls_emb is not None:
+ seq_len += 1
+ x = torch.cat([x, self._repeat(self.cls_emb, x.shape[0])], dim=1)
+ cls_mask = self.build_cls_mask(text, cast_dtype)
+ attn_mask = attn_mask[None, :seq_len, :seq_len] + cls_mask[:, :seq_len, :seq_len]
+
+ x = x + self.positional_embedding[:seq_len].to(cast_dtype)
+ x = x.permute(1, 0, 2) # NLD -> LND
+ x = self.transformer(x, attn_mask=attn_mask)
+ x = x.permute(1, 0, 2) # LND -> NLD
+
+ # x.shape = [batch_size, n_ctx, transformer.width]
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
+ if self.cls_emb is not None:
+ pooled, tokens = x[:, -1], x[:, :-1]
+ pooled = self.ln_final(pooled)
+ else:
+ x = self.ln_final(x)
+ pooled, tokens = x[torch.arange(x.shape[0]), text.argmax(dim=-1)], x
+
+ if self.text_projection is not None:
+ pooled = pooled @ self.text_projection
+
+ if self.output_tokens:
+ return pooled, tokens
+
+ return pooled
+
+
+class MultimodalTransformer(Transformer):
+ def __init__(
+ self,
+ width: int,
+ layers: int,
+ heads: int,
+ context_length: int = 77,
+ mlp_ratio: float = 4.0,
+ ls_init_value: float = None,
+ act_layer: Callable = nn.GELU,
+ norm_layer: Callable = LayerNorm,
+ output_dim: int = 512,
+ ):
+
+ super().__init__(
+ width=width,
+ layers=layers,
+ heads=heads,
+ mlp_ratio=mlp_ratio,
+ ls_init_value=ls_init_value,
+ act_layer=act_layer,
+ norm_layer=norm_layer,
+ )
+ self.context_length = context_length
+ self.cross_attn = nn.ModuleList([
+ ResidualAttentionBlock(
+ width,
+ heads,
+ mlp_ratio,
+ ls_init_value=ls_init_value,
+ act_layer=act_layer,
+ norm_layer=norm_layer,
+ is_cross_attention=True,
+ )
+ for _ in range(layers)
+ ])
+
+ self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False)
+
+ self.ln_final = norm_layer(width)
+ self.text_projection = nn.Parameter(torch.empty(width, output_dim))
+
+ def init_parameters(self):
+ proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
+ attn_std = self.transformer.width ** -0.5
+ fc_std = (2 * self.transformer.width) ** -0.5
+ for block in self.transformer.resblocks:
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
+ for block in self.transformer.cross_attn:
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
+
+ if self.text_projection is not None:
+ nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
+
+ def build_attention_mask(self):
+ # lazily create causal attention mask, with full attention between the tokens
+ # pytorch uses additive attention mask; fill with -inf
+ mask = torch.empty(self.context_length, self.context_length)
+ mask.fill_(float("-inf"))
+ mask.triu_(1) # zero out the lower diagonal
+ return mask
+
+ def forward(self, image_embs, text_embs):
+ text_embs = text_embs.permute(1, 0, 2) # NLD -> LNDsq
+ image_embs = image_embs.permute(1, 0, 2) # NLD -> LND
+ seq_len = text_embs.shape[0]
+
+ for resblock, cross_attn in zip(self.resblocks, self.cross_attn):
+ if self.grad_checkpointing and not torch.jit.is_scripting():
+ # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372
+ text_embs = checkpoint(resblock, text_embs, None, None, self.attn_mask[:seq_len, :seq_len])
+ text_embs = checkpoint(cross_attn, text_embs, image_embs, image_embs, None)
+ else:
+ text_embs = resblock(text_embs, attn_mask=self.attn_mask[:seq_len, :seq_len])
+ text_embs = cross_attn(text_embs, k_x=image_embs, v_x=image_embs)
+
+ x = text_embs.permute(1, 0, 2) # LND -> NLD
+ x = self.ln_final(x)
+
+ if self.text_projection is not None:
+ x = x @ self.text_projection
+
+ return x
+
+ @torch.jit.ignore
+ def set_grad_checkpointing(self, enable=True):
+ self.grad_checkpointing = enable
diff --git a/diffsynth/extensions/ImageQualityMetric/open_clip/utils.py b/diffsynth/extensions/ImageQualityMetric/open_clip/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..51e80c5e296b24cae130ab0459baf268e0db7673
--- /dev/null
+++ b/diffsynth/extensions/ImageQualityMetric/open_clip/utils.py
@@ -0,0 +1,60 @@
+from itertools import repeat
+import collections.abc
+
+from torch import nn as nn
+from torchvision.ops.misc import FrozenBatchNorm2d
+
+
+def freeze_batch_norm_2d(module, module_match={}, name=''):
+ """
+ Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is
+ itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and
+ returned. Otherwise, the module is walked recursively and submodules are converted in place.
+
+ Args:
+ module (torch.nn.Module): Any PyTorch module.
+ module_match (dict): Dictionary of full module names to freeze (all if empty)
+ name (str): Full module name (prefix)
+
+ Returns:
+ torch.nn.Module: Resulting module
+
+ Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
+ """
+ res = module
+ is_match = True
+ if module_match:
+ is_match = name in module_match
+ if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)):
+ res = FrozenBatchNorm2d(module.num_features)
+ res.num_features = module.num_features
+ res.affine = module.affine
+ if module.affine:
+ res.weight.data = module.weight.data.clone().detach()
+ res.bias.data = module.bias.data.clone().detach()
+ res.running_mean.data = module.running_mean.data
+ res.running_var.data = module.running_var.data
+ res.eps = module.eps
+ else:
+ for child_name, child in module.named_children():
+ full_child_name = '.'.join([name, child_name]) if name else child_name
+ new_child = freeze_batch_norm_2d(child, module_match, full_child_name)
+ if new_child is not child:
+ res.add_module(child_name, new_child)
+ return res
+
+
+# From PyTorch internals
+def _ntuple(n):
+ def parse(x):
+ if isinstance(x, collections.abc.Iterable):
+ return x
+ return tuple(repeat(x, n))
+ return parse
+
+
+to_1tuple = _ntuple(1)
+to_2tuple = _ntuple(2)
+to_3tuple = _ntuple(3)
+to_4tuple = _ntuple(4)
+to_ntuple = lambda n, x: _ntuple(n)(x)
diff --git a/diffsynth/extensions/ImageQualityMetric/open_clip/version.py b/diffsynth/extensions/ImageQualityMetric/open_clip/version.py
new file mode 100644
index 0000000000000000000000000000000000000000..48aa744fb053599044caf0253b889b5cfe5b78e7
--- /dev/null
+++ b/diffsynth/extensions/ImageQualityMetric/open_clip/version.py
@@ -0,0 +1 @@
+__version__ = '2.16.0'
diff --git a/diffsynth/extensions/ImageQualityMetric/pickscore.py b/diffsynth/extensions/ImageQualityMetric/pickscore.py
new file mode 100644
index 0000000000000000000000000000000000000000..7370e099724997d98f1c4ad3fc5f14c861202665
--- /dev/null
+++ b/diffsynth/extensions/ImageQualityMetric/pickscore.py
@@ -0,0 +1,112 @@
+import torch
+from PIL import Image
+from transformers import AutoProcessor, AutoModel
+from typing import List, Union
+import os
+from .config import MODEL_PATHS
+
+class PickScore(torch.nn.Module):
+ def __init__(self, device: Union[str, torch.device], path: str = MODEL_PATHS):
+ super().__init__()
+ """Initialize the Selector with a processor and model.
+
+ Args:
+ device (Union[str, torch.device]): The device to load the model on.
+ """
+ self.device = device if isinstance(device, torch.device) else torch.device(device)
+ processor_name_or_path = path.get("clip")
+ model_pretrained_name_or_path = path.get("pickscore")
+ self.processor = AutoProcessor.from_pretrained(processor_name_or_path)
+ self.model = AutoModel.from_pretrained(model_pretrained_name_or_path).eval().to(self.device)
+
+ def _calculate_score(self, image: torch.Tensor, prompt: str, softmax: bool = False) -> float:
+ """Calculate the score for a single image and prompt.
+
+ Args:
+ image (torch.Tensor): The processed image tensor.
+ prompt (str): The prompt text.
+ softmax (bool): Whether to apply softmax to the scores.
+
+ Returns:
+ float: The score for the image.
+ """
+ with torch.no_grad():
+ # Prepare text inputs
+ text_inputs = self.processor(
+ text=prompt,
+ padding=True,
+ truncation=True,
+ max_length=77,
+ return_tensors="pt",
+ ).to(self.device)
+
+ # Embed images and text
+ image_embs = self.model.get_image_features(pixel_values=image)
+ image_embs = image_embs / torch.norm(image_embs, dim=-1, keepdim=True)
+ text_embs = self.model.get_text_features(**text_inputs)
+ text_embs = text_embs / torch.norm(text_embs, dim=-1, keepdim=True)
+
+ # Compute score
+ score = (text_embs @ image_embs.T)[0]
+ if softmax:
+ # Apply logit scale and softmax
+ score = torch.softmax(self.model.logit_scale.exp() * score, dim=-1)
+
+ return score.cpu().item()
+
+ @torch.no_grad()
+ def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str, softmax: bool = False) -> List[float]:
+ """Score the images based on the prompt.
+
+ Args:
+ images (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s).
+ prompt (str): The prompt text.
+ softmax (bool): Whether to apply softmax to the scores.
+
+ Returns:
+ List[float]: List of scores for the images.
+ """
+ try:
+ if isinstance(images, (str, Image.Image)):
+ # Single image
+ if isinstance(images, str):
+ pil_image = Image.open(images)
+ else:
+ pil_image = images
+
+ # Prepare image inputs
+ image_inputs = self.processor(
+ images=pil_image,
+ padding=True,
+ truncation=True,
+ max_length=77,
+ return_tensors="pt",
+ ).to(self.device)
+
+ return [self._calculate_score(image_inputs["pixel_values"], prompt, softmax)]
+ elif isinstance(images, list):
+ # Multiple images
+ scores = []
+ for one_image in images:
+ if isinstance(one_image, str):
+ pil_image = Image.open(one_image)
+ elif isinstance(one_image, Image.Image):
+ pil_image = one_image
+ else:
+ raise TypeError("The type of parameter images is illegal.")
+
+ # Prepare image inputs
+ image_inputs = self.processor(
+ images=pil_image,
+ padding=True,
+ truncation=True,
+ max_length=77,
+ return_tensors="pt",
+ ).to(self.device)
+
+ scores.append(self._calculate_score(image_inputs["pixel_values"], prompt, softmax))
+ return scores
+ else:
+ raise TypeError("The type of parameter images is illegal.")
+ except Exception as e:
+ raise RuntimeError(f"Error in scoring images: {e}")
diff --git a/diffsynth/extensions/ImageQualityMetric/trainer/__init__.py b/diffsynth/extensions/ImageQualityMetric/trainer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..cf4f59d6c0977e578ab67ec92c916c7e38842715
--- /dev/null
+++ b/diffsynth/extensions/ImageQualityMetric/trainer/__init__.py
@@ -0,0 +1 @@
+from .models import *
\ No newline at end of file
diff --git a/diffsynth/extensions/RIFE/__init__.py b/diffsynth/extensions/RIFE/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e76c391f0b085b3628592990a868ac09f37cced7
--- /dev/null
+++ b/diffsynth/extensions/RIFE/__init__.py
@@ -0,0 +1,242 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+from PIL import Image
+
+
+def warp(tenInput, tenFlow, device):
+ backwarp_tenGrid = {}
+ k = (str(tenFlow.device), str(tenFlow.size()))
+ if k not in backwarp_tenGrid:
+ tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=device).view(
+ 1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1)
+ tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=device).view(
+ 1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3])
+ backwarp_tenGrid[k] = torch.cat(
+ [tenHorizontal, tenVertical], 1).to(device)
+
+ tenFlow = torch.cat([tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0),
+ tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0)], 1)
+
+ g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1)
+ return torch.nn.functional.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode='border', align_corners=True)
+
+
+def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
+ return nn.Sequential(
+ nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
+ padding=padding, dilation=dilation, bias=True),
+ nn.PReLU(out_planes)
+ )
+
+
+class IFBlock(nn.Module):
+ def __init__(self, in_planes, c=64):
+ super(IFBlock, self).__init__()
+ self.conv0 = nn.Sequential(conv(in_planes, c//2, 3, 2, 1), conv(c//2, c, 3, 2, 1),)
+ self.convblock0 = nn.Sequential(conv(c, c), conv(c, c))
+ self.convblock1 = nn.Sequential(conv(c, c), conv(c, c))
+ self.convblock2 = nn.Sequential(conv(c, c), conv(c, c))
+ self.convblock3 = nn.Sequential(conv(c, c), conv(c, c))
+ self.conv1 = nn.Sequential(nn.ConvTranspose2d(c, c//2, 4, 2, 1), nn.PReLU(c//2), nn.ConvTranspose2d(c//2, 4, 4, 2, 1))
+ self.conv2 = nn.Sequential(nn.ConvTranspose2d(c, c//2, 4, 2, 1), nn.PReLU(c//2), nn.ConvTranspose2d(c//2, 1, 4, 2, 1))
+
+ def forward(self, x, flow, scale=1):
+ x = F.interpolate(x, scale_factor= 1. / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False)
+ flow = F.interpolate(flow, scale_factor= 1. / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 1. / scale
+ feat = self.conv0(torch.cat((x, flow), 1))
+ feat = self.convblock0(feat) + feat
+ feat = self.convblock1(feat) + feat
+ feat = self.convblock2(feat) + feat
+ feat = self.convblock3(feat) + feat
+ flow = self.conv1(feat)
+ mask = self.conv2(feat)
+ flow = F.interpolate(flow, scale_factor=scale, mode="bilinear", align_corners=False, recompute_scale_factor=False) * scale
+ mask = F.interpolate(mask, scale_factor=scale, mode="bilinear", align_corners=False, recompute_scale_factor=False)
+ return flow, mask
+
+
+class IFNet(nn.Module):
+ def __init__(self, **kwargs):
+ super(IFNet, self).__init__()
+ self.block0 = IFBlock(7+4, c=90)
+ self.block1 = IFBlock(7+4, c=90)
+ self.block2 = IFBlock(7+4, c=90)
+ self.block_tea = IFBlock(10+4, c=90)
+
+ def forward(self, x, scale_list=[4, 2, 1], training=False):
+ if training == False:
+ channel = x.shape[1] // 2
+ img0 = x[:, :channel]
+ img1 = x[:, channel:]
+ flow_list = []
+ merged = []
+ mask_list = []
+ warped_img0 = img0
+ warped_img1 = img1
+ flow = (x[:, :4]).detach() * 0
+ mask = (x[:, :1]).detach() * 0
+ block = [self.block0, self.block1, self.block2]
+ for i in range(3):
+ f0, m0 = block[i](torch.cat((warped_img0[:, :3], warped_img1[:, :3], mask), 1), flow, scale=scale_list[i])
+ f1, m1 = block[i](torch.cat((warped_img1[:, :3], warped_img0[:, :3], -mask), 1), torch.cat((flow[:, 2:4], flow[:, :2]), 1), scale=scale_list[i])
+ flow = flow + (f0 + torch.cat((f1[:, 2:4], f1[:, :2]), 1)) / 2
+ mask = mask + (m0 + (-m1)) / 2
+ mask_list.append(mask)
+ flow_list.append(flow)
+ warped_img0 = warp(img0, flow[:, :2], device=x.device)
+ warped_img1 = warp(img1, flow[:, 2:4], device=x.device)
+ merged.append((warped_img0, warped_img1))
+ '''
+ c0 = self.contextnet(img0, flow[:, :2])
+ c1 = self.contextnet(img1, flow[:, 2:4])
+ tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1)
+ res = tmp[:, 1:4] * 2 - 1
+ '''
+ for i in range(3):
+ mask_list[i] = torch.sigmoid(mask_list[i])
+ merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i])
+ return flow_list, mask_list[2], merged
+
+ @staticmethod
+ def state_dict_converter():
+ return IFNetStateDictConverter()
+
+
+class IFNetStateDictConverter:
+ def __init__(self):
+ pass
+
+ def from_diffusers(self, state_dict):
+ state_dict_ = {k.replace("module.", ""): v for k, v in state_dict.items()}
+ return state_dict_
+
+ def from_civitai(self, state_dict):
+ return self.from_diffusers(state_dict), {"upcast_to_float32": True}
+
+
+class RIFEInterpolater:
+ def __init__(self, model, device="cuda"):
+ self.model = model
+ self.device = device
+ # IFNet only does not support float16
+ self.torch_dtype = torch.float32
+
+ @staticmethod
+ def from_model_manager(model_manager):
+ return RIFEInterpolater(model_manager.fetch_model("rife"), device=model_manager.device)
+
+ def process_image(self, image):
+ width, height = image.size
+ if width % 32 != 0 or height % 32 != 0:
+ width = (width + 31) // 32
+ height = (height + 31) // 32
+ image = image.resize((width, height))
+ image = torch.Tensor(np.array(image, dtype=np.float32)[:, :, [2,1,0]] / 255).permute(2, 0, 1)
+ return image
+
+ def process_images(self, images):
+ images = [self.process_image(image) for image in images]
+ images = torch.stack(images)
+ return images
+
+ def decode_images(self, images):
+ images = (images[:, [2,1,0]].permute(0, 2, 3, 1) * 255).clip(0, 255).numpy().astype(np.uint8)
+ images = [Image.fromarray(image) for image in images]
+ return images
+
+ def add_interpolated_images(self, images, interpolated_images):
+ output_images = []
+ for image, interpolated_image in zip(images, interpolated_images):
+ output_images.append(image)
+ output_images.append(interpolated_image)
+ output_images.append(images[-1])
+ return output_images
+
+
+ @torch.no_grad()
+ def interpolate_(self, images, scale=1.0):
+ input_tensor = self.process_images(images)
+ input_tensor = torch.cat((input_tensor[:-1], input_tensor[1:]), dim=1)
+ input_tensor = input_tensor.to(device=self.device, dtype=self.torch_dtype)
+ flow, mask, merged = self.model(input_tensor, [4/scale, 2/scale, 1/scale])
+ output_images = self.decode_images(merged[2].cpu())
+ if output_images[0].size != images[0].size:
+ output_images = [image.resize(images[0].size) for image in output_images]
+ return output_images
+
+
+ @torch.no_grad()
+ def interpolate(self, images, scale=1.0, batch_size=4, num_iter=1, progress_bar=lambda x:x):
+ # Preprocess
+ processed_images = self.process_images(images)
+
+ for iter in range(num_iter):
+ # Input
+ input_tensor = torch.cat((processed_images[:-1], processed_images[1:]), dim=1)
+
+ # Interpolate
+ output_tensor = []
+ for batch_id in progress_bar(range(0, input_tensor.shape[0], batch_size)):
+ batch_id_ = min(batch_id + batch_size, input_tensor.shape[0])
+ batch_input_tensor = input_tensor[batch_id: batch_id_]
+ batch_input_tensor = batch_input_tensor.to(device=self.device, dtype=self.torch_dtype)
+ flow, mask, merged = self.model(batch_input_tensor, [4/scale, 2/scale, 1/scale])
+ output_tensor.append(merged[2].cpu())
+
+ # Output
+ output_tensor = torch.concat(output_tensor, dim=0).clip(0, 1)
+ processed_images = self.add_interpolated_images(processed_images, output_tensor)
+ processed_images = torch.stack(processed_images)
+
+ # To images
+ output_images = self.decode_images(processed_images)
+ if output_images[0].size != images[0].size:
+ output_images = [image.resize(images[0].size) for image in output_images]
+ return output_images
+
+
+class RIFESmoother(RIFEInterpolater):
+ def __init__(self, model, device="cuda"):
+ super(RIFESmoother, self).__init__(model, device=device)
+
+ @staticmethod
+ def from_model_manager(model_manager):
+ return RIFEInterpolater(model_manager.fetch_model("rife"), device=model_manager.device)
+
+ def process_tensors(self, input_tensor, scale=1.0, batch_size=4):
+ output_tensor = []
+ for batch_id in range(0, input_tensor.shape[0], batch_size):
+ batch_id_ = min(batch_id + batch_size, input_tensor.shape[0])
+ batch_input_tensor = input_tensor[batch_id: batch_id_]
+ batch_input_tensor = batch_input_tensor.to(device=self.device, dtype=self.torch_dtype)
+ flow, mask, merged = self.model(batch_input_tensor, [4/scale, 2/scale, 1/scale])
+ output_tensor.append(merged[2].cpu())
+ output_tensor = torch.concat(output_tensor, dim=0)
+ return output_tensor
+
+ @torch.no_grad()
+ def __call__(self, rendered_frames, scale=1.0, batch_size=4, num_iter=1, **kwargs):
+ # Preprocess
+ processed_images = self.process_images(rendered_frames)
+
+ for iter in range(num_iter):
+ # Input
+ input_tensor = torch.cat((processed_images[:-2], processed_images[2:]), dim=1)
+
+ # Interpolate
+ output_tensor = self.process_tensors(input_tensor, scale=scale, batch_size=batch_size)
+
+ # Blend
+ input_tensor = torch.cat((processed_images[1:-1], output_tensor), dim=1)
+ output_tensor = self.process_tensors(input_tensor, scale=scale, batch_size=batch_size)
+
+ # Add to frames
+ processed_images[1:-1] = output_tensor
+
+ # To images
+ output_images = self.decode_images(processed_images)
+ if output_images[0].size != rendered_frames[0].size:
+ output_images = [image.resize(rendered_frames[0].size) for image in output_images]
+ return output_images
diff --git a/diffsynth/extensions/__init__.py b/diffsynth/extensions/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/diffsynth/lora/__init__.py b/diffsynth/lora/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..33bd89c21b74ec0bba8b7b69468a3901e45360fd
--- /dev/null
+++ b/diffsynth/lora/__init__.py
@@ -0,0 +1,45 @@
+import torch
+
+
+
+class GeneralLoRALoader:
+ def __init__(self, device="cpu", torch_dtype=torch.float32):
+ self.device = device
+ self.torch_dtype = torch_dtype
+
+
+ def get_name_dict(self, lora_state_dict):
+ lora_name_dict = {}
+ for key in lora_state_dict:
+ if ".lora_B." not in key:
+ continue
+ keys = key.split(".")
+ if len(keys) > keys.index("lora_B") + 2:
+ keys.pop(keys.index("lora_B") + 1)
+ keys.pop(keys.index("lora_B"))
+ if keys[0] == "diffusion_model":
+ keys.pop(0)
+ keys.pop(-1)
+ target_name = ".".join(keys)
+ lora_name_dict[target_name] = (key, key.replace(".lora_B.", ".lora_A."))
+ return lora_name_dict
+
+
+ def load(self, model: torch.nn.Module, state_dict_lora, alpha=1.0):
+ updated_num = 0
+ lora_name_dict = self.get_name_dict(state_dict_lora)
+ for name, module in model.named_modules():
+ if name in lora_name_dict:
+ weight_up = state_dict_lora[lora_name_dict[name][0]].to(device=self.device, dtype=self.torch_dtype)
+ weight_down = state_dict_lora[lora_name_dict[name][1]].to(device=self.device, dtype=self.torch_dtype)
+ if len(weight_up.shape) == 4:
+ weight_up = weight_up.squeeze(3).squeeze(2)
+ weight_down = weight_down.squeeze(3).squeeze(2)
+ weight_lora = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
+ else:
+ weight_lora = alpha * torch.mm(weight_up, weight_down)
+ state_dict = module.state_dict()
+ state_dict["weight"] = state_dict["weight"].to(device=self.device, dtype=self.torch_dtype) + weight_lora
+ module.load_state_dict(state_dict)
+ updated_num += 1
+ print(f"{updated_num} tensors are updated by LoRA.")
diff --git a/diffsynth/lora/flux_lora.py b/diffsynth/lora/flux_lora.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb53b7342221dd749c9a704a09f6669f08f684b7
--- /dev/null
+++ b/diffsynth/lora/flux_lora.py
@@ -0,0 +1,324 @@
+import torch, math
+from . import GeneralLoRALoader
+from ..utils import ModelConfig
+from ..models.utils import load_state_dict
+from typing import Union
+
+
+class FluxLoRALoader(GeneralLoRALoader):
+ def __init__(self, device="cpu", torch_dtype=torch.float32):
+ super().__init__(device=device, torch_dtype=torch_dtype)
+
+ self.diffusers_rename_dict = {
+ "transformer.single_transformer_blocks.blockid.attn.to_k.lora_A.weight":"single_blocks.blockid.a_to_k.lora_A.default.weight",
+ "transformer.single_transformer_blocks.blockid.attn.to_k.lora_B.weight":"single_blocks.blockid.a_to_k.lora_B.default.weight",
+ "transformer.single_transformer_blocks.blockid.attn.to_q.lora_A.weight":"single_blocks.blockid.a_to_q.lora_A.default.weight",
+ "transformer.single_transformer_blocks.blockid.attn.to_q.lora_B.weight":"single_blocks.blockid.a_to_q.lora_B.default.weight",
+ "transformer.single_transformer_blocks.blockid.attn.to_v.lora_A.weight":"single_blocks.blockid.a_to_v.lora_A.default.weight",
+ "transformer.single_transformer_blocks.blockid.attn.to_v.lora_B.weight":"single_blocks.blockid.a_to_v.lora_B.default.weight",
+ "transformer.single_transformer_blocks.blockid.norm.linear.lora_A.weight":"single_blocks.blockid.norm.linear.lora_A.default.weight",
+ "transformer.single_transformer_blocks.blockid.norm.linear.lora_B.weight":"single_blocks.blockid.norm.linear.lora_B.default.weight",
+ "transformer.single_transformer_blocks.blockid.proj_mlp.lora_A.weight":"single_blocks.blockid.proj_in_besides_attn.lora_A.default.weight",
+ "transformer.single_transformer_blocks.blockid.proj_mlp.lora_B.weight":"single_blocks.blockid.proj_in_besides_attn.lora_B.default.weight",
+ "transformer.single_transformer_blocks.blockid.proj_out.lora_A.weight":"single_blocks.blockid.proj_out.lora_A.default.weight",
+ "transformer.single_transformer_blocks.blockid.proj_out.lora_B.weight":"single_blocks.blockid.proj_out.lora_B.default.weight",
+ "transformer.transformer_blocks.blockid.attn.add_k_proj.lora_A.weight":"blocks.blockid.attn.b_to_k.lora_A.default.weight",
+ "transformer.transformer_blocks.blockid.attn.add_k_proj.lora_B.weight":"blocks.blockid.attn.b_to_k.lora_B.default.weight",
+ "transformer.transformer_blocks.blockid.attn.add_q_proj.lora_A.weight":"blocks.blockid.attn.b_to_q.lora_A.default.weight",
+ "transformer.transformer_blocks.blockid.attn.add_q_proj.lora_B.weight":"blocks.blockid.attn.b_to_q.lora_B.default.weight",
+ "transformer.transformer_blocks.blockid.attn.add_v_proj.lora_A.weight":"blocks.blockid.attn.b_to_v.lora_A.default.weight",
+ "transformer.transformer_blocks.blockid.attn.add_v_proj.lora_B.weight":"blocks.blockid.attn.b_to_v.lora_B.default.weight",
+ "transformer.transformer_blocks.blockid.attn.to_add_out.lora_A.weight":"blocks.blockid.attn.b_to_out.lora_A.default.weight",
+ "transformer.transformer_blocks.blockid.attn.to_add_out.lora_B.weight":"blocks.blockid.attn.b_to_out.lora_B.default.weight",
+ "transformer.transformer_blocks.blockid.attn.to_k.lora_A.weight":"blocks.blockid.attn.a_to_k.lora_A.default.weight",
+ "transformer.transformer_blocks.blockid.attn.to_k.lora_B.weight":"blocks.blockid.attn.a_to_k.lora_B.default.weight",
+ "transformer.transformer_blocks.blockid.attn.to_out.0.lora_A.weight":"blocks.blockid.attn.a_to_out.lora_A.default.weight",
+ "transformer.transformer_blocks.blockid.attn.to_out.0.lora_B.weight":"blocks.blockid.attn.a_to_out.lora_B.default.weight",
+ "transformer.transformer_blocks.blockid.attn.to_q.lora_A.weight":"blocks.blockid.attn.a_to_q.lora_A.default.weight",
+ "transformer.transformer_blocks.blockid.attn.to_q.lora_B.weight":"blocks.blockid.attn.a_to_q.lora_B.default.weight",
+ "transformer.transformer_blocks.blockid.attn.to_v.lora_A.weight":"blocks.blockid.attn.a_to_v.lora_A.default.weight",
+ "transformer.transformer_blocks.blockid.attn.to_v.lora_B.weight":"blocks.blockid.attn.a_to_v.lora_B.default.weight",
+ "transformer.transformer_blocks.blockid.ff.net.0.proj.lora_A.weight":"blocks.blockid.ff_a.0.lora_A.default.weight",
+ "transformer.transformer_blocks.blockid.ff.net.0.proj.lora_B.weight":"blocks.blockid.ff_a.0.lora_B.default.weight",
+ "transformer.transformer_blocks.blockid.ff.net.2.lora_A.weight":"blocks.blockid.ff_a.2.lora_A.default.weight",
+ "transformer.transformer_blocks.blockid.ff.net.2.lora_B.weight":"blocks.blockid.ff_a.2.lora_B.default.weight",
+ "transformer.transformer_blocks.blockid.ff_context.net.0.proj.lora_A.weight":"blocks.blockid.ff_b.0.lora_A.default.weight",
+ "transformer.transformer_blocks.blockid.ff_context.net.0.proj.lora_B.weight":"blocks.blockid.ff_b.0.lora_B.default.weight",
+ "transformer.transformer_blocks.blockid.ff_context.net.2.lora_A.weight":"blocks.blockid.ff_b.2.lora_A.default.weight",
+ "transformer.transformer_blocks.blockid.ff_context.net.2.lora_B.weight":"blocks.blockid.ff_b.2.lora_B.default.weight",
+ "transformer.transformer_blocks.blockid.norm1.linear.lora_A.weight":"blocks.blockid.norm1_a.linear.lora_A.default.weight",
+ "transformer.transformer_blocks.blockid.norm1.linear.lora_B.weight":"blocks.blockid.norm1_a.linear.lora_B.default.weight",
+ "transformer.transformer_blocks.blockid.norm1_context.linear.lora_A.weight":"blocks.blockid.norm1_b.linear.lora_A.default.weight",
+ "transformer.transformer_blocks.blockid.norm1_context.linear.lora_B.weight":"blocks.blockid.norm1_b.linear.lora_B.default.weight",
+ }
+
+ self.civitai_rename_dict = {
+ "lora_unet_double_blocks_blockid_img_mod_lin.lora_down.weight": "blocks.blockid.norm1_a.linear.lora_A.default.weight",
+ "lora_unet_double_blocks_blockid_img_mod_lin.lora_up.weight": "blocks.blockid.norm1_a.linear.lora_B.default.weight",
+ "lora_unet_double_blocks_blockid_txt_mod_lin.lora_down.weight": "blocks.blockid.norm1_b.linear.lora_A.default.weight",
+ "lora_unet_double_blocks_blockid_txt_mod_lin.lora_up.weight": "blocks.blockid.norm1_b.linear.lora_B.default.weight",
+ "lora_unet_double_blocks_blockid_img_attn_qkv.lora_down.weight": "blocks.blockid.attn.a_to_qkv.lora_A.default.weight",
+ "lora_unet_double_blocks_blockid_img_attn_qkv.lora_up.weight": "blocks.blockid.attn.a_to_qkv.lora_B.default.weight",
+ "lora_unet_double_blocks_blockid_txt_attn_qkv.lora_down.weight": "blocks.blockid.attn.b_to_qkv.lora_A.default.weight",
+ "lora_unet_double_blocks_blockid_txt_attn_qkv.lora_up.weight": "blocks.blockid.attn.b_to_qkv.lora_B.default.weight",
+ "lora_unet_double_blocks_blockid_img_attn_proj.lora_down.weight": "blocks.blockid.attn.a_to_out.lora_A.default.weight",
+ "lora_unet_double_blocks_blockid_img_attn_proj.lora_up.weight": "blocks.blockid.attn.a_to_out.lora_B.default.weight",
+ "lora_unet_double_blocks_blockid_txt_attn_proj.lora_down.weight": "blocks.blockid.attn.b_to_out.lora_A.default.weight",
+ "lora_unet_double_blocks_blockid_txt_attn_proj.lora_up.weight": "blocks.blockid.attn.b_to_out.lora_B.default.weight",
+ "lora_unet_double_blocks_blockid_img_mlp_0.lora_down.weight": "blocks.blockid.ff_a.0.lora_A.default.weight",
+ "lora_unet_double_blocks_blockid_img_mlp_0.lora_up.weight": "blocks.blockid.ff_a.0.lora_B.default.weight",
+ "lora_unet_double_blocks_blockid_img_mlp_2.lora_down.weight": "blocks.blockid.ff_a.2.lora_A.default.weight",
+ "lora_unet_double_blocks_blockid_img_mlp_2.lora_up.weight": "blocks.blockid.ff_a.2.lora_B.default.weight",
+ "lora_unet_double_blocks_blockid_txt_mlp_0.lora_down.weight": "blocks.blockid.ff_b.0.lora_A.default.weight",
+ "lora_unet_double_blocks_blockid_txt_mlp_0.lora_up.weight": "blocks.blockid.ff_b.0.lora_B.default.weight",
+ "lora_unet_double_blocks_blockid_txt_mlp_2.lora_down.weight": "blocks.blockid.ff_b.2.lora_A.default.weight",
+ "lora_unet_double_blocks_blockid_txt_mlp_2.lora_up.weight": "blocks.blockid.ff_b.2.lora_B.default.weight",
+ "lora_unet_single_blocks_blockid_modulation_lin.lora_down.weight": "single_blocks.blockid.norm.linear.lora_A.default.weight",
+ "lora_unet_single_blocks_blockid_modulation_lin.lora_up.weight": "single_blocks.blockid.norm.linear.lora_B.default.weight",
+ "lora_unet_single_blocks_blockid_linear1.lora_down.weight": "single_blocks.blockid.to_qkv_mlp.lora_A.default.weight",
+ "lora_unet_single_blocks_blockid_linear1.lora_up.weight": "single_blocks.blockid.to_qkv_mlp.lora_B.default.weight",
+ "lora_unet_single_blocks_blockid_linear2.lora_down.weight": "single_blocks.blockid.proj_out.lora_A.default.weight",
+ "lora_unet_single_blocks_blockid_linear2.lora_up.weight": "single_blocks.blockid.proj_out.lora_B.default.weight",
+ }
+
+ def load(self, model: torch.nn.Module, state_dict_lora, alpha=1.0):
+ super().load(model, state_dict_lora, alpha)
+
+
+ def convert_state_dict(self,state_dict):
+
+ def guess_block_id(name,model_resource):
+ if model_resource == 'civitai':
+ names = name.split("_")
+ for i in names:
+ if i.isdigit():
+ return i, name.replace(f"_{i}_", "_blockid_")
+ if model_resource == 'diffusers':
+ names = name.split(".")
+ for i in names:
+ if i.isdigit():
+ return i, name.replace(f"transformer_blocks.{i}.", "transformer_blocks.blockid.")
+ return None, None
+
+ def guess_resource(state_dict):
+ for k in state_dict:
+ if "lora_unet_" in k:
+ return 'civitai'
+ elif k.startswith("transformer."):
+ return 'diffusers'
+ else:
+ None
+
+ model_resource = guess_resource(state_dict)
+ if model_resource is None:
+ return state_dict
+
+ rename_dict = self.diffusers_rename_dict if model_resource == 'diffusers' else self.civitai_rename_dict
+ def guess_alpha(state_dict):
+ for name, param in state_dict.items():
+ if ".alpha" in name:
+ for suffix in [".lora_down.weight", ".lora_A.weight"]:
+ name_ = name.replace(".alpha", suffix)
+ if name_ in state_dict:
+ lora_alpha = param.item() / state_dict[name_].shape[0]
+ lora_alpha = math.sqrt(lora_alpha)
+ return lora_alpha
+
+ return 1
+
+ alpha = guess_alpha(state_dict)
+
+ state_dict_ = {}
+ for name, param in state_dict.items():
+ block_id, source_name = guess_block_id(name,model_resource)
+ if alpha != 1:
+ param *= alpha
+ if source_name in rename_dict:
+ target_name = rename_dict[source_name]
+ target_name = target_name.replace(".blockid.", f".{block_id}.")
+ state_dict_[target_name] = param
+ else:
+ state_dict_[name] = param
+
+ if model_resource == 'diffusers':
+ for name in list(state_dict_.keys()):
+ if "single_blocks." in name and ".a_to_q." in name:
+ mlp = state_dict_.get(name.replace(".a_to_q.", ".proj_in_besides_attn."), None)
+ if mlp is None:
+ dim = 4
+ if 'lora_A' in name:
+ dim = 1
+ mlp = torch.zeros(dim * state_dict_[name].shape[0],
+ *state_dict_[name].shape[1:],
+ dtype=state_dict_[name].dtype)
+ else:
+ state_dict_.pop(name.replace(".a_to_q.", ".proj_in_besides_attn."))
+ if 'lora_A' in name:
+ param = torch.concat([
+ state_dict_.pop(name),
+ state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")),
+ state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")),
+ mlp,
+ ], dim=0)
+ elif 'lora_B' in name:
+ d, r = state_dict_[name].shape
+ param = torch.zeros((3*d+mlp.shape[0], 3*r+mlp.shape[1]), dtype=state_dict_[name].dtype, device=state_dict_[name].device)
+ param[:d, :r] = state_dict_.pop(name)
+ param[d:2*d, r:2*r] = state_dict_.pop(name.replace(".a_to_q.", ".a_to_k."))
+ param[2*d:3*d, 2*r:3*r] = state_dict_.pop(name.replace(".a_to_q.", ".a_to_v."))
+ param[3*d:, 3*r:] = mlp
+ else:
+ param = torch.concat([
+ state_dict_.pop(name),
+ state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")),
+ state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")),
+ mlp,
+ ], dim=0)
+ name_ = name.replace(".a_to_q.", ".to_qkv_mlp.")
+ state_dict_[name_] = param
+ for name in list(state_dict_.keys()):
+ for component in ["a", "b"]:
+ if f".{component}_to_q." in name:
+ name_ = name.replace(f".{component}_to_q.", f".{component}_to_qkv.")
+ concat_dim = 0
+ if 'lora_A' in name:
+ param = torch.concat([
+ state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")],
+ state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")],
+ state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")],
+ ], dim=0)
+ elif 'lora_B' in name:
+ origin = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")]
+ d, r = origin.shape
+ # print(d, r)
+ param = torch.zeros((3*d, 3*r), dtype=origin.dtype, device=origin.device)
+ param[:d, :r] = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")]
+ param[d:2*d, r:2*r] = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")]
+ param[2*d:3*d, 2*r:3*r] = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")]
+ else:
+ param = torch.concat([
+ state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")],
+ state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")],
+ state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")],
+ ], dim=0)
+ state_dict_[name_] = param
+ state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_q."))
+ state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k."))
+ state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v."))
+ return state_dict_
+
+
+class LoraMerger(torch.nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+ self.weight_base = torch.nn.Parameter(torch.randn((dim,)))
+ self.weight_lora = torch.nn.Parameter(torch.randn((dim,)))
+ self.weight_cross = torch.nn.Parameter(torch.randn((dim,)))
+ self.weight_out = torch.nn.Parameter(torch.ones((dim,)))
+ self.bias = torch.nn.Parameter(torch.randn((dim,)))
+ self.activation = torch.nn.Sigmoid()
+ self.norm_base = torch.nn.LayerNorm(dim, eps=1e-5)
+ self.norm_lora = torch.nn.LayerNorm(dim, eps=1e-5)
+
+ def forward(self, base_output, lora_outputs):
+ norm_base_output = self.norm_base(base_output)
+ norm_lora_outputs = self.norm_lora(lora_outputs)
+ gate = self.activation(
+ norm_base_output * self.weight_base \
+ + norm_lora_outputs * self.weight_lora \
+ + norm_base_output * norm_lora_outputs * self.weight_cross + self.bias
+ )
+ output = base_output + (self.weight_out * gate * lora_outputs).sum(dim=0)
+ return output
+
+
+class FluxLoraPatcher(torch.nn.Module):
+ def __init__(self, lora_patterns=None):
+ super().__init__()
+ if lora_patterns is None:
+ lora_patterns = self.default_lora_patterns()
+ model_dict = {}
+ for lora_pattern in lora_patterns:
+ name, dim = lora_pattern["name"], lora_pattern["dim"]
+ model_dict[name.replace(".", "___")] = LoraMerger(dim)
+ self.model_dict = torch.nn.ModuleDict(model_dict)
+
+ def default_lora_patterns(self):
+ lora_patterns = []
+ lora_dict = {
+ "attn.a_to_qkv": 9216, "attn.a_to_out": 3072, "ff_a.0": 12288, "ff_a.2": 3072, "norm1_a.linear": 18432,
+ "attn.b_to_qkv": 9216, "attn.b_to_out": 3072, "ff_b.0": 12288, "ff_b.2": 3072, "norm1_b.linear": 18432,
+ }
+ for i in range(19):
+ for suffix in lora_dict:
+ lora_patterns.append({
+ "name": f"blocks.{i}.{suffix}",
+ "dim": lora_dict[suffix]
+ })
+ lora_dict = {"to_qkv_mlp": 21504, "proj_out": 3072, "norm.linear": 9216}
+ for i in range(38):
+ for suffix in lora_dict:
+ lora_patterns.append({
+ "name": f"single_blocks.{i}.{suffix}",
+ "dim": lora_dict[suffix]
+ })
+ return lora_patterns
+
+ def forward(self, base_output, lora_outputs, name):
+ return self.model_dict[name.replace(".", "___")](base_output, lora_outputs)
+
+ @staticmethod
+ def state_dict_converter():
+ return FluxLoraPatcherStateDictConverter()
+
+
+class FluxLoraPatcherStateDictConverter:
+ def __init__(self):
+ pass
+
+ def from_civitai(self, state_dict):
+ return state_dict
+
+
+class FluxLoRAFuser:
+ def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
+ self.device = device
+ self.torch_dtype = torch_dtype
+
+ def Matrix_Decomposition_lowrank(self, A, k):
+ U, S, V = torch.svd_lowrank(A.float(), q=k)
+ S_k = torch.diag(S[:k])
+ U_hat = U @ S_k
+ return U_hat, V.t()
+
+ def LoRA_State_Dicts_Decomposition(self, lora_state_dicts=[], q=4):
+ lora_1 = lora_state_dicts[0]
+ state_dict_ = {}
+ for k,v in lora_1.items():
+ if 'lora_A.' in k:
+ lora_B_name = k.replace('lora_A.', 'lora_B.')
+ lora_B = lora_1[lora_B_name]
+ weight = torch.mm(lora_B, v)
+ for lora_dict in lora_state_dicts[1:]:
+ lora_A_ = lora_dict[k]
+ lora_B_ = lora_dict[lora_B_name]
+ weight_ = torch.mm(lora_B_, lora_A_)
+ weight += weight_
+ new_B, new_A = self.Matrix_Decomposition_lowrank(weight, q)
+ state_dict_[lora_B_name] = new_B.to(dtype=torch.bfloat16)
+ state_dict_[k] = new_A.to(dtype=torch.bfloat16)
+ return state_dict_
+
+ def __call__(self, lora_configs: list[Union[ModelConfig, str]]):
+ loras = []
+ loader = FluxLoRALoader(torch_dtype=self.torch_dtype, device=self.device)
+ for lora_config in lora_configs:
+ if isinstance(lora_config, str):
+ lora = load_state_dict(lora_config, torch_dtype=self.torch_dtype, device=self.device)
+ else:
+ lora_config.download_if_necessary()
+ lora = load_state_dict(lora_config.path, torch_dtype=self.torch_dtype, device=self.device)
+ lora = loader.convert_state_dict(lora)
+ loras.append(lora)
+ lora = self.LoRA_State_Dicts_Decomposition(loras)
+ return lora
diff --git a/diffsynth/pipelines/__init__.py b/diffsynth/pipelines/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2ad5516a01787b9e2ce5ba54228466dd7b57d8e
--- /dev/null
+++ b/diffsynth/pipelines/__init__.py
@@ -0,0 +1,15 @@
+from .sd_image import SDImagePipeline
+from .sd_video import SDVideoPipeline
+from .sdxl_image import SDXLImagePipeline
+from .sdxl_video import SDXLVideoPipeline
+from .sd3_image import SD3ImagePipeline
+from .hunyuan_image import HunyuanDiTImagePipeline
+from .svd_video import SVDVideoPipeline
+from .flux_image import FluxImagePipeline
+from .cog_video import CogVideoPipeline
+from .omnigen_image import OmnigenImagePipeline
+from .pipeline_runner import SDVideoPipelineRunner
+from .hunyuan_video import HunyuanVideoPipeline
+from .step_video import StepVideoPipeline
+from .wan_video import WanVideoPipeline
+KolorsImagePipeline = SDXLImagePipeline
diff --git a/diffsynth/pipelines/base.py b/diffsynth/pipelines/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..2a4f01cff55dc0fcca02dc5234227bd65efc7434
--- /dev/null
+++ b/diffsynth/pipelines/base.py
@@ -0,0 +1,127 @@
+import torch
+import numpy as np
+from PIL import Image
+from torchvision.transforms import GaussianBlur
+
+
+
+class BasePipeline(torch.nn.Module):
+
+ def __init__(self, device="cuda", torch_dtype=torch.float16, height_division_factor=64, width_division_factor=64):
+ super().__init__()
+ self.device = device
+ self.torch_dtype = torch_dtype
+ self.height_division_factor = height_division_factor
+ self.width_division_factor = width_division_factor
+ self.cpu_offload = False
+ self.model_names = []
+
+
+ def check_resize_height_width(self, height, width):
+ if height % self.height_division_factor != 0:
+ height = (height + self.height_division_factor - 1) // self.height_division_factor * self.height_division_factor
+ print(f"The height cannot be evenly divided by {self.height_division_factor}. We round it up to {height}.")
+ if width % self.width_division_factor != 0:
+ width = (width + self.width_division_factor - 1) // self.width_division_factor * self.width_division_factor
+ print(f"The width cannot be evenly divided by {self.width_division_factor}. We round it up to {width}.")
+ return height, width
+
+
+ def preprocess_image(self, image):
+ image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
+ return image
+
+
+ def preprocess_images(self, images):
+ return [self.preprocess_image(image) for image in images]
+
+
+ def vae_output_to_image(self, vae_output):
+ image = vae_output[0].cpu().float().permute(1, 2, 0).numpy()
+ image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
+ return image
+
+
+ def vae_output_to_video(self, vae_output):
+ video = vae_output.cpu().permute(1, 2, 0).numpy()
+ video = [Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8")) for image in video]
+ return video
+
+
+ def merge_latents(self, value, latents, masks, scales, blur_kernel_size=33, blur_sigma=10.0):
+ if len(latents) > 0:
+ blur = GaussianBlur(kernel_size=blur_kernel_size, sigma=blur_sigma)
+ height, width = value.shape[-2:]
+ weight = torch.ones_like(value)
+ for latent, mask, scale in zip(latents, masks, scales):
+ mask = self.preprocess_image(mask.resize((width, height))).mean(dim=1, keepdim=True) > 0
+ mask = mask.repeat(1, latent.shape[1], 1, 1).to(dtype=latent.dtype, device=latent.device)
+ mask = blur(mask)
+ value += latent * mask * scale
+ weight += mask * scale
+ value /= weight
+ return value
+
+
+ def control_noise_via_local_prompts(self, prompt_emb_global, prompt_emb_locals, masks, mask_scales, inference_callback, special_kwargs=None, special_local_kwargs_list=None):
+ if special_kwargs is None:
+ noise_pred_global = inference_callback(prompt_emb_global)
+ else:
+ noise_pred_global = inference_callback(prompt_emb_global, special_kwargs)
+ if special_local_kwargs_list is None:
+ noise_pred_locals = [inference_callback(prompt_emb_local) for prompt_emb_local in prompt_emb_locals]
+ else:
+ noise_pred_locals = [inference_callback(prompt_emb_local, special_kwargs) for prompt_emb_local, special_kwargs in zip(prompt_emb_locals, special_local_kwargs_list)]
+ noise_pred = self.merge_latents(noise_pred_global, noise_pred_locals, masks, mask_scales)
+ return noise_pred
+
+
+ def extend_prompt(self, prompt, local_prompts, masks, mask_scales):
+ local_prompts = local_prompts or []
+ masks = masks or []
+ mask_scales = mask_scales or []
+ extended_prompt_dict = self.prompter.extend_prompt(prompt)
+ prompt = extended_prompt_dict.get("prompt", prompt)
+ local_prompts += extended_prompt_dict.get("prompts", [])
+ masks += extended_prompt_dict.get("masks", [])
+ mask_scales += [100.0] * len(extended_prompt_dict.get("masks", []))
+ return prompt, local_prompts, masks, mask_scales
+
+
+ def enable_cpu_offload(self):
+ self.cpu_offload = True
+
+
+ def load_models_to_device(self, loadmodel_names=[]):
+ # only load models to device if cpu_offload is enabled
+ if not self.cpu_offload:
+ return
+ # offload the unneeded models to cpu
+ for model_name in self.model_names:
+ if model_name not in loadmodel_names:
+ model = getattr(self, model_name)
+ if model is not None:
+ if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
+ for module in model.modules():
+ if hasattr(module, "offload"):
+ module.offload()
+ else:
+ model.cpu()
+ # load the needed models to device
+ for model_name in loadmodel_names:
+ model = getattr(self, model_name)
+ if model is not None:
+ if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
+ for module in model.modules():
+ if hasattr(module, "onload"):
+ module.onload()
+ else:
+ model.to(self.device)
+ # fresh the cuda cache
+ torch.cuda.empty_cache()
+
+
+ def generate_noise(self, shape, seed=None, device="cpu", dtype=torch.float16):
+ generator = None if seed is None else torch.Generator(device).manual_seed(seed)
+ noise = torch.randn(shape, generator=generator, device=device, dtype=dtype)
+ return noise
diff --git a/diffsynth/pipelines/cog_video.py b/diffsynth/pipelines/cog_video.py
new file mode 100644
index 0000000000000000000000000000000000000000..f42d295187e718617cc7d4e327067700f2a689fd
--- /dev/null
+++ b/diffsynth/pipelines/cog_video.py
@@ -0,0 +1,135 @@
+from ..models import ModelManager, FluxTextEncoder2, CogDiT, CogVAEEncoder, CogVAEDecoder
+from ..prompters import CogPrompter
+from ..schedulers import EnhancedDDIMScheduler
+from .base import BasePipeline
+import torch
+from tqdm import tqdm
+from PIL import Image
+import numpy as np
+from einops import rearrange
+
+
+
+class CogVideoPipeline(BasePipeline):
+
+ def __init__(self, device="cuda", torch_dtype=torch.float16):
+ super().__init__(device=device, torch_dtype=torch_dtype, height_division_factor=16, width_division_factor=16)
+ self.scheduler = EnhancedDDIMScheduler(rescale_zero_terminal_snr=True, prediction_type="v_prediction")
+ self.prompter = CogPrompter()
+ # models
+ self.text_encoder: FluxTextEncoder2 = None
+ self.dit: CogDiT = None
+ self.vae_encoder: CogVAEEncoder = None
+ self.vae_decoder: CogVAEDecoder = None
+
+
+ def fetch_models(self, model_manager: ModelManager, prompt_refiner_classes=[]):
+ self.text_encoder = model_manager.fetch_model("flux_text_encoder_2")
+ self.dit = model_manager.fetch_model("cog_dit")
+ self.vae_encoder = model_manager.fetch_model("cog_vae_encoder")
+ self.vae_decoder = model_manager.fetch_model("cog_vae_decoder")
+ self.prompter.fetch_models(self.text_encoder)
+ self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes)
+
+
+ @staticmethod
+ def from_model_manager(model_manager: ModelManager, prompt_refiner_classes=[]):
+ pipe = CogVideoPipeline(
+ device=model_manager.device,
+ torch_dtype=model_manager.torch_dtype
+ )
+ pipe.fetch_models(model_manager, prompt_refiner_classes)
+ return pipe
+
+
+ def tensor2video(self, frames):
+ frames = rearrange(frames, "C T H W -> T H W C")
+ frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8)
+ frames = [Image.fromarray(frame) for frame in frames]
+ return frames
+
+
+ def encode_prompt(self, prompt, positive=True):
+ prompt_emb = self.prompter.encode_prompt(prompt, device=self.device, positive=positive)
+ return {"prompt_emb": prompt_emb}
+
+
+ def prepare_extra_input(self, latents):
+ return {"image_rotary_emb": self.dit.prepare_rotary_positional_embeddings(latents.shape[3], latents.shape[4], latents.shape[2], device=self.device)}
+
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt,
+ negative_prompt="",
+ input_video=None,
+ cfg_scale=7.0,
+ denoising_strength=1.0,
+ num_frames=49,
+ height=480,
+ width=720,
+ num_inference_steps=20,
+ tiled=False,
+ tile_size=(60, 90),
+ tile_stride=(30, 45),
+ seed=None,
+ progress_bar_cmd=tqdm,
+ progress_bar_st=None,
+ ):
+ height, width = self.check_resize_height_width(height, width)
+
+ # Tiler parameters
+ tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
+
+ # Prepare scheduler
+ self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength)
+
+ # Prepare latent tensors
+ noise = self.generate_noise((1, 16, num_frames // 4 + 1, height//8, width//8), seed=seed, device="cpu", dtype=self.torch_dtype)
+
+ if denoising_strength == 1.0:
+ latents = noise.clone()
+ else:
+ input_video = self.preprocess_images(input_video)
+ input_video = torch.stack(input_video, dim=2)
+ latents = self.vae_encoder.encode_video(input_video, **tiler_kwargs, progress_bar=progress_bar_cmd).to(dtype=self.torch_dtype)
+ latents = self.scheduler.add_noise(latents, noise, self.scheduler.timesteps[0])
+ if not tiled: latents = latents.to(self.device)
+
+ # Encode prompt
+ prompt_emb_posi = self.encode_prompt(prompt, positive=True)
+ if cfg_scale != 1.0:
+ prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False)
+
+ # Extra input
+ extra_input = self.prepare_extra_input(latents)
+
+ # Denoise
+ for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
+ timestep = timestep.unsqueeze(0).to(self.device)
+
+ # Classifier-free guidance
+ noise_pred_posi = self.dit(
+ latents, timestep=timestep, **prompt_emb_posi, **tiler_kwargs, **extra_input
+ )
+ if cfg_scale != 1.0:
+ noise_pred_nega = self.dit(
+ latents, timestep=timestep, **prompt_emb_nega, **tiler_kwargs, **extra_input
+ )
+ noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
+ else:
+ noise_pred = noise_pred_posi
+
+ # DDIM
+ latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
+
+ # Update progress bar
+ if progress_bar_st is not None:
+ progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
+
+ # Decode image
+ video = self.vae_decoder.decode_video(latents.to("cpu"), **tiler_kwargs, progress_bar=progress_bar_cmd)
+ video = self.tensor2video(video[0])
+
+ return video
diff --git a/diffsynth/pipelines/dancer.py b/diffsynth/pipelines/dancer.py
new file mode 100644
index 0000000000000000000000000000000000000000..593b57c8363f94e312debf7c7f69bf6decdb7dbd
--- /dev/null
+++ b/diffsynth/pipelines/dancer.py
@@ -0,0 +1,236 @@
+import torch
+from ..models import SDUNet, SDMotionModel, SDXLUNet, SDXLMotionModel
+from ..models.sd_unet import PushBlock, PopBlock
+from ..controlnets import MultiControlNetManager
+
+
+def lets_dance(
+ unet: SDUNet,
+ motion_modules: SDMotionModel = None,
+ controlnet: MultiControlNetManager = None,
+ sample = None,
+ timestep = None,
+ encoder_hidden_states = None,
+ ipadapter_kwargs_list = {},
+ controlnet_frames = None,
+ unet_batch_size = 1,
+ controlnet_batch_size = 1,
+ cross_frame_attention = False,
+ tiled=False,
+ tile_size=64,
+ tile_stride=32,
+ device = "cuda",
+ vram_limit_level = 0,
+):
+ # 0. Text embedding alignment (only for video processing)
+ if encoder_hidden_states.shape[0] != sample.shape[0]:
+ encoder_hidden_states = encoder_hidden_states.repeat(sample.shape[0], 1, 1, 1)
+
+ # 1. ControlNet
+ # This part will be repeated on overlapping frames if animatediff_batch_size > animatediff_stride.
+ # I leave it here because I intend to do something interesting on the ControlNets.
+ controlnet_insert_block_id = 30
+ if controlnet is not None and controlnet_frames is not None:
+ res_stacks = []
+ # process controlnet frames with batch
+ for batch_id in range(0, sample.shape[0], controlnet_batch_size):
+ batch_id_ = min(batch_id + controlnet_batch_size, sample.shape[0])
+ res_stack = controlnet(
+ sample[batch_id: batch_id_],
+ timestep,
+ encoder_hidden_states[batch_id: batch_id_],
+ controlnet_frames[:, batch_id: batch_id_],
+ tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
+ )
+ if vram_limit_level >= 1:
+ res_stack = [res.cpu() for res in res_stack]
+ res_stacks.append(res_stack)
+ # concat the residual
+ additional_res_stack = []
+ for i in range(len(res_stacks[0])):
+ res = torch.concat([res_stack[i] for res_stack in res_stacks], dim=0)
+ additional_res_stack.append(res)
+ else:
+ additional_res_stack = None
+
+ # 2. time
+ time_emb = unet.time_proj(timestep).to(sample.dtype)
+ time_emb = unet.time_embedding(time_emb)
+
+ # 3. pre-process
+ height, width = sample.shape[2], sample.shape[3]
+ hidden_states = unet.conv_in(sample)
+ text_emb = encoder_hidden_states
+ res_stack = [hidden_states.cpu() if vram_limit_level>=1 else hidden_states]
+
+ # 4. blocks
+ for block_id, block in enumerate(unet.blocks):
+ # 4.1 UNet
+ if isinstance(block, PushBlock):
+ hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
+ if vram_limit_level>=1:
+ res_stack[-1] = res_stack[-1].cpu()
+ elif isinstance(block, PopBlock):
+ if vram_limit_level>=1:
+ res_stack[-1] = res_stack[-1].to(device)
+ hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
+ else:
+ hidden_states_input = hidden_states
+ hidden_states_output = []
+ for batch_id in range(0, sample.shape[0], unet_batch_size):
+ batch_id_ = min(batch_id + unet_batch_size, sample.shape[0])
+ hidden_states, _, _, _ = block(
+ hidden_states_input[batch_id: batch_id_],
+ time_emb,
+ text_emb[batch_id: batch_id_],
+ res_stack,
+ cross_frame_attention=cross_frame_attention,
+ ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, {}),
+ tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
+ )
+ hidden_states_output.append(hidden_states)
+ hidden_states = torch.concat(hidden_states_output, dim=0)
+ # 4.2 AnimateDiff
+ if motion_modules is not None:
+ if block_id in motion_modules.call_block_id:
+ motion_module_id = motion_modules.call_block_id[block_id]
+ hidden_states, time_emb, text_emb, res_stack = motion_modules.motion_modules[motion_module_id](
+ hidden_states, time_emb, text_emb, res_stack,
+ batch_size=1
+ )
+ # 4.3 ControlNet
+ if block_id == controlnet_insert_block_id and additional_res_stack is not None:
+ hidden_states += additional_res_stack.pop().to(device)
+ if vram_limit_level>=1:
+ res_stack = [(res.to(device) + additional_res.to(device)).cpu() for res, additional_res in zip(res_stack, additional_res_stack)]
+ else:
+ res_stack = [res + additional_res for res, additional_res in zip(res_stack, additional_res_stack)]
+
+ # 5. output
+ hidden_states = unet.conv_norm_out(hidden_states)
+ hidden_states = unet.conv_act(hidden_states)
+ hidden_states = unet.conv_out(hidden_states)
+
+ return hidden_states
+
+
+
+
+def lets_dance_xl(
+ unet: SDXLUNet,
+ motion_modules: SDXLMotionModel = None,
+ controlnet: MultiControlNetManager = None,
+ sample = None,
+ add_time_id = None,
+ add_text_embeds = None,
+ timestep = None,
+ encoder_hidden_states = None,
+ ipadapter_kwargs_list = {},
+ controlnet_frames = None,
+ unet_batch_size = 1,
+ controlnet_batch_size = 1,
+ cross_frame_attention = False,
+ tiled=False,
+ tile_size=64,
+ tile_stride=32,
+ device = "cuda",
+ vram_limit_level = 0,
+):
+ # 0. Text embedding alignment (only for video processing)
+ if encoder_hidden_states.shape[0] != sample.shape[0]:
+ encoder_hidden_states = encoder_hidden_states.repeat(sample.shape[0], 1, 1, 1)
+ if add_text_embeds.shape[0] != sample.shape[0]:
+ add_text_embeds = add_text_embeds.repeat(sample.shape[0], 1)
+
+ # 1. ControlNet
+ controlnet_insert_block_id = 22
+ if controlnet is not None and controlnet_frames is not None:
+ res_stacks = []
+ # process controlnet frames with batch
+ for batch_id in range(0, sample.shape[0], controlnet_batch_size):
+ batch_id_ = min(batch_id + controlnet_batch_size, sample.shape[0])
+ res_stack = controlnet(
+ sample[batch_id: batch_id_],
+ timestep,
+ encoder_hidden_states[batch_id: batch_id_],
+ controlnet_frames[:, batch_id: batch_id_],
+ add_time_id=add_time_id,
+ add_text_embeds=add_text_embeds,
+ tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
+ unet=unet, # for Kolors, some modules in ControlNets will be replaced.
+ )
+ if vram_limit_level >= 1:
+ res_stack = [res.cpu() for res in res_stack]
+ res_stacks.append(res_stack)
+ # concat the residual
+ additional_res_stack = []
+ for i in range(len(res_stacks[0])):
+ res = torch.concat([res_stack[i] for res_stack in res_stacks], dim=0)
+ additional_res_stack.append(res)
+ else:
+ additional_res_stack = None
+
+ # 2. time
+ t_emb = unet.time_proj(timestep).to(sample.dtype)
+ t_emb = unet.time_embedding(t_emb)
+
+ time_embeds = unet.add_time_proj(add_time_id)
+ time_embeds = time_embeds.reshape((add_text_embeds.shape[0], -1))
+ add_embeds = torch.concat([add_text_embeds, time_embeds], dim=-1)
+ add_embeds = add_embeds.to(sample.dtype)
+ add_embeds = unet.add_time_embedding(add_embeds)
+
+ time_emb = t_emb + add_embeds
+
+ # 3. pre-process
+ height, width = sample.shape[2], sample.shape[3]
+ hidden_states = unet.conv_in(sample)
+ text_emb = encoder_hidden_states if unet.text_intermediate_proj is None else unet.text_intermediate_proj(encoder_hidden_states)
+ res_stack = [hidden_states]
+
+ # 4. blocks
+ for block_id, block in enumerate(unet.blocks):
+ # 4.1 UNet
+ if isinstance(block, PushBlock):
+ hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
+ if vram_limit_level>=1:
+ res_stack[-1] = res_stack[-1].cpu()
+ elif isinstance(block, PopBlock):
+ if vram_limit_level>=1:
+ res_stack[-1] = res_stack[-1].to(device)
+ hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
+ else:
+ hidden_states_input = hidden_states
+ hidden_states_output = []
+ for batch_id in range(0, sample.shape[0], unet_batch_size):
+ batch_id_ = min(batch_id + unet_batch_size, sample.shape[0])
+ hidden_states, _, _, _ = block(
+ hidden_states_input[batch_id: batch_id_],
+ time_emb[batch_id: batch_id_],
+ text_emb[batch_id: batch_id_],
+ res_stack,
+ cross_frame_attention=cross_frame_attention,
+ ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, {}),
+ tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
+ )
+ hidden_states_output.append(hidden_states)
+ hidden_states = torch.concat(hidden_states_output, dim=0)
+ # 4.2 AnimateDiff
+ if motion_modules is not None:
+ if block_id in motion_modules.call_block_id:
+ motion_module_id = motion_modules.call_block_id[block_id]
+ hidden_states, time_emb, text_emb, res_stack = motion_modules.motion_modules[motion_module_id](
+ hidden_states, time_emb, text_emb, res_stack,
+ batch_size=1
+ )
+ # 4.3 ControlNet
+ if block_id == controlnet_insert_block_id and additional_res_stack is not None:
+ hidden_states += additional_res_stack.pop().to(device)
+ res_stack = [res + additional_res for res, additional_res in zip(res_stack, additional_res_stack)]
+
+ # 5. output
+ hidden_states = unet.conv_norm_out(hidden_states)
+ hidden_states = unet.conv_act(hidden_states)
+ hidden_states = unet.conv_out(hidden_states)
+
+ return hidden_states
\ No newline at end of file
diff --git a/diffsynth/pipelines/flux_image.py b/diffsynth/pipelines/flux_image.py
new file mode 100644
index 0000000000000000000000000000000000000000..55a84c06d81bfa6a1a663a08db375de12a3aea57
--- /dev/null
+++ b/diffsynth/pipelines/flux_image.py
@@ -0,0 +1,823 @@
+from ..models import ModelManager, FluxDiT, SD3TextEncoder1, FluxTextEncoder2, FluxVAEDecoder, FluxVAEEncoder, FluxIpAdapter
+from ..models.step1x_connector import Qwen2Connector
+from ..controlnets import FluxMultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
+from ..prompters import FluxPrompter
+from ..schedulers import FlowMatchScheduler
+from .base import BasePipeline
+from typing import List
+import torch
+from tqdm import tqdm
+import numpy as np
+from PIL import Image
+from ..models.tiler import FastTileWorker
+from transformers import SiglipVisionModel
+from copy import deepcopy
+from transformers.models.t5.modeling_t5 import T5LayerNorm, T5DenseActDense, T5DenseGatedActDense
+from ..models.flux_dit import RMSNorm
+from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear
+
+
+class FluxImagePipeline(BasePipeline):
+
+ def __init__(self, device="cuda", torch_dtype=torch.float16):
+ super().__init__(device=device, torch_dtype=torch_dtype, height_division_factor=16, width_division_factor=16)
+ self.scheduler = FlowMatchScheduler()
+ self.prompter = FluxPrompter()
+ # models
+ self.text_encoder_1: SD3TextEncoder1 = None
+ self.text_encoder_2: FluxTextEncoder2 = None
+ self.dit: FluxDiT = None
+ self.vae_decoder: FluxVAEDecoder = None
+ self.vae_encoder: FluxVAEEncoder = None
+ self.controlnet: FluxMultiControlNetManager = None
+ self.ipadapter: FluxIpAdapter = None
+ self.ipadapter_image_encoder: SiglipVisionModel = None
+ self.infinityou_processor: InfinitYou = None
+ self.qwenvl = None
+ self.step1x_connector: Qwen2Connector = None
+ self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder', 'vae_encoder', 'controlnet', 'ipadapter', 'ipadapter_image_encoder', 'qwenvl', 'step1x_connector']
+
+
+ def enable_vram_management(self, num_persistent_param_in_dit=None):
+ if self.text_encoder_1 is not None:
+ dtype = next(iter(self.text_encoder_1.parameters())).dtype
+ enable_vram_management(
+ self.text_encoder_1,
+ module_map = {
+ torch.nn.Linear: AutoWrappedLinear,
+ torch.nn.Embedding: AutoWrappedModule,
+ torch.nn.LayerNorm: AutoWrappedModule,
+ },
+ module_config = dict(
+ offload_dtype=dtype,
+ offload_device="cpu",
+ onload_dtype=dtype,
+ onload_device="cpu",
+ computation_dtype=self.torch_dtype,
+ computation_device=self.device,
+ ),
+ )
+ if self.text_encoder_2 is not None:
+ dtype = next(iter(self.text_encoder_2.parameters())).dtype
+ enable_vram_management(
+ self.text_encoder_2,
+ module_map = {
+ torch.nn.Linear: AutoWrappedLinear,
+ torch.nn.Embedding: AutoWrappedModule,
+ T5LayerNorm: AutoWrappedModule,
+ T5DenseActDense: AutoWrappedModule,
+ T5DenseGatedActDense: AutoWrappedModule,
+ },
+ module_config = dict(
+ offload_dtype=dtype,
+ offload_device="cpu",
+ onload_dtype=dtype,
+ onload_device="cpu",
+ computation_dtype=self.torch_dtype,
+ computation_device=self.device,
+ ),
+ )
+ if self.dit is not None:
+ dtype = next(iter(self.dit.parameters())).dtype
+ enable_vram_management(
+ self.dit,
+ module_map = {
+ RMSNorm: AutoWrappedModule,
+ torch.nn.Linear: AutoWrappedLinear,
+ },
+ module_config = dict(
+ offload_dtype=dtype,
+ offload_device="cpu",
+ onload_dtype=dtype,
+ onload_device="cuda",
+ computation_dtype=self.torch_dtype,
+ computation_device=self.device,
+ ),
+ max_num_param=num_persistent_param_in_dit,
+ overflow_module_config = dict(
+ offload_dtype=dtype,
+ offload_device="cpu",
+ onload_dtype=dtype,
+ onload_device="cpu",
+ computation_dtype=self.torch_dtype,
+ computation_device=self.device,
+ ),
+ )
+ if self.vae_decoder is not None:
+ dtype = next(iter(self.vae_decoder.parameters())).dtype
+ enable_vram_management(
+ self.vae_decoder,
+ module_map = {
+ torch.nn.Linear: AutoWrappedLinear,
+ torch.nn.Conv2d: AutoWrappedModule,
+ torch.nn.GroupNorm: AutoWrappedModule,
+ },
+ module_config = dict(
+ offload_dtype=dtype,
+ offload_device="cpu",
+ onload_dtype=dtype,
+ onload_device="cpu",
+ computation_dtype=self.torch_dtype,
+ computation_device=self.device,
+ ),
+ )
+ if self.vae_encoder is not None:
+ dtype = next(iter(self.vae_encoder.parameters())).dtype
+ enable_vram_management(
+ self.vae_encoder,
+ module_map = {
+ torch.nn.Linear: AutoWrappedLinear,
+ torch.nn.Conv2d: AutoWrappedModule,
+ torch.nn.GroupNorm: AutoWrappedModule,
+ },
+ module_config = dict(
+ offload_dtype=dtype,
+ offload_device="cpu",
+ onload_dtype=dtype,
+ onload_device="cpu",
+ computation_dtype=self.torch_dtype,
+ computation_device=self.device,
+ ),
+ )
+ self.enable_cpu_offload()
+
+
+ def denoising_model(self):
+ return self.dit
+
+
+ def fetch_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[], prompt_extender_classes=[]):
+ self.text_encoder_1 = model_manager.fetch_model("sd3_text_encoder_1")
+ self.text_encoder_2 = model_manager.fetch_model("flux_text_encoder_2")
+ self.dit = model_manager.fetch_model("flux_dit")
+ self.vae_decoder = model_manager.fetch_model("flux_vae_decoder")
+ self.vae_encoder = model_manager.fetch_model("flux_vae_encoder")
+ self.prompter.fetch_models(self.text_encoder_1, self.text_encoder_2)
+ self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes)
+ self.prompter.load_prompt_extenders(model_manager, prompt_extender_classes)
+
+ # ControlNets
+ controlnet_units = []
+ for config in controlnet_config_units:
+ controlnet_unit = ControlNetUnit(
+ Annotator(config.processor_id, device=self.device, skip_processor=config.skip_processor),
+ model_manager.fetch_model("flux_controlnet", config.model_path),
+ config.scale
+ )
+ controlnet_units.append(controlnet_unit)
+ self.controlnet = FluxMultiControlNetManager(controlnet_units)
+
+ # IP-Adapters
+ self.ipadapter = model_manager.fetch_model("flux_ipadapter")
+ self.ipadapter_image_encoder = model_manager.fetch_model("siglip_vision_model")
+
+ # InfiniteYou
+ self.image_proj_model = model_manager.fetch_model("infiniteyou_image_projector")
+ if self.image_proj_model is not None:
+ self.infinityou_processor = InfinitYou(device=self.device)
+
+ # Step1x
+ self.qwenvl = model_manager.fetch_model("qwenvl")
+ self.step1x_connector = model_manager.fetch_model("step1x_connector")
+
+
+ @staticmethod
+ def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[], prompt_extender_classes=[], device=None, torch_dtype=None):
+ pipe = FluxImagePipeline(
+ device=model_manager.device if device is None else device,
+ torch_dtype=model_manager.torch_dtype if torch_dtype is None else torch_dtype,
+ )
+ pipe.fetch_models(model_manager, controlnet_config_units, prompt_refiner_classes, prompt_extender_classes)
+ return pipe
+
+
+ def encode_image(self, image, tiled=False, tile_size=64, tile_stride=32):
+ latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
+ return latents
+
+
+ def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
+ image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
+ image = self.vae_output_to_image(image)
+ return image
+
+
+ def encode_prompt(self, prompt, positive=True, t5_sequence_length=512):
+ if self.text_encoder_1 is not None and self.text_encoder_2 is not None:
+ prompt_emb, pooled_prompt_emb, text_ids = self.prompter.encode_prompt(
+ prompt, device=self.device, positive=positive, t5_sequence_length=t5_sequence_length
+ )
+ return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb, "text_ids": text_ids}
+ else:
+ return {}
+
+
+ def prepare_extra_input(self, latents=None, guidance=1.0):
+ latent_image_ids = self.dit.prepare_image_ids(latents)
+ guidance = torch.Tensor([guidance] * latents.shape[0]).to(device=latents.device, dtype=latents.dtype)
+ return {"image_ids": latent_image_ids, "guidance": guidance}
+
+
+ def apply_controlnet_mask_on_latents(self, latents, mask):
+ mask = (self.preprocess_image(mask) + 1) / 2
+ mask = mask.mean(dim=1, keepdim=True)
+ mask = mask.to(dtype=self.torch_dtype, device=self.device)
+ mask = 1 - torch.nn.functional.interpolate(mask, size=latents.shape[-2:])
+ latents = torch.concat([latents, mask], dim=1)
+ return latents
+
+
+ def apply_controlnet_mask_on_image(self, image, mask):
+ mask = mask.resize(image.size)
+ mask = self.preprocess_image(mask).mean(dim=[0, 1])
+ image = np.array(image)
+ image[mask > 0] = 0
+ image = Image.fromarray(image)
+ return image
+
+
+ def prepare_controlnet_input(self, controlnet_image, controlnet_inpaint_mask, tiler_kwargs):
+ if isinstance(controlnet_image, Image.Image):
+ controlnet_image = [controlnet_image] * len(self.controlnet.processors)
+
+ controlnet_frames = []
+ for i in range(len(self.controlnet.processors)):
+ # image annotator
+ image = self.controlnet.process_image(controlnet_image[i], processor_id=i)[0]
+ if controlnet_inpaint_mask is not None and self.controlnet.processors[i].processor_id == "inpaint":
+ image = self.apply_controlnet_mask_on_image(image, controlnet_inpaint_mask)
+
+ # image to tensor
+ image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype)
+
+ # vae encoder
+ image = self.encode_image(image, **tiler_kwargs)
+ if controlnet_inpaint_mask is not None and self.controlnet.processors[i].processor_id == "inpaint":
+ image = self.apply_controlnet_mask_on_latents(image, controlnet_inpaint_mask)
+
+ # store it
+ controlnet_frames.append(image)
+ return controlnet_frames
+
+
+ def prepare_ipadapter_inputs(self, images, height=384, width=384):
+ images = [image.convert("RGB").resize((width, height), resample=3) for image in images]
+ images = [self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype) for image in images]
+ return torch.cat(images, dim=0)
+
+
+ def inpaint_fusion(self, latents, inpaint_latents, pred_noise, fg_mask, bg_mask, progress_id, background_weight=0.):
+ # inpaint noise
+ inpaint_noise = (latents - inpaint_latents) / self.scheduler.sigmas[progress_id]
+ # merge noise
+ weight = torch.ones_like(inpaint_noise)
+ inpaint_noise[fg_mask] = pred_noise[fg_mask]
+ inpaint_noise[bg_mask] += pred_noise[bg_mask] * background_weight
+ weight[bg_mask] += background_weight
+ inpaint_noise /= weight
+ return inpaint_noise
+
+
+ def preprocess_masks(self, masks, height, width, dim):
+ out_masks = []
+ for mask in masks:
+ mask = self.preprocess_image(mask.resize((width, height), resample=Image.NEAREST)).mean(dim=1, keepdim=True) > 0
+ mask = mask.repeat(1, dim, 1, 1).to(device=self.device, dtype=self.torch_dtype)
+ out_masks.append(mask)
+ return out_masks
+
+
+ def prepare_entity_inputs(self, entity_prompts, entity_masks, width, height, t5_sequence_length=512, enable_eligen_inpaint=False):
+ fg_mask, bg_mask = None, None
+ if enable_eligen_inpaint:
+ masks_ = deepcopy(entity_masks)
+ fg_masks = torch.cat([self.preprocess_image(mask.resize((width//8, height//8))).mean(dim=1, keepdim=True) for mask in masks_])
+ fg_masks = (fg_masks > 0).float()
+ fg_mask = fg_masks.sum(dim=0, keepdim=True).repeat(1, 16, 1, 1) > 0
+ bg_mask = ~fg_mask
+ entity_masks = self.preprocess_masks(entity_masks, height//8, width//8, 1)
+ entity_masks = torch.cat(entity_masks, dim=0).unsqueeze(0) # b, n_mask, c, h, w
+ entity_prompts = self.encode_prompt(entity_prompts, t5_sequence_length=t5_sequence_length)['prompt_emb'].unsqueeze(0)
+ return entity_prompts, entity_masks, fg_mask, bg_mask
+
+
+ def prepare_latents(self, input_image, height, width, seed, tiled, tile_size, tile_stride):
+ if input_image is not None:
+ self.load_models_to_device(['vae_encoder'])
+ image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
+ input_latents = self.encode_image(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
+ noise = self.generate_noise((1, 16, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
+ latents = self.scheduler.add_noise(input_latents, noise, timestep=self.scheduler.timesteps[0])
+ else:
+ latents = self.generate_noise((1, 16, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
+ input_latents = None
+ return latents, input_latents
+
+
+ def prepare_ipadapter(self, ipadapter_images, ipadapter_scale):
+ if ipadapter_images is not None:
+ self.load_models_to_device(['ipadapter_image_encoder'])
+ ipadapter_images = self.prepare_ipadapter_inputs(ipadapter_images)
+ ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images).pooler_output
+ self.load_models_to_device(['ipadapter'])
+ ipadapter_kwargs_list_posi = {"ipadapter_kwargs_list": self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)}
+ ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": self.ipadapter(torch.zeros_like(ipadapter_image_encoding))}
+ else:
+ ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": {}}, {"ipadapter_kwargs_list": {}}
+ return ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega
+
+
+ def prepare_controlnet(self, controlnet_image, masks, controlnet_inpaint_mask, tiler_kwargs, enable_controlnet_on_negative):
+ if controlnet_image is not None:
+ self.load_models_to_device(['vae_encoder'])
+ controlnet_kwargs_posi = {"controlnet_frames": self.prepare_controlnet_input(controlnet_image, controlnet_inpaint_mask, tiler_kwargs)}
+ if len(masks) > 0 and controlnet_inpaint_mask is not None:
+ print("The controlnet_inpaint_mask will be overridden by masks.")
+ local_controlnet_kwargs = [{"controlnet_frames": self.prepare_controlnet_input(controlnet_image, mask, tiler_kwargs)} for mask in masks]
+ else:
+ local_controlnet_kwargs = None
+ else:
+ controlnet_kwargs_posi, local_controlnet_kwargs = {"controlnet_frames": None}, [{}] * len(masks)
+ controlnet_kwargs_nega = controlnet_kwargs_posi if enable_controlnet_on_negative else {}
+ return controlnet_kwargs_posi, controlnet_kwargs_nega, local_controlnet_kwargs
+
+
+ def prepare_eligen(self, prompt_emb_nega, eligen_entity_prompts, eligen_entity_masks, width, height, t5_sequence_length, enable_eligen_inpaint, enable_eligen_on_negative, cfg_scale):
+ if eligen_entity_masks is not None:
+ entity_prompt_emb_posi, entity_masks_posi, fg_mask, bg_mask = self.prepare_entity_inputs(eligen_entity_prompts, eligen_entity_masks, width, height, t5_sequence_length, enable_eligen_inpaint)
+ if enable_eligen_on_negative and cfg_scale != 1.0:
+ entity_prompt_emb_nega = prompt_emb_nega['prompt_emb'].unsqueeze(1).repeat(1, entity_masks_posi.shape[1], 1, 1)
+ entity_masks_nega = entity_masks_posi
+ else:
+ entity_prompt_emb_nega, entity_masks_nega = None, None
+ else:
+ entity_prompt_emb_posi, entity_masks_posi, entity_prompt_emb_nega, entity_masks_nega = None, None, None, None
+ fg_mask, bg_mask = None, None
+ eligen_kwargs_posi = {"entity_prompt_emb": entity_prompt_emb_posi, "entity_masks": entity_masks_posi}
+ eligen_kwargs_nega = {"entity_prompt_emb": entity_prompt_emb_nega, "entity_masks": entity_masks_nega}
+ return eligen_kwargs_posi, eligen_kwargs_nega, fg_mask, bg_mask
+
+
+ def prepare_prompts(self, prompt, local_prompts, masks, mask_scales, t5_sequence_length, negative_prompt, cfg_scale):
+ # Extend prompt
+ self.load_models_to_device(['text_encoder_1', 'text_encoder_2'])
+ prompt, local_prompts, masks, mask_scales = self.extend_prompt(prompt, local_prompts, masks, mask_scales)
+
+ # Encode prompts
+ prompt_emb_posi = self.encode_prompt(prompt, t5_sequence_length=t5_sequence_length)
+ prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False, t5_sequence_length=t5_sequence_length) if cfg_scale != 1.0 else None
+ prompt_emb_locals = [self.encode_prompt(prompt_local, t5_sequence_length=t5_sequence_length) for prompt_local in local_prompts]
+ return prompt_emb_posi, prompt_emb_nega, prompt_emb_locals
+
+
+ def prepare_infinite_you(self, id_image, controlnet_image, infinityou_guidance, height, width):
+ if self.infinityou_processor is not None and id_image is not None:
+ return self.infinityou_processor.prepare_infinite_you(self.image_proj_model, id_image, controlnet_image, infinityou_guidance, height, width)
+ else:
+ return {}, controlnet_image
+
+
+ def prepare_flex_kwargs(self, latents, flex_inpaint_image=None, flex_inpaint_mask=None, flex_control_image=None, flex_control_strength=0.5, flex_control_stop=0.5, tiled=False, tile_size=64, tile_stride=32):
+ if self.dit.input_dim == 196:
+ if flex_inpaint_image is None:
+ flex_inpaint_image = torch.zeros_like(latents)
+ else:
+ flex_inpaint_image = self.preprocess_image(flex_inpaint_image).to(device=self.device, dtype=self.torch_dtype)
+ flex_inpaint_image = self.encode_image(flex_inpaint_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
+ if flex_inpaint_mask is None:
+ flex_inpaint_mask = torch.ones_like(latents)[:, 0:1, :, :]
+ else:
+ flex_inpaint_mask = flex_inpaint_mask.resize((latents.shape[3], latents.shape[2]))
+ flex_inpaint_mask = self.preprocess_image(flex_inpaint_mask).to(device=self.device, dtype=self.torch_dtype)
+ flex_inpaint_mask = (flex_inpaint_mask[:, 0:1, :, :] + 1) / 2
+ flex_inpaint_image = flex_inpaint_image * (1 - flex_inpaint_mask)
+ if flex_control_image is None:
+ flex_control_image = torch.zeros_like(latents)
+ else:
+ flex_control_image = self.preprocess_image(flex_control_image).to(device=self.device, dtype=self.torch_dtype)
+ flex_control_image = self.encode_image(flex_control_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) * flex_control_strength
+ flex_condition = torch.concat([flex_inpaint_image, flex_inpaint_mask, flex_control_image], dim=1)
+ flex_uncondition = torch.concat([flex_inpaint_image, flex_inpaint_mask, torch.zeros_like(flex_control_image)], dim=1)
+ flex_control_stop_timestep = self.scheduler.timesteps[int(flex_control_stop * (len(self.scheduler.timesteps) - 1))]
+ flex_kwargs = {"flex_condition": flex_condition, "flex_uncondition": flex_uncondition, "flex_control_stop_timestep": flex_control_stop_timestep}
+ else:
+ flex_kwargs = {}
+ return flex_kwargs
+
+
+ def prepare_step1x_kwargs(self, prompt, negative_prompt, image):
+ if image is None:
+ return {}, {}
+ self.load_models_to_device(["qwenvl", "vae_encoder"])
+ captions = [prompt, negative_prompt]
+ ref_images = [image, image]
+ embs, masks = self.qwenvl(captions, ref_images)
+ image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype)
+ image = self.encode_image(image)
+ return {"step1x_llm_embedding": embs[0:1], "step1x_mask": masks[0:1], "step1x_reference_latents": image}, {"step1x_llm_embedding": embs[1:2], "step1x_mask": masks[1:2], "step1x_reference_latents": image}
+
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ # Prompt
+ prompt,
+ negative_prompt="",
+ cfg_scale=1.0,
+ embedded_guidance=3.5,
+ t5_sequence_length=512,
+ # Image
+ input_image=None,
+ denoising_strength=1.0,
+ height=1024,
+ width=1024,
+ seed=None,
+ # Steps
+ num_inference_steps=30,
+ # local prompts
+ local_prompts=(),
+ masks=(),
+ mask_scales=(),
+ # ControlNet
+ controlnet_image=None,
+ controlnet_inpaint_mask=None,
+ enable_controlnet_on_negative=False,
+ # IP-Adapter
+ ipadapter_images=None,
+ ipadapter_scale=1.0,
+ # EliGen
+ eligen_entity_prompts=None,
+ eligen_entity_masks=None,
+ enable_eligen_on_negative=False,
+ enable_eligen_inpaint=False,
+ # InfiniteYou
+ infinityou_id_image=None,
+ infinityou_guidance=1.0,
+ # Flex
+ flex_inpaint_image=None,
+ flex_inpaint_mask=None,
+ flex_control_image=None,
+ flex_control_strength=0.5,
+ flex_control_stop=0.5,
+ # Step1x
+ step1x_reference_image=None,
+ # TeaCache
+ tea_cache_l1_thresh=None,
+ # Tile
+ tiled=False,
+ tile_size=128,
+ tile_stride=64,
+ # Progress bar
+ progress_bar_cmd=tqdm,
+ progress_bar_st=None,
+ ):
+ height, width = self.check_resize_height_width(height, width)
+
+ # Tiler parameters
+ tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
+
+ # Prepare scheduler
+ self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
+
+ # Prepare latent tensors
+ latents, input_latents = self.prepare_latents(input_image, height, width, seed, tiled, tile_size, tile_stride)
+
+ # Prompt
+ prompt_emb_posi, prompt_emb_nega, prompt_emb_locals = self.prepare_prompts(prompt, local_prompts, masks, mask_scales, t5_sequence_length, negative_prompt, cfg_scale)
+
+ # Extra input
+ extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance)
+
+ # InfiniteYou
+ infiniteyou_kwargs, controlnet_image = self.prepare_infinite_you(infinityou_id_image, controlnet_image, infinityou_guidance, height, width)
+
+ # Entity control
+ eligen_kwargs_posi, eligen_kwargs_nega, fg_mask, bg_mask = self.prepare_eligen(prompt_emb_nega, eligen_entity_prompts, eligen_entity_masks, width, height, t5_sequence_length, enable_eligen_inpaint, enable_eligen_on_negative, cfg_scale)
+
+ # IP-Adapter
+ ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = self.prepare_ipadapter(ipadapter_images, ipadapter_scale)
+
+ # ControlNets
+ controlnet_kwargs_posi, controlnet_kwargs_nega, local_controlnet_kwargs = self.prepare_controlnet(controlnet_image, masks, controlnet_inpaint_mask, tiler_kwargs, enable_controlnet_on_negative)
+
+ # Flex
+ flex_kwargs = self.prepare_flex_kwargs(latents, flex_inpaint_image, flex_inpaint_mask, flex_control_image, flex_control_strength=flex_control_strength, flex_control_stop=flex_control_stop, **tiler_kwargs)
+
+ # Step1x
+ step1x_kwargs_posi, step1x_kwargs_nega = self.prepare_step1x_kwargs(prompt, negative_prompt, image=step1x_reference_image)
+
+ # TeaCache
+ tea_cache_kwargs = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh) if tea_cache_l1_thresh is not None else None}
+
+ # Denoise
+ self.load_models_to_device(['dit', 'controlnet', 'step1x_connector'])
+ for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
+ timestep = timestep.unsqueeze(0).to(self.device)
+
+ # Positive side
+ inference_callback = lambda prompt_emb_posi, controlnet_kwargs: lets_dance_flux(
+ dit=self.dit, controlnet=self.controlnet, step1x_connector=self.step1x_connector,
+ hidden_states=latents, timestep=timestep,
+ **prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **eligen_kwargs_posi, **tea_cache_kwargs, **infiniteyou_kwargs, **flex_kwargs, **step1x_kwargs_posi,
+ )
+ noise_pred_posi = self.control_noise_via_local_prompts(
+ prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback,
+ special_kwargs=controlnet_kwargs_posi, special_local_kwargs_list=local_controlnet_kwargs
+ )
+
+ # Inpaint
+ if enable_eligen_inpaint:
+ noise_pred_posi = self.inpaint_fusion(latents, input_latents, noise_pred_posi, fg_mask, bg_mask, progress_id)
+
+ # Classifier-free guidance
+ if cfg_scale != 1.0:
+ # Negative side
+ noise_pred_nega = lets_dance_flux(
+ dit=self.dit, controlnet=self.controlnet, step1x_connector=self.step1x_connector,
+ hidden_states=latents, timestep=timestep,
+ **prompt_emb_nega, **tiler_kwargs, **extra_input, **controlnet_kwargs_nega, **ipadapter_kwargs_list_nega, **eligen_kwargs_nega, **infiniteyou_kwargs, **flex_kwargs, **step1x_kwargs_nega,
+ )
+ noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
+ else:
+ noise_pred = noise_pred_posi
+
+ # Iterate
+ latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
+
+ # UI
+ if progress_bar_st is not None:
+ progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
+
+ # Decode image
+ self.load_models_to_device(['vae_decoder'])
+ image = self.decode_image(latents, **tiler_kwargs)
+
+ # Offload all models
+ self.load_models_to_device([])
+ return image
+
+
+
+class InfinitYou:
+ def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
+ from facexlib.recognition import init_recognition_model
+ from insightface.app import FaceAnalysis
+ self.device = device
+ self.torch_dtype = torch_dtype
+ insightface_root_path = 'models/InfiniteYou/insightface'
+ self.app_640 = FaceAnalysis(name='antelopev2', root=insightface_root_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
+ self.app_640.prepare(ctx_id=0, det_size=(640, 640))
+ self.app_320 = FaceAnalysis(name='antelopev2', root=insightface_root_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
+ self.app_320.prepare(ctx_id=0, det_size=(320, 320))
+ self.app_160 = FaceAnalysis(name='antelopev2', root=insightface_root_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
+ self.app_160.prepare(ctx_id=0, det_size=(160, 160))
+ self.arcface_model = init_recognition_model('arcface', device=self.device)
+
+ def _detect_face(self, id_image_cv2):
+ face_info = self.app_640.get(id_image_cv2)
+ if len(face_info) > 0:
+ return face_info
+ face_info = self.app_320.get(id_image_cv2)
+ if len(face_info) > 0:
+ return face_info
+ face_info = self.app_160.get(id_image_cv2)
+ return face_info
+
+ def extract_arcface_bgr_embedding(self, in_image, landmark):
+ from insightface.utils import face_align
+ arc_face_image = face_align.norm_crop(in_image, landmark=np.array(landmark), image_size=112)
+ arc_face_image = torch.from_numpy(arc_face_image).unsqueeze(0).permute(0, 3, 1, 2) / 255.
+ arc_face_image = 2 * arc_face_image - 1
+ arc_face_image = arc_face_image.contiguous().to(self.device)
+ face_emb = self.arcface_model(arc_face_image)[0] # [512], normalized
+ return face_emb
+
+ def prepare_infinite_you(self, model, id_image, controlnet_image, infinityou_guidance, height, width):
+ import cv2
+ if id_image is None:
+ return {'id_emb': None}, controlnet_image
+ id_image_cv2 = cv2.cvtColor(np.array(id_image), cv2.COLOR_RGB2BGR)
+ face_info = self._detect_face(id_image_cv2)
+ if len(face_info) == 0:
+ raise ValueError('No face detected in the input ID image')
+ landmark = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]))[-1]['kps'] # only use the maximum face
+ id_emb = self.extract_arcface_bgr_embedding(id_image_cv2, landmark)
+ id_emb = model(id_emb.unsqueeze(0).reshape([1, -1, 512]).to(dtype=self.torch_dtype))
+ if controlnet_image is None:
+ controlnet_image = Image.fromarray(np.zeros([height, width, 3]).astype(np.uint8))
+ infinityou_guidance = torch.Tensor([infinityou_guidance]).to(device=self.device, dtype=self.torch_dtype)
+ return {'id_emb': id_emb, 'infinityou_guidance': infinityou_guidance}, controlnet_image
+
+
+class TeaCache:
+ def __init__(self, num_inference_steps, rel_l1_thresh):
+ self.num_inference_steps = num_inference_steps
+ self.step = 0
+ self.accumulated_rel_l1_distance = 0
+ self.previous_modulated_input = None
+ self.rel_l1_thresh = rel_l1_thresh
+ self.previous_residual = None
+ self.previous_hidden_states = None
+
+ def check(self, dit: FluxDiT, hidden_states, conditioning):
+ inp = hidden_states.clone()
+ temb_ = conditioning.clone()
+ modulated_inp, _, _, _, _ = dit.blocks[0].norm1_a(inp, emb=temb_)
+ if self.step == 0 or self.step == self.num_inference_steps - 1:
+ should_calc = True
+ self.accumulated_rel_l1_distance = 0
+ else:
+ coefficients = [4.98651651e+02, -2.83781631e+02, 5.58554382e+01, -3.82021401e+00, 2.64230861e-01]
+ rescale_func = np.poly1d(coefficients)
+ self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
+ if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
+ should_calc = False
+ else:
+ should_calc = True
+ self.accumulated_rel_l1_distance = 0
+ self.previous_modulated_input = modulated_inp
+ self.step += 1
+ if self.step == self.num_inference_steps:
+ self.step = 0
+ if should_calc:
+ self.previous_hidden_states = hidden_states.clone()
+ return not should_calc
+
+ def store(self, hidden_states):
+ self.previous_residual = hidden_states - self.previous_hidden_states
+ self.previous_hidden_states = None
+
+ def update(self, hidden_states):
+ hidden_states = hidden_states + self.previous_residual
+ return hidden_states
+
+
+def lets_dance_flux(
+ dit: FluxDiT,
+ controlnet: FluxMultiControlNetManager = None,
+ step1x_connector: Qwen2Connector = None,
+ hidden_states=None,
+ timestep=None,
+ prompt_emb=None,
+ pooled_prompt_emb=None,
+ guidance=None,
+ text_ids=None,
+ image_ids=None,
+ controlnet_frames=None,
+ tiled=False,
+ tile_size=128,
+ tile_stride=64,
+ entity_prompt_emb=None,
+ entity_masks=None,
+ ipadapter_kwargs_list={},
+ id_emb=None,
+ infinityou_guidance=None,
+ flex_condition=None,
+ flex_uncondition=None,
+ flex_control_stop_timestep=None,
+ step1x_llm_embedding=None,
+ step1x_mask=None,
+ step1x_reference_latents=None,
+ tea_cache: TeaCache = None,
+ **kwargs
+):
+ if tiled:
+ def flux_forward_fn(hl, hr, wl, wr):
+ tiled_controlnet_frames = [f[:, :, hl: hr, wl: wr] for f in controlnet_frames] if controlnet_frames is not None else None
+ return lets_dance_flux(
+ dit=dit,
+ controlnet=controlnet,
+ hidden_states=hidden_states[:, :, hl: hr, wl: wr],
+ timestep=timestep,
+ prompt_emb=prompt_emb,
+ pooled_prompt_emb=pooled_prompt_emb,
+ guidance=guidance,
+ text_ids=text_ids,
+ image_ids=None,
+ controlnet_frames=tiled_controlnet_frames,
+ tiled=False,
+ **kwargs
+ )
+ return FastTileWorker().tiled_forward(
+ flux_forward_fn,
+ hidden_states,
+ tile_size=tile_size,
+ tile_stride=tile_stride,
+ tile_device=hidden_states.device,
+ tile_dtype=hidden_states.dtype
+ )
+
+
+ # ControlNet
+ if controlnet is not None and controlnet_frames is not None:
+ controlnet_extra_kwargs = {
+ "hidden_states": hidden_states,
+ "timestep": timestep,
+ "prompt_emb": prompt_emb,
+ "pooled_prompt_emb": pooled_prompt_emb,
+ "guidance": guidance,
+ "text_ids": text_ids,
+ "image_ids": image_ids,
+ "tiled": tiled,
+ "tile_size": tile_size,
+ "tile_stride": tile_stride,
+ }
+ if id_emb is not None:
+ controlnet_text_ids = torch.zeros(id_emb.shape[0], id_emb.shape[1], 3).to(device=hidden_states.device, dtype=hidden_states.dtype)
+ controlnet_extra_kwargs.update({"prompt_emb": id_emb, 'text_ids': controlnet_text_ids, 'guidance': infinityou_guidance})
+ controlnet_res_stack, controlnet_single_res_stack = controlnet(
+ controlnet_frames, **controlnet_extra_kwargs
+ )
+
+ # Flex
+ if flex_condition is not None:
+ if timestep.tolist()[0] >= flex_control_stop_timestep:
+ hidden_states = torch.concat([hidden_states, flex_condition], dim=1)
+ else:
+ hidden_states = torch.concat([hidden_states, flex_uncondition], dim=1)
+
+ # Step1x
+ if step1x_llm_embedding is not None:
+ prompt_emb, pooled_prompt_emb = step1x_connector(step1x_llm_embedding, timestep / 1000, step1x_mask)
+ text_ids = torch.zeros((1, prompt_emb.shape[1], 3), dtype=prompt_emb.dtype, device=prompt_emb.device)
+
+ if image_ids is None:
+ image_ids = dit.prepare_image_ids(hidden_states)
+
+ conditioning = dit.time_embedder(timestep, hidden_states.dtype) + dit.pooled_text_embedder(pooled_prompt_emb)
+ if dit.guidance_embedder is not None:
+ guidance = guidance * 1000
+ conditioning = conditioning + dit.guidance_embedder(guidance, hidden_states.dtype)
+
+ height, width = hidden_states.shape[-2:]
+ hidden_states = dit.patchify(hidden_states)
+
+ # Step1x
+ if step1x_reference_latents is not None:
+ step1x_reference_image_ids = dit.prepare_image_ids(step1x_reference_latents)
+ step1x_reference_latents = dit.patchify(step1x_reference_latents)
+ image_ids = torch.concat([image_ids, step1x_reference_image_ids], dim=-2)
+ hidden_states = torch.concat([hidden_states, step1x_reference_latents], dim=1)
+
+ hidden_states = dit.x_embedder(hidden_states)
+
+ if entity_prompt_emb is not None and entity_masks is not None:
+ prompt_emb, image_rotary_emb, attention_mask = dit.process_entity_masks(hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids, 16)
+ else:
+ prompt_emb = dit.context_embedder(prompt_emb)
+ image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
+ attention_mask = None
+
+ # TeaCache
+ if tea_cache is not None:
+ tea_cache_update = tea_cache.check(dit, hidden_states, conditioning)
+ else:
+ tea_cache_update = False
+
+ if tea_cache_update:
+ hidden_states = tea_cache.update(hidden_states)
+ else:
+ # Joint Blocks
+ for block_id, block in enumerate(dit.blocks):
+ hidden_states, prompt_emb = block(
+ hidden_states,
+ prompt_emb,
+ conditioning,
+ image_rotary_emb,
+ attention_mask,
+ ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None)
+ )
+ # ControlNet
+ if controlnet is not None and controlnet_frames is not None:
+ hidden_states = hidden_states + controlnet_res_stack[block_id]
+
+ # Single Blocks
+ hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
+ num_joint_blocks = len(dit.blocks)
+ for block_id, block in enumerate(dit.single_blocks):
+ hidden_states, prompt_emb = block(
+ hidden_states,
+ prompt_emb,
+ conditioning,
+ image_rotary_emb,
+ attention_mask,
+ ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id + num_joint_blocks, None)
+ )
+ # ControlNet
+ if controlnet is not None and controlnet_frames is not None:
+ hidden_states[:, prompt_emb.shape[1]:] = hidden_states[:, prompt_emb.shape[1]:] + controlnet_single_res_stack[block_id]
+ hidden_states = hidden_states[:, prompt_emb.shape[1]:]
+
+ if tea_cache is not None:
+ tea_cache.store(hidden_states)
+
+ hidden_states = dit.final_norm_out(hidden_states, conditioning)
+ hidden_states = dit.final_proj_out(hidden_states)
+
+ # Step1x
+ if step1x_reference_latents is not None:
+ hidden_states = hidden_states[:, :hidden_states.shape[1] // 2]
+
+ hidden_states = dit.unpatchify(hidden_states, height, width)
+
+ return hidden_states
diff --git a/diffsynth/pipelines/flux_image_new.py b/diffsynth/pipelines/flux_image_new.py
new file mode 100644
index 0000000000000000000000000000000000000000..8cf73a65b49e3fa6ab7e424e31cff8680dc65806
--- /dev/null
+++ b/diffsynth/pipelines/flux_image_new.py
@@ -0,0 +1,1318 @@
+import torch, warnings, glob, os, types
+import numpy as np
+from PIL import Image
+from einops import repeat, reduce
+from typing import Optional, Union
+from dataclasses import dataclass
+from modelscope import snapshot_download
+from einops import rearrange
+import numpy as np
+from PIL import Image
+from tqdm import tqdm
+from typing import Optional
+from typing_extensions import Literal
+
+from ..schedulers import FlowMatchScheduler
+from ..prompters import FluxPrompter
+from ..models import ModelManager, load_state_dict, SD3TextEncoder1, FluxTextEncoder2, FluxDiT, FluxVAEEncoder, FluxVAEDecoder
+from ..models.step1x_connector import Qwen2Connector
+from ..models.flux_controlnet import FluxControlNet
+from ..models.flux_ipadapter import FluxIpAdapter
+from ..models.flux_value_control import MultiValueEncoder
+from ..models.flux_infiniteyou import InfiniteYouImageProjector
+from ..models.flux_lora_encoder import FluxLoRAEncoder, LoRALayerBlock
+from ..models.tiler import FastTileWorker
+from ..models.nexus_gen import NexusGenAutoregressiveModel
+from ..models.nexus_gen_projector import NexusGenAdapter, NexusGenImageEmbeddingMerger
+from ..utils import BasePipeline, ModelConfig, PipelineUnitRunner, PipelineUnit
+from ..lora.flux_lora import FluxLoRALoader, FluxLoraPatcher, FluxLoRAFuser
+
+from ..models.flux_dit import RMSNorm
+from ..vram_management import gradient_checkpoint_forward, enable_vram_management, AutoWrappedModule, AutoWrappedLinear
+
+def visualize_inputs(inputs: dict):
+ # save the image for debug
+ print(f"### prompt = {inputs.get('prompt', '')} ###")
+ # img = inputs['input_image'].detach().cpu().numpy().transpose(0, 2, 3, 1)[0]
+ img = inputs['input_image'].save("debug_input_image.png")
+ img_vae = inputs['input_latents'].to(torch.float32).detach().cpu().numpy().transpose(0, 2, 3, 1)[0, :, :, :3]
+ img_vae = ((img_vae + 1) * 127.5).astype(np.uint8)
+ Image.fromarray(img_vae).save("debug_input_image_vae.png")
+ print("### input image saved to debug_input_image.png ###")
+
+@dataclass
+class ControlNetInput:
+ controlnet_id: int = 0
+ scale: float = 1.0
+ start: float = 1.0
+ end: float = 0.0
+ image: Image.Image = None
+ inpaint_mask: Image.Image = None
+ processor_id: str = None
+
+
+
+class MultiControlNet(torch.nn.Module):
+ def __init__(self, models: list[FluxControlNet]):
+ super().__init__()
+ self.models = torch.nn.ModuleList(models)
+
+ def process_single_controlnet(self, controlnet_input: ControlNetInput, conditioning: torch.Tensor, **kwargs):
+ model = self.models[controlnet_input.controlnet_id]
+ res_stack, single_res_stack = model(
+ controlnet_conditioning=conditioning,
+ processor_id=controlnet_input.processor_id,
+ **kwargs
+ )
+ res_stack = [res * controlnet_input.scale for res in res_stack]
+ single_res_stack = [res * controlnet_input.scale for res in single_res_stack]
+ return res_stack, single_res_stack
+
+ def forward(self, conditionings: list[torch.Tensor], controlnet_inputs: list[ControlNetInput], progress_id, num_inference_steps, **kwargs):
+ res_stack, single_res_stack = None, None
+ for controlnet_input, conditioning in zip(controlnet_inputs, conditionings):
+ progress = (num_inference_steps - 1 - progress_id) / max(num_inference_steps - 1, 1)
+ if progress > controlnet_input.start or progress < controlnet_input.end:
+ continue
+ res_stack_, single_res_stack_ = self.process_single_controlnet(controlnet_input, conditioning, **kwargs)
+ if res_stack is None:
+ res_stack = res_stack_
+ single_res_stack = single_res_stack_
+ else:
+ res_stack = [i + j for i, j in zip(res_stack, res_stack_)]
+ single_res_stack = [i + j for i, j in zip(single_res_stack, single_res_stack_)]
+ return res_stack, single_res_stack
+
+
+
+class FluxImagePipeline(BasePipeline):
+
+ def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
+ super().__init__(
+ device=device, torch_dtype=torch_dtype,
+ height_division_factor=16, width_division_factor=16,
+ )
+ self.scheduler = FlowMatchScheduler()
+ self.prompter = FluxPrompter()
+ self.text_encoder_1: SD3TextEncoder1 = None
+ self.text_encoder_2: FluxTextEncoder2 = None
+ self.dit: FluxDiT = None
+ self.vae_decoder: FluxVAEDecoder = None
+ self.vae_encoder: FluxVAEEncoder = None
+ self.controlnet: MultiControlNet = None
+ self.ipadapter: FluxIpAdapter = None
+ self.ipadapter_image_encoder = None
+ self.qwenvl = None
+ self.step1x_connector: Qwen2Connector = None
+ self.nexus_gen: NexusGenAutoregressiveModel = None
+ self.nexus_gen_generation_adapter: NexusGenAdapter = None
+ self.nexus_gen_editing_adapter: NexusGenImageEmbeddingMerger = None
+ self.value_controller: MultiValueEncoder = None
+ self.infinityou_processor: InfinitYou = None
+ self.image_proj_model: InfiniteYouImageProjector = None
+ self.lora_patcher: FluxLoraPatcher = None
+ self.lora_encoder: FluxLoRAEncoder = None
+ self.unit_runner = PipelineUnitRunner()
+ self.in_iteration_models = ("dit", "step1x_connector", "controlnet", "lora_patcher")
+ self.units = [
+ FluxImageUnit_ShapeChecker(),
+ FluxImageUnit_NoiseInitializer(),
+ FluxImageUnit_PromptEmbedder(),
+ FluxImageUnit_InputImageEmbedder(),
+ FluxImageUnit_ImageIDs(),
+ FluxImageUnit_EmbeddedGuidanceEmbedder(),
+ FluxImageUnit_Kontext(),
+ FluxImageUnit_InfiniteYou(),
+ FluxImageUnit_ControlNet(),
+ FluxImageUnit_IPAdapter(),
+ FluxImageUnit_EntityControl(),
+ FluxImageUnit_NexusGen(),
+ FluxImageUnit_TeaCache(),
+ FluxImageUnit_Flex(),
+ FluxImageUnit_Step1x(),
+ FluxImageUnit_ValueControl(),
+ FluxImageUnit_LoRAEncode(),
+ ]
+ self.model_fn = model_fn_flux_image
+
+
+ def load_lora(
+ self,
+ module: torch.nn.Module,
+ lora_config: Union[ModelConfig, str] = None,
+ alpha=1,
+ hotload=False,
+ state_dict=None,
+ ):
+ if state_dict is None:
+ if isinstance(lora_config, str):
+ lora = load_state_dict(lora_config, torch_dtype=self.torch_dtype, device=self.device)
+ else:
+ lora_config.download_if_necessary()
+ lora = load_state_dict(lora_config.path, torch_dtype=self.torch_dtype, device=self.device)
+ else:
+ lora = state_dict
+ loader = FluxLoRALoader(torch_dtype=self.torch_dtype, device=self.device)
+ lora = loader.convert_state_dict(lora)
+ if hotload:
+ for name, module in module.named_modules():
+ if isinstance(module, AutoWrappedLinear):
+ lora_a_name = f'{name}.lora_A.default.weight'
+ lora_b_name = f'{name}.lora_B.default.weight'
+ if lora_a_name in lora and lora_b_name in lora:
+ module.lora_A_weights.append(lora[lora_a_name] * alpha)
+ module.lora_B_weights.append(lora[lora_b_name])
+ else:
+ loader.load(module, lora, alpha=alpha)
+
+
+ def load_loras(
+ self,
+ module: torch.nn.Module,
+ lora_configs: list[Union[ModelConfig, str]],
+ alpha=1,
+ hotload=False,
+ extra_fused_lora=False,
+ ):
+ for lora_config in lora_configs:
+ self.load_lora(module, lora_config, hotload=hotload, alpha=alpha)
+ if extra_fused_lora:
+ lora_fuser = FluxLoRAFuser(device="cuda", torch_dtype=torch.bfloat16)
+ fused_lora = lora_fuser(lora_configs)
+ self.load_lora(module, state_dict=fused_lora, hotload=hotload, alpha=alpha)
+
+
+ def clear_lora(self):
+ for name, module in self.named_modules():
+ if isinstance(module, AutoWrappedLinear):
+ if hasattr(module, "lora_A_weights"):
+ module.lora_A_weights.clear()
+ if hasattr(module, "lora_B_weights"):
+ module.lora_B_weights.clear()
+
+
+ def training_loss(self, **inputs):
+ timestep_id = torch.randint(0, self.scheduler.num_train_timesteps, (1,))
+ timestep = self.scheduler.timesteps[timestep_id].to(dtype=self.torch_dtype, device=self.device)
+
+ inputs["latents"] = self.scheduler.add_noise(inputs["input_latents"], inputs["noise"], timestep)
+ training_target = self.scheduler.training_target(inputs["input_latents"], inputs["noise"], timestep)
+
+ noise_pred = self.model_fn(**inputs, timestep=timestep)
+
+ loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
+ loss = loss * self.scheduler.training_weight(timestep)
+ return loss
+
+
+ def _enable_vram_management_with_default_config(self, model, vram_limit):
+ if model is not None:
+ dtype = next(iter(model.parameters())).dtype
+ enable_vram_management(
+ model,
+ module_map = {
+ torch.nn.Linear: AutoWrappedLinear,
+ torch.nn.Embedding: AutoWrappedModule,
+ torch.nn.LayerNorm: AutoWrappedModule,
+ torch.nn.Conv2d: AutoWrappedModule,
+ torch.nn.GroupNorm: AutoWrappedModule,
+ RMSNorm: AutoWrappedModule,
+ LoRALayerBlock: AutoWrappedModule,
+ },
+ module_config = dict(
+ offload_dtype=dtype,
+ offload_device="cpu",
+ onload_dtype=dtype,
+ onload_device="cpu",
+ computation_dtype=self.torch_dtype,
+ computation_device=self.device,
+ ),
+ vram_limit=vram_limit,
+ )
+
+
+ def enable_lora_magic(self):
+ if self.dit is not None:
+ if not (hasattr(self.dit, "vram_management_enabled") and self.dit.vram_management_enabled):
+ dtype = next(iter(self.dit.parameters())).dtype
+ enable_vram_management(
+ self.dit,
+ module_map = {
+ torch.nn.Linear: AutoWrappedLinear,
+ },
+ module_config = dict(
+ offload_dtype=dtype,
+ offload_device=self.device,
+ onload_dtype=dtype,
+ onload_device=self.device,
+ computation_dtype=self.torch_dtype,
+ computation_device=self.device,
+ ),
+ vram_limit=None,
+ )
+ if self.lora_patcher is not None:
+ for name, module in self.dit.named_modules():
+ if isinstance(module, AutoWrappedLinear):
+ merger_name = name.replace(".", "___")
+ if merger_name in self.lora_patcher.model_dict:
+ module.lora_merger = self.lora_patcher.model_dict[merger_name]
+
+
+ def enable_vram_management(self, num_persistent_param_in_dit=None, vram_limit=None, vram_buffer=0.5):
+ self.vram_management_enabled = True
+ if num_persistent_param_in_dit is not None:
+ vram_limit = None
+ else:
+ if vram_limit is None:
+ vram_limit = self.get_vram()
+ vram_limit = vram_limit - vram_buffer
+
+ # Default config
+ default_vram_management_models = ["text_encoder_1", "vae_decoder", "vae_encoder", "controlnet", "image_proj_model", "ipadapter", "lora_patcher", "value_controller", "step1x_connector", "lora_encoder"]
+ for model_name in default_vram_management_models:
+ self._enable_vram_management_with_default_config(getattr(self, model_name), vram_limit)
+
+ # Special config
+ if self.text_encoder_2 is not None:
+ from transformers.models.t5.modeling_t5 import T5LayerNorm, T5DenseActDense, T5DenseGatedActDense
+ dtype = next(iter(self.text_encoder_2.parameters())).dtype
+ enable_vram_management(
+ self.text_encoder_2,
+ module_map = {
+ torch.nn.Linear: AutoWrappedLinear,
+ torch.nn.Embedding: AutoWrappedModule,
+ T5LayerNorm: AutoWrappedModule,
+ T5DenseActDense: AutoWrappedModule,
+ T5DenseGatedActDense: AutoWrappedModule,
+ },
+ module_config = dict(
+ offload_dtype=dtype,
+ offload_device="cpu",
+ onload_dtype=dtype,
+ onload_device="cpu",
+ computation_dtype=self.torch_dtype,
+ computation_device=self.device,
+ ),
+ vram_limit=vram_limit,
+ )
+ if self.dit is not None:
+ dtype = next(iter(self.dit.parameters())).dtype
+ device = "cpu" if vram_limit is not None else self.device
+ enable_vram_management(
+ self.dit,
+ module_map = {
+ RMSNorm: AutoWrappedModule,
+ torch.nn.Linear: AutoWrappedLinear,
+ },
+ module_config = dict(
+ offload_dtype=dtype,
+ offload_device="cpu",
+ onload_dtype=dtype,
+ onload_device=device,
+ computation_dtype=self.torch_dtype,
+ computation_device=self.device,
+ ),
+ max_num_param=num_persistent_param_in_dit,
+ overflow_module_config = dict(
+ offload_dtype=dtype,
+ offload_device="cpu",
+ onload_dtype=dtype,
+ onload_device="cpu",
+ computation_dtype=self.torch_dtype,
+ computation_device=self.device,
+ ),
+ vram_limit=vram_limit,
+ )
+ if self.ipadapter_image_encoder is not None:
+ from transformers.models.siglip.modeling_siglip import SiglipVisionEmbeddings, SiglipEncoder, SiglipMultiheadAttentionPoolingHead
+ dtype = next(iter(self.ipadapter_image_encoder.parameters())).dtype
+ enable_vram_management(
+ self.ipadapter_image_encoder,
+ module_map = {
+ SiglipVisionEmbeddings: AutoWrappedModule,
+ SiglipEncoder: AutoWrappedModule,
+ SiglipMultiheadAttentionPoolingHead: AutoWrappedModule,
+ torch.nn.MultiheadAttention: AutoWrappedModule,
+ torch.nn.Linear: AutoWrappedLinear,
+ torch.nn.LayerNorm: AutoWrappedModule,
+ },
+ module_config = dict(
+ offload_dtype=dtype,
+ offload_device="cpu",
+ onload_dtype=dtype,
+ onload_device="cpu",
+ computation_dtype=self.torch_dtype,
+ computation_device=self.device,
+ ),
+ vram_limit=vram_limit,
+ )
+ if self.qwenvl is not None:
+ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
+ Qwen2_5_VisionPatchEmbed, Qwen2_5_VLVisionBlock, Qwen2_5_VLPatchMerger,
+ Qwen2_5_VLDecoderLayer, Qwen2_5_VisionRotaryEmbedding, Qwen2_5_VLRotaryEmbedding, Qwen2RMSNorm
+ )
+ dtype = next(iter(self.qwenvl.parameters())).dtype
+ enable_vram_management(
+ self.qwenvl,
+ module_map = {
+ Qwen2_5_VisionPatchEmbed: AutoWrappedModule,
+ Qwen2_5_VLVisionBlock: AutoWrappedModule,
+ Qwen2_5_VLPatchMerger: AutoWrappedModule,
+ Qwen2_5_VLDecoderLayer: AutoWrappedModule,
+ Qwen2_5_VisionRotaryEmbedding: AutoWrappedModule,
+ Qwen2_5_VLRotaryEmbedding: AutoWrappedModule,
+ Qwen2RMSNorm: AutoWrappedModule,
+ torch.nn.Embedding: AutoWrappedModule,
+ torch.nn.Linear: AutoWrappedLinear,
+ torch.nn.LayerNorm: AutoWrappedModule,
+ },
+ module_config = dict(
+ offload_dtype=dtype,
+ offload_device="cpu",
+ onload_dtype=dtype,
+ onload_device="cpu",
+ computation_dtype=self.torch_dtype,
+ computation_device=self.device,
+ ),
+ vram_limit=vram_limit,
+ )
+
+
+ @staticmethod
+ def from_pretrained(
+ torch_dtype: torch.dtype = torch.bfloat16,
+ device: Union[str, torch.device] = "cuda",
+ model_configs: list[ModelConfig] = [],
+ nexus_gen_processor_config: ModelConfig = ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="processor/"),
+ ):
+ # Download and load models
+ model_manager = ModelManager()
+ for model_config in model_configs:
+ model_config.download_if_necessary()
+ model_manager.load_model(
+ model_config.path,
+ device=model_config.offload_device or device,
+ torch_dtype=model_config.offload_dtype or torch_dtype
+ )
+
+ # Initialize pipeline
+ pipe = FluxImagePipeline(device=device, torch_dtype=torch_dtype)
+ pipe.text_encoder_1 = model_manager.fetch_model("sd3_text_encoder_1")
+ pipe.text_encoder_2 = model_manager.fetch_model("flux_text_encoder_2")
+ pipe.dit = model_manager.fetch_model("flux_dit")
+ pipe.vae_decoder = model_manager.fetch_model("flux_vae_decoder")
+ pipe.vae_encoder = model_manager.fetch_model("flux_vae_encoder")
+ pipe.prompter.fetch_models(pipe.text_encoder_1, pipe.text_encoder_2)
+ pipe.ipadapter = model_manager.fetch_model("flux_ipadapter")
+ pipe.ipadapter_image_encoder = model_manager.fetch_model("siglip_vision_model")
+ pipe.qwenvl = model_manager.fetch_model("qwenvl")
+ pipe.step1x_connector = model_manager.fetch_model("step1x_connector")
+ pipe.image_proj_model = model_manager.fetch_model("infiniteyou_image_projector")
+ if pipe.image_proj_model is not None:
+ pipe.infinityou_processor = InfinitYou(device=device)
+ pipe.lora_patcher = model_manager.fetch_model("flux_lora_patcher")
+ pipe.lora_encoder = model_manager.fetch_model("flux_lora_encoder")
+ pipe.nexus_gen = model_manager.fetch_model("nexus_gen_llm")
+ pipe.nexus_gen_generation_adapter = model_manager.fetch_model("nexus_gen_generation_adapter")
+ pipe.nexus_gen_editing_adapter = model_manager.fetch_model("nexus_gen_editing_adapter")
+ if nexus_gen_processor_config is not None and pipe.nexus_gen is not None:
+ nexus_gen_processor_config.download_if_necessary()
+ pipe.nexus_gen.load_processor(nexus_gen_processor_config.path)
+
+ # ControlNet
+ controlnets = []
+ for model_name, model in zip(model_manager.model_name, model_manager.model):
+ if model_name == "flux_controlnet":
+ controlnets.append(model)
+ if len(controlnets) > 0:
+ pipe.controlnet = MultiControlNet(controlnets)
+
+ # Value Controller
+ value_controllers = []
+ for model_name, model in zip(model_manager.model_name, model_manager.model):
+ if model_name == "flux_value_controller":
+ value_controllers.append(model)
+ if len(value_controllers) > 0:
+ pipe.value_controller = MultiValueEncoder(value_controllers)
+
+ return pipe
+
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ # Prompt
+ prompt: str,
+ negative_prompt: str = "",
+ cfg_scale: float = 1.0,
+ embedded_guidance: float = 3.5,
+ t5_sequence_length: int = 512,
+ # Image
+ input_image: Image.Image = None,
+ denoising_strength: float = 1.0,
+ # Shape
+ height: int = 1024,
+ width: int = 1024,
+ # Randomness
+ seed: int = None,
+ rand_device: str = "cpu",
+ # Scheduler
+ sigma_shift: float = None,
+ # Steps
+ num_inference_steps: int = 30,
+ # local prompts
+ multidiffusion_prompts=(),
+ multidiffusion_masks=(),
+ multidiffusion_scales=(),
+ # Kontext
+ kontext_images: Union[list[Image.Image], Image.Image] = None,
+ # ControlNet
+ controlnet_inputs: list[ControlNetInput] = None,
+ # IP-Adapter
+ ipadapter_images: Union[list[Image.Image], Image.Image] = None,
+ ipadapter_scale: float = 1.0,
+ # EliGen
+ eligen_entity_prompts: list[str] = None,
+ eligen_entity_masks: list[Image.Image] = None,
+ eligen_enable_on_negative: bool = False,
+ eligen_enable_inpaint: bool = False,
+ # InfiniteYou
+ infinityou_id_image: Image.Image = None,
+ infinityou_guidance: float = 1.0,
+ # Flex
+ flex_inpaint_image: Image.Image = None,
+ flex_inpaint_mask: Image.Image = None,
+ flex_control_image: Image.Image = None,
+ flex_control_strength: float = 0.5,
+ flex_control_stop: float = 0.5,
+ # Value Controller
+ value_controller_inputs: Union[list[float], float] = None,
+ # Step1x
+ step1x_reference_image: Image.Image = None,
+ # NexusGen
+ nexus_gen_reference_image: Image.Image = None,
+ # LoRA Encoder
+ lora_encoder_inputs: Union[list[ModelConfig], ModelConfig, str] = None,
+ lora_encoder_scale: float = 1.0,
+ # TeaCache
+ tea_cache_l1_thresh: float = None,
+ # Tile
+ tiled: bool = False,
+ tile_size: int = 128,
+ tile_stride: int = 64,
+ # Progress bar
+ progress_bar_cmd = tqdm,
+ ):
+ # Scheduler
+ self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift)
+
+ inputs_posi = {
+ "prompt": prompt,
+ }
+ inputs_nega = {
+ "negative_prompt": negative_prompt,
+ }
+ inputs_shared = {
+ "cfg_scale": cfg_scale, "embedded_guidance": embedded_guidance, "t5_sequence_length": t5_sequence_length,
+ "input_image": input_image, "denoising_strength": denoising_strength,
+ "height": height, "width": width,
+ "seed": seed, "rand_device": rand_device,
+ "sigma_shift": sigma_shift, "num_inference_steps": num_inference_steps,
+ "multidiffusion_prompts": multidiffusion_prompts, "multidiffusion_masks": multidiffusion_masks, "multidiffusion_scales": multidiffusion_scales,
+ "kontext_images": kontext_images,
+ "controlnet_inputs": controlnet_inputs,
+ "ipadapter_images": ipadapter_images, "ipadapter_scale": ipadapter_scale,
+ "eligen_entity_prompts": eligen_entity_prompts, "eligen_entity_masks": eligen_entity_masks, "eligen_enable_on_negative": eligen_enable_on_negative, "eligen_enable_inpaint": eligen_enable_inpaint,
+ "infinityou_id_image": infinityou_id_image, "infinityou_guidance": infinityou_guidance,
+ "flex_inpaint_image": flex_inpaint_image, "flex_inpaint_mask": flex_inpaint_mask, "flex_control_image": flex_control_image, "flex_control_strength": flex_control_strength, "flex_control_stop": flex_control_stop,
+ "value_controller_inputs": value_controller_inputs,
+ "step1x_reference_image": step1x_reference_image,
+ "nexus_gen_reference_image": nexus_gen_reference_image,
+ "lora_encoder_inputs": lora_encoder_inputs, "lora_encoder_scale": lora_encoder_scale,
+ "tea_cache_l1_thresh": tea_cache_l1_thresh,
+ "tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride,
+ "progress_bar_cmd": progress_bar_cmd,
+ }
+ for unit in self.units:
+ inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
+
+ # Denoise
+ self.load_models_to_device(self.in_iteration_models)
+ models = {name: getattr(self, name) for name in self.in_iteration_models}
+ for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
+ timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
+
+ # Inference
+ noise_pred_posi = self.model_fn(**models, **inputs_shared, **inputs_posi, timestep=timestep, progress_id=progress_id)
+ if cfg_scale != 1.0:
+ noise_pred_nega = self.model_fn(**models, **inputs_shared, **inputs_nega, timestep=timestep, progress_id=progress_id)
+ noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
+ else:
+ noise_pred = noise_pred_posi
+
+ # Scheduler
+ inputs_shared["latents"] = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], inputs_shared["latents"])
+
+ # Decode
+ self.load_models_to_device(['vae_decoder'])
+ image = self.vae_decoder(inputs_shared["latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
+ image = self.vae_output_to_image(image)
+ self.load_models_to_device([])
+
+ return image
+
+
+
+class FluxImageUnit_ShapeChecker(PipelineUnit):
+ def __init__(self):
+ super().__init__(input_params=("height", "width"))
+
+ def process(self, pipe: FluxImagePipeline, height, width):
+ height, width = pipe.check_resize_height_width(height, width)
+ return {"height": height, "width": width}
+
+
+
+class FluxImageUnit_NoiseInitializer(PipelineUnit):
+ def __init__(self):
+ super().__init__(input_params=("height", "width", "seed", "rand_device"))
+
+ def process(self, pipe: FluxImagePipeline, height, width, seed, rand_device):
+ noise = pipe.generate_noise((1, 16, height//8, width//8), seed=seed, rand_device=rand_device)
+ return {"noise": noise}
+
+
+
+class FluxImageUnit_InputImageEmbedder(PipelineUnit):
+ def __init__(self):
+ super().__init__(
+ input_params=("input_image", "noise", "tiled", "tile_size", "tile_stride"),
+ onload_model_names=("vae_encoder",)
+ )
+
+ def process(self, pipe: FluxImagePipeline, input_image, noise, tiled, tile_size, tile_stride):
+ if input_image is None:
+ return {"latents": noise, "input_latents": None}
+ pipe.load_models_to_device(['vae_encoder'])
+ image = pipe.preprocess_image(input_image).to(device=pipe.device, dtype=pipe.torch_dtype)
+ input_latents = pipe.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
+ if pipe.scheduler.training:
+ return {"latents": noise, "input_latents": input_latents}
+ else:
+ latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0])
+ return {"latents": latents, "input_latents": None}
+
+
+
+class FluxImageUnit_PromptEmbedder(PipelineUnit):
+ def __init__(self):
+ super().__init__(
+ seperate_cfg=True,
+ input_params_posi={"prompt": "prompt", "positive": "positive"},
+ input_params_nega={"prompt": "negative_prompt", "positive": "positive"},
+ input_params=("t5_sequence_length",),
+ onload_model_names=("text_encoder_1", "text_encoder_2")
+ )
+
+ def process(self, pipe: FluxImagePipeline, prompt, t5_sequence_length, positive) -> dict:
+ if pipe.text_encoder_1 is not None and pipe.text_encoder_2 is not None:
+ prompt_emb, pooled_prompt_emb, text_ids = pipe.prompter.encode_prompt(
+ prompt, device=pipe.device, positive=positive, t5_sequence_length=t5_sequence_length
+ )
+ return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb, "text_ids": text_ids}
+ else:
+ return {}
+
+
+class FluxImageUnit_ImageIDs(PipelineUnit):
+ def __init__(self):
+ super().__init__(input_params=("latents",))
+
+ def process(self, pipe: FluxImagePipeline, latents):
+ latent_image_ids = pipe.dit.prepare_image_ids(latents)
+ return {"image_ids": latent_image_ids}
+
+
+
+class FluxImageUnit_EmbeddedGuidanceEmbedder(PipelineUnit):
+ def __init__(self):
+ super().__init__(input_params=("embedded_guidance", "latents"))
+
+ def process(self, pipe: FluxImagePipeline, embedded_guidance, latents):
+ guidance = torch.Tensor([embedded_guidance] * latents.shape[0]).to(device=latents.device, dtype=latents.dtype)
+ return {"guidance": guidance}
+
+
+
+class FluxImageUnit_Kontext(PipelineUnit):
+ def __init__(self):
+ super().__init__(input_params=("kontext_images", "tiled", "tile_size", "tile_stride"))
+
+ def process(self, pipe: FluxImagePipeline, kontext_images, tiled, tile_size, tile_stride):
+ if kontext_images is None:
+ return {}
+ if not isinstance(kontext_images, list):
+ kontext_images = [kontext_images]
+
+ kontext_latents = []
+ kontext_image_ids = []
+ for kontext_image in kontext_images:
+ kontext_image = pipe.preprocess_image(kontext_image)
+ kontext_latent = pipe.vae_encoder(kontext_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
+ image_ids = pipe.dit.prepare_image_ids(kontext_latent)
+ image_ids[..., 0] = 1
+ kontext_image_ids.append(image_ids)
+ kontext_latent = pipe.dit.patchify(kontext_latent)
+ kontext_latents.append(kontext_latent)
+ kontext_latents = torch.concat(kontext_latents, dim=1)
+ kontext_image_ids = torch.concat(kontext_image_ids, dim=-2)
+ return {"kontext_latents": kontext_latents, "kontext_image_ids": kontext_image_ids}
+
+
+
+class FluxImageUnit_ControlNet(PipelineUnit):
+ def __init__(self):
+ super().__init__(
+ input_params=("controlnet_inputs", "tiled", "tile_size", "tile_stride"),
+ onload_model_names=("vae_encoder",)
+ )
+
+ def apply_controlnet_mask_on_latents(self, pipe, latents, mask):
+ mask = (pipe.preprocess_image(mask) + 1) / 2
+ mask = mask.mean(dim=1, keepdim=True)
+ mask = 1 - torch.nn.functional.interpolate(mask, size=latents.shape[-2:])
+ latents = torch.concat([latents, mask], dim=1)
+ return latents
+
+ def apply_controlnet_mask_on_image(self, pipe, image, mask):
+ mask = mask.resize(image.size)
+ mask = pipe.preprocess_image(mask).mean(dim=[0, 1]).cpu()
+ image = np.array(image)
+ image[mask > 0] = 0
+ image = Image.fromarray(image)
+ return image
+
+ def process(self, pipe: FluxImagePipeline, controlnet_inputs: list[ControlNetInput], tiled, tile_size, tile_stride):
+ if controlnet_inputs is None:
+ return {}
+ pipe.load_models_to_device(['vae_encoder'])
+ conditionings = []
+ for controlnet_input in controlnet_inputs:
+ image = controlnet_input.image
+ if controlnet_input.inpaint_mask is not None:
+ image = self.apply_controlnet_mask_on_image(pipe, image, controlnet_input.inpaint_mask)
+
+ image = pipe.preprocess_image(image).to(device=pipe.device, dtype=pipe.torch_dtype)
+ image = pipe.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
+
+ if controlnet_input.inpaint_mask is not None:
+ image = self.apply_controlnet_mask_on_latents(pipe, image, controlnet_input.inpaint_mask)
+ conditionings.append(image)
+ return {"controlnet_conditionings": conditionings}
+
+
+
+class FluxImageUnit_IPAdapter(PipelineUnit):
+ def __init__(self):
+ super().__init__(
+ take_over=True,
+ onload_model_names=("ipadapter_image_encoder", "ipadapter")
+ )
+
+ def process(self, pipe: FluxImagePipeline, inputs_shared, inputs_posi, inputs_nega):
+ ipadapter_images, ipadapter_scale = inputs_shared.get("ipadapter_images", None), inputs_shared.get("ipadapter_scale", 1.0)
+ if ipadapter_images is None:
+ return inputs_shared, inputs_posi, inputs_nega
+ if not isinstance(ipadapter_images, list):
+ ipadapter_images = [ipadapter_images]
+
+ pipe.load_models_to_device(self.onload_model_names)
+ images = [image.convert("RGB").resize((384, 384), resample=3) for image in ipadapter_images]
+ images = [pipe.preprocess_image(image).to(device=pipe.device, dtype=pipe.torch_dtype) for image in images]
+ ipadapter_images = torch.cat(images, dim=0)
+ ipadapter_image_encoding = pipe.ipadapter_image_encoder(ipadapter_images).pooler_output
+
+ inputs_posi.update({"ipadapter_kwargs_list": pipe.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)})
+ if inputs_shared.get("cfg_scale", 1.0) != 1.0:
+ inputs_nega.update({"ipadapter_kwargs_list": pipe.ipadapter(torch.zeros_like(ipadapter_image_encoding))})
+ return inputs_shared, inputs_posi, inputs_nega
+
+
+
+class FluxImageUnit_EntityControl(PipelineUnit):
+ def __init__(self):
+ super().__init__(
+ take_over=True,
+ onload_model_names=("text_encoder_1", "text_encoder_2")
+ )
+
+ def preprocess_masks(self, pipe, masks, height, width, dim):
+ out_masks = []
+ for mask in masks:
+ mask = pipe.preprocess_image(mask.resize((width, height), resample=Image.NEAREST)).mean(dim=1, keepdim=True) > 0
+ mask = mask.repeat(1, dim, 1, 1).to(device=pipe.device, dtype=pipe.torch_dtype)
+ out_masks.append(mask)
+ return out_masks
+
+ def prepare_entity_inputs(self, pipe, entity_prompts, entity_masks, width, height, t5_sequence_length=512):
+ entity_masks = self.preprocess_masks(pipe, entity_masks, height//8, width//8, 1)
+ entity_masks = torch.cat(entity_masks, dim=0).unsqueeze(0) # b, n_mask, c, h, w
+
+ prompt_emb, _, _ = pipe.prompter.encode_prompt(
+ entity_prompts, device=pipe.device, t5_sequence_length=t5_sequence_length
+ )
+ return prompt_emb.unsqueeze(0), entity_masks
+
+ def prepare_eligen(self, pipe, prompt_emb_nega, eligen_entity_prompts, eligen_entity_masks, width, height, t5_sequence_length, enable_eligen_on_negative, cfg_scale):
+ entity_prompt_emb_posi, entity_masks_posi = self.prepare_entity_inputs(pipe, eligen_entity_prompts, eligen_entity_masks, width, height, t5_sequence_length)
+ if enable_eligen_on_negative and cfg_scale != 1.0:
+ entity_prompt_emb_nega = prompt_emb_nega['prompt_emb'].unsqueeze(1).repeat(1, entity_masks_posi.shape[1], 1, 1)
+ entity_masks_nega = entity_masks_posi
+ else:
+ entity_prompt_emb_nega, entity_masks_nega = None, None
+ eligen_kwargs_posi = {"entity_prompt_emb": entity_prompt_emb_posi, "entity_masks": entity_masks_posi}
+ eligen_kwargs_nega = {"entity_prompt_emb": entity_prompt_emb_nega, "entity_masks": entity_masks_nega}
+ return eligen_kwargs_posi, eligen_kwargs_nega
+
+ def process(self, pipe: FluxImagePipeline, inputs_shared, inputs_posi, inputs_nega):
+ eligen_entity_prompts, eligen_entity_masks = inputs_shared.get("eligen_entity_prompts", None), inputs_shared.get("eligen_entity_masks", None)
+ if eligen_entity_prompts is None or eligen_entity_masks is None:
+ return inputs_shared, inputs_posi, inputs_nega
+ pipe.load_models_to_device(self.onload_model_names)
+ eligen_enable_on_negative = inputs_shared.get("eligen_enable_on_negative", False)
+ eligen_kwargs_posi, eligen_kwargs_nega = self.prepare_eligen(pipe, inputs_nega,
+ eligen_entity_prompts, eligen_entity_masks, inputs_shared["width"], inputs_shared["height"],
+ inputs_shared["t5_sequence_length"], eligen_enable_on_negative, inputs_shared["cfg_scale"])
+ inputs_posi.update(eligen_kwargs_posi)
+ if inputs_shared.get("cfg_scale", 1.0) != 1.0:
+ inputs_nega.update(eligen_kwargs_nega)
+ return inputs_shared, inputs_posi, inputs_nega
+
+
+class FluxImageUnit_NexusGen(PipelineUnit):
+ def __init__(self):
+ super().__init__(
+ take_over=True,
+ onload_model_names=("nexus_gen", "nexus_gen_generation_adapter", "nexus_gen_editing_adapter"),
+ )
+
+ def process(self, pipe: FluxImagePipeline, inputs_shared, inputs_posi, inputs_nega):
+ if pipe.nexus_gen is None:
+ return inputs_shared, inputs_posi, inputs_nega
+ pipe.load_models_to_device(self.onload_model_names)
+ if inputs_shared.get("nexus_gen_reference_image", None) is None:
+ assert pipe.nexus_gen_generation_adapter is not None, "NexusGen requires a generation adapter to be set."
+ embed = pipe.nexus_gen(inputs_posi["prompt"])[0].unsqueeze(0)
+ inputs_posi["prompt_emb"] = pipe.nexus_gen_generation_adapter(embed)
+ inputs_posi['text_ids'] = torch.zeros(embed.shape[0], embed.shape[1], 3).to(device=pipe.device, dtype=pipe.torch_dtype)
+ else:
+ assert pipe.nexus_gen_editing_adapter is not None, "NexusGen requires an editing adapter to be set."
+ embed, ref_embed, grids = pipe.nexus_gen(inputs_posi["prompt"], inputs_shared["nexus_gen_reference_image"])
+ embeds_grid = grids[1:2].to(device=pipe.device, dtype=torch.long)
+ ref_embeds_grid = grids[0:1].to(device=pipe.device, dtype=torch.long)
+
+ inputs_posi["prompt_emb"] = pipe.nexus_gen_editing_adapter(embed.unsqueeze(0), embeds_grid, ref_embed.unsqueeze(0), ref_embeds_grid)
+ inputs_posi["text_ids"] = self.get_editing_text_ids(
+ inputs_shared["latents"],
+ embeds_grid[0][1].item(), embeds_grid[0][2].item(),
+ ref_embeds_grid[0][1].item(), ref_embeds_grid[0][2].item(),
+ )
+ return inputs_shared, inputs_posi, inputs_nega
+
+
+ def get_editing_text_ids(self, latents, target_embed_height, target_embed_width, ref_embed_height, ref_embed_width):
+ # prepare text ids for target and reference embeddings
+ batch_size, height, width = latents.shape[0], target_embed_height, target_embed_width
+ embed_ids = torch.zeros(height // 2, width // 2, 3)
+ scale_factor_height, scale_factor_width = latents.shape[-2] / height, latents.shape[-1] / width
+ embed_ids[..., 1] = embed_ids[..., 1] + torch.arange(height // 2)[:, None] * scale_factor_height
+ embed_ids[..., 2] = embed_ids[..., 2] + torch.arange(width // 2)[None, :] * scale_factor_width
+ embed_ids = embed_ids[None, :].repeat(batch_size, 1, 1, 1).reshape(batch_size, height // 2 * width // 2, 3)
+ embed_text_ids = embed_ids.to(device=latents.device, dtype=latents.dtype)
+
+ batch_size, height, width = latents.shape[0], ref_embed_height, ref_embed_width
+ ref_embed_ids = torch.zeros(height // 2, width // 2, 3)
+ scale_factor_height, scale_factor_width = latents.shape[-2] / height, latents.shape[-1] / width
+ ref_embed_ids[..., 0] = ref_embed_ids[..., 0] + 1.0
+ ref_embed_ids[..., 1] = ref_embed_ids[..., 1] + torch.arange(height // 2)[:, None] * scale_factor_height
+ ref_embed_ids[..., 2] = ref_embed_ids[..., 2] + torch.arange(width // 2)[None, :] * scale_factor_width
+ ref_embed_ids = ref_embed_ids[None, :].repeat(batch_size, 1, 1, 1).reshape(batch_size, height // 2 * width // 2, 3)
+ ref_embed_text_ids = ref_embed_ids.to(device=latents.device, dtype=latents.dtype)
+
+ text_ids = torch.cat([embed_text_ids, ref_embed_text_ids], dim=1)
+ return text_ids
+
+
+class FluxImageUnit_Step1x(PipelineUnit):
+ def __init__(self):
+ super().__init__(take_over=True,onload_model_names=("qwenvl","vae_encoder"))
+
+ def process(self, pipe: FluxImagePipeline, inputs_shared: dict, inputs_posi: dict, inputs_nega: dict):
+ image = inputs_shared.get("step1x_reference_image",None)
+ if image is None:
+ return inputs_shared, inputs_posi, inputs_nega
+ else:
+ pipe.load_models_to_device(self.onload_model_names)
+ prompt = inputs_posi["prompt"]
+ nega_prompt = inputs_nega["negative_prompt"]
+ captions = [prompt, nega_prompt]
+ ref_images = [image, image]
+ embs, masks = pipe.qwenvl(captions, ref_images)
+ image = pipe.preprocess_image(image).to(device=pipe.device, dtype=pipe.torch_dtype)
+ image = pipe.vae_encoder(image)
+ inputs_posi.update({"step1x_llm_embedding": embs[0:1], "step1x_mask": masks[0:1], "step1x_reference_latents": image})
+ if inputs_shared.get("cfg_scale", 1) != 1:
+ inputs_nega.update({"step1x_llm_embedding": embs[1:2], "step1x_mask": masks[1:2], "step1x_reference_latents": image})
+ return inputs_shared, inputs_posi, inputs_nega
+
+
+class FluxImageUnit_TeaCache(PipelineUnit):
+ def __init__(self):
+ super().__init__(input_params=("num_inference_steps","tea_cache_l1_thresh"))
+
+ def process(self, pipe: FluxImagePipeline, num_inference_steps, tea_cache_l1_thresh):
+ if tea_cache_l1_thresh is None:
+ return {}
+ else:
+ return {"tea_cache": TeaCache(num_inference_steps=num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh)}
+
+class FluxImageUnit_Flex(PipelineUnit):
+ def __init__(self):
+ super().__init__(
+ input_params=("latents", "flex_inpaint_image", "flex_inpaint_mask", "flex_control_image", "flex_control_strength", "flex_control_stop", "tiled", "tile_size", "tile_stride"),
+ onload_model_names=("vae_encoder",)
+ )
+
+ def process(self, pipe: FluxImagePipeline, latents, flex_inpaint_image, flex_inpaint_mask, flex_control_image, flex_control_strength, flex_control_stop, tiled, tile_size, tile_stride):
+ if pipe.dit.input_dim == 196:
+ if flex_control_stop is None:
+ flex_control_stop = 1
+ pipe.load_models_to_device(self.onload_model_names)
+ if flex_inpaint_image is None:
+ flex_inpaint_image = torch.zeros_like(latents)
+ else:
+ flex_inpaint_image = pipe.preprocess_image(flex_inpaint_image).to(device=pipe.device, dtype=pipe.torch_dtype)
+ flex_inpaint_image = pipe.vae_encoder(flex_inpaint_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
+ if flex_inpaint_mask is None:
+ flex_inpaint_mask = torch.ones_like(latents)[:, 0:1, :, :]
+ else:
+ flex_inpaint_mask = flex_inpaint_mask.resize((latents.shape[3], latents.shape[2]))
+ flex_inpaint_mask = pipe.preprocess_image(flex_inpaint_mask).to(device=pipe.device, dtype=pipe.torch_dtype)
+ flex_inpaint_mask = (flex_inpaint_mask[:, 0:1, :, :] + 1) / 2
+ flex_inpaint_image = flex_inpaint_image * (1 - flex_inpaint_mask)
+ if flex_control_image is None:
+ flex_control_image = torch.zeros_like(latents)
+ else:
+ flex_control_image = pipe.preprocess_image(flex_control_image).to(device=pipe.device, dtype=pipe.torch_dtype)
+ flex_control_image = pipe.vae_encoder(flex_control_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) * flex_control_strength
+ flex_condition = torch.concat([flex_inpaint_image, flex_inpaint_mask, flex_control_image], dim=1)
+ flex_uncondition = torch.concat([flex_inpaint_image, flex_inpaint_mask, torch.zeros_like(flex_control_image)], dim=1)
+ flex_control_stop_timestep = pipe.scheduler.timesteps[int(flex_control_stop * (len(pipe.scheduler.timesteps) - 1))]
+ return {"flex_condition": flex_condition, "flex_uncondition": flex_uncondition, "flex_control_stop_timestep": flex_control_stop_timestep}
+ else:
+ return {}
+
+
+
+class FluxImageUnit_InfiniteYou(PipelineUnit):
+ def __init__(self):
+ super().__init__(
+ input_params=("infinityou_id_image", "infinityou_guidance"),
+ onload_model_names=("infinityou_processor",)
+ )
+
+ def process(self, pipe: FluxImagePipeline, infinityou_id_image, infinityou_guidance):
+ pipe.load_models_to_device("infinityou_processor")
+ if infinityou_id_image is not None:
+ return pipe.infinityou_processor.prepare_infinite_you(pipe.image_proj_model, infinityou_id_image, infinityou_guidance, pipe.device)
+ else:
+ return {}
+
+
+
+class FluxImageUnit_ValueControl(PipelineUnit):
+ def __init__(self):
+ super().__init__(
+ seperate_cfg=True,
+ input_params_posi={"prompt_emb": "prompt_emb", "text_ids": "text_ids"},
+ input_params_nega={"prompt_emb": "prompt_emb", "text_ids": "text_ids"},
+ input_params=("value_controller_inputs",),
+ onload_model_names=("value_controller",)
+ )
+
+ def add_to_text_embedding(self, prompt_emb, text_ids, value_emb):
+ prompt_emb = torch.concat([prompt_emb, value_emb], dim=1)
+ extra_text_ids = torch.zeros((value_emb.shape[0], value_emb.shape[1], 3), device=value_emb.device, dtype=value_emb.dtype)
+ text_ids = torch.concat([text_ids, extra_text_ids], dim=1)
+ return prompt_emb, text_ids
+
+ def process(self, pipe: FluxImagePipeline, prompt_emb, text_ids, value_controller_inputs):
+ if value_controller_inputs is None:
+ return {}
+ if not isinstance(value_controller_inputs, list):
+ value_controller_inputs = [value_controller_inputs]
+ value_controller_inputs = torch.tensor(value_controller_inputs).to(dtype=pipe.torch_dtype, device=pipe.device)
+ pipe.load_models_to_device(["value_controller"])
+ value_emb = pipe.value_controller(value_controller_inputs, pipe.torch_dtype)
+ value_emb = value_emb.unsqueeze(0)
+ prompt_emb, text_ids = self.add_to_text_embedding(prompt_emb, text_ids, value_emb)
+ return {"prompt_emb": prompt_emb, "text_ids": text_ids}
+
+
+
+class InfinitYou(torch.nn.Module):
+ def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
+ super().__init__()
+ from facexlib.recognition import init_recognition_model
+ from insightface.app import FaceAnalysis
+ self.device = device
+ self.torch_dtype = torch_dtype
+ insightface_root_path = 'models/ByteDance/InfiniteYou/supports/insightface'
+ self.app_640 = FaceAnalysis(name='antelopev2', root=insightface_root_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
+ self.app_640.prepare(ctx_id=0, det_size=(640, 640))
+ self.app_320 = FaceAnalysis(name='antelopev2', root=insightface_root_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
+ self.app_320.prepare(ctx_id=0, det_size=(320, 320))
+ self.app_160 = FaceAnalysis(name='antelopev2', root=insightface_root_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
+ self.app_160.prepare(ctx_id=0, det_size=(160, 160))
+ self.arcface_model = init_recognition_model('arcface', device=self.device).to(torch_dtype)
+
+ def _detect_face(self, id_image_cv2):
+ face_info = self.app_640.get(id_image_cv2)
+ if len(face_info) > 0:
+ return face_info
+ face_info = self.app_320.get(id_image_cv2)
+ if len(face_info) > 0:
+ return face_info
+ face_info = self.app_160.get(id_image_cv2)
+ return face_info
+
+ def extract_arcface_bgr_embedding(self, in_image, landmark, device):
+ from insightface.utils import face_align
+ arc_face_image = face_align.norm_crop(in_image, landmark=np.array(landmark), image_size=112)
+ arc_face_image = torch.from_numpy(arc_face_image).unsqueeze(0).permute(0, 3, 1, 2) / 255.
+ arc_face_image = 2 * arc_face_image - 1
+ arc_face_image = arc_face_image.contiguous().to(device=device, dtype=self.torch_dtype)
+ face_emb = self.arcface_model(arc_face_image)[0] # [512], normalized
+ return face_emb
+
+ def prepare_infinite_you(self, model, id_image, infinityou_guidance, device):
+ import cv2
+ if id_image is None:
+ return {'id_emb': None}
+ id_image_cv2 = cv2.cvtColor(np.array(id_image), cv2.COLOR_RGB2BGR)
+ face_info = self._detect_face(id_image_cv2)
+ if len(face_info) == 0:
+ raise ValueError('No face detected in the input ID image')
+ landmark = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]))[-1]['kps'] # only use the maximum face
+ id_emb = self.extract_arcface_bgr_embedding(id_image_cv2, landmark, device)
+ id_emb = model(id_emb.unsqueeze(0).reshape([1, -1, 512]).to(dtype=self.torch_dtype))
+ infinityou_guidance = torch.Tensor([infinityou_guidance]).to(device=device, dtype=self.torch_dtype)
+ return {'id_emb': id_emb, 'infinityou_guidance': infinityou_guidance}
+
+
+
+class FluxImageUnit_LoRAEncode(PipelineUnit):
+ def __init__(self):
+ super().__init__(
+ take_over=True,
+ onload_model_names=("lora_encoder",)
+ )
+
+ def parse_lora_encoder_inputs(self, lora_encoder_inputs):
+ if not isinstance(lora_encoder_inputs, list):
+ lora_encoder_inputs = [lora_encoder_inputs]
+ lora_configs = []
+ for lora_encoder_input in lora_encoder_inputs:
+ if isinstance(lora_encoder_input, str):
+ lora_encoder_input = ModelConfig(path=lora_encoder_input)
+ lora_encoder_input.download_if_necessary()
+ lora_configs.append(lora_encoder_input)
+ return lora_configs
+
+ def load_lora(self, lora_config, dtype, device):
+ loader = FluxLoRALoader(torch_dtype=dtype, device=device)
+ lora = load_state_dict(lora_config.path, torch_dtype=dtype, device=device)
+ lora = loader.convert_state_dict(lora)
+ return lora
+
+ def lora_embedding(self, pipe, lora_encoder_inputs):
+ lora_emb = []
+ for lora_config in self.parse_lora_encoder_inputs(lora_encoder_inputs):
+ lora = self.load_lora(lora_config, pipe.torch_dtype, pipe.device)
+ lora_emb.append(pipe.lora_encoder(lora))
+ lora_emb = torch.concat(lora_emb, dim=1)
+ return lora_emb
+
+ def add_to_text_embedding(self, prompt_emb, text_ids, lora_emb):
+ prompt_emb = torch.concat([prompt_emb, lora_emb], dim=1)
+ extra_text_ids = torch.zeros((lora_emb.shape[0], lora_emb.shape[1], 3), device=lora_emb.device, dtype=lora_emb.dtype)
+ text_ids = torch.concat([text_ids, extra_text_ids], dim=1)
+ return prompt_emb, text_ids
+
+ def process(self, pipe: FluxImagePipeline, inputs_shared, inputs_posi, inputs_nega):
+ if inputs_shared.get("lora_encoder_inputs", None) is None:
+ return inputs_shared, inputs_posi, inputs_nega
+
+ # Encode
+ pipe.load_models_to_device(["lora_encoder"])
+ lora_encoder_inputs = inputs_shared["lora_encoder_inputs"]
+ lora_emb = self.lora_embedding(pipe, lora_encoder_inputs)
+
+ # Scale
+ lora_encoder_scale = inputs_shared.get("lora_encoder_scale", None)
+ if lora_encoder_scale is not None:
+ lora_emb = lora_emb * lora_encoder_scale
+
+ # Add to prompt embedding
+ inputs_posi["prompt_emb"], inputs_posi["text_ids"] = self.add_to_text_embedding(
+ inputs_posi["prompt_emb"], inputs_posi["text_ids"], lora_emb)
+ return inputs_shared, inputs_posi, inputs_nega
+
+
+
+class TeaCache:
+ def __init__(self, num_inference_steps, rel_l1_thresh):
+ self.num_inference_steps = num_inference_steps
+ self.step = 0
+ self.accumulated_rel_l1_distance = 0
+ self.previous_modulated_input = None
+ self.rel_l1_thresh = rel_l1_thresh
+ self.previous_residual = None
+ self.previous_hidden_states = None
+
+ def check(self, dit: FluxDiT, hidden_states, conditioning):
+ inp = hidden_states.clone()
+ temb_ = conditioning.clone()
+ modulated_inp, _, _, _, _ = dit.blocks[0].norm1_a(inp, emb=temb_)
+ if self.step == 0 or self.step == self.num_inference_steps - 1:
+ should_calc = True
+ self.accumulated_rel_l1_distance = 0
+ else:
+ coefficients = [4.98651651e+02, -2.83781631e+02, 5.58554382e+01, -3.82021401e+00, 2.64230861e-01]
+ rescale_func = np.poly1d(coefficients)
+ self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
+ if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
+ should_calc = False
+ else:
+ should_calc = True
+ self.accumulated_rel_l1_distance = 0
+ self.previous_modulated_input = modulated_inp
+ self.step += 1
+ if self.step == self.num_inference_steps:
+ self.step = 0
+ if should_calc:
+ self.previous_hidden_states = hidden_states.clone()
+ return not should_calc
+
+ def store(self, hidden_states):
+ self.previous_residual = hidden_states - self.previous_hidden_states
+ self.previous_hidden_states = None
+
+ def update(self, hidden_states):
+ hidden_states = hidden_states + self.previous_residual
+ return hidden_states
+
+
+def model_fn_flux_image(
+ dit: FluxDiT,
+ controlnet=None,
+ step1x_connector=None,
+ latents=None,
+ timestep=None,
+ prompt_emb=None,
+ pooled_prompt_emb=None,
+ guidance=None,
+ text_ids=None,
+ image_ids=None,
+ kontext_latents=None,
+ kontext_image_ids=None,
+ controlnet_inputs=None,
+ controlnet_conditionings=None,
+ tiled=False,
+ tile_size=128,
+ tile_stride=64,
+ entity_prompt_emb=None,
+ entity_masks=None,
+ ipadapter_kwargs_list={},
+ id_emb=None,
+ infinityou_guidance=None,
+ flex_condition=None,
+ flex_uncondition=None,
+ flex_control_stop_timestep=None,
+ step1x_llm_embedding=None,
+ step1x_mask=None,
+ step1x_reference_latents=None,
+ tea_cache: TeaCache = None,
+ progress_id=0,
+ num_inference_steps=1,
+ use_gradient_checkpointing=False,
+ use_gradient_checkpointing_offload=False,
+ **kwargs
+):
+ if tiled:
+ def flux_forward_fn(hl, hr, wl, wr):
+ tiled_controlnet_conditionings = [f[:, :, hl: hr, wl: wr] for f in controlnet_conditionings] if controlnet_conditionings is not None else None
+ return model_fn_flux_image(
+ dit=dit,
+ controlnet=controlnet,
+ latents=latents[:, :, hl: hr, wl: wr],
+ timestep=timestep,
+ prompt_emb=prompt_emb,
+ pooled_prompt_emb=pooled_prompt_emb,
+ guidance=guidance,
+ text_ids=text_ids,
+ image_ids=None,
+ controlnet_inputs=controlnet_inputs,
+ controlnet_conditionings=tiled_controlnet_conditionings,
+ tiled=False,
+ **kwargs
+ )
+ return FastTileWorker().tiled_forward(
+ flux_forward_fn,
+ latents,
+ tile_size=tile_size,
+ tile_stride=tile_stride,
+ tile_device=latents.device,
+ tile_dtype=latents.dtype
+ )
+
+ hidden_states = latents
+
+ # ControlNet
+ if controlnet is not None and controlnet_conditionings is not None:
+ controlnet_extra_kwargs = {
+ "hidden_states": hidden_states,
+ "timestep": timestep,
+ "prompt_emb": prompt_emb,
+ "pooled_prompt_emb": pooled_prompt_emb,
+ "guidance": guidance,
+ "text_ids": text_ids,
+ "image_ids": image_ids,
+ "controlnet_inputs": controlnet_inputs,
+ "tiled": tiled,
+ "tile_size": tile_size,
+ "tile_stride": tile_stride,
+ "progress_id": progress_id,
+ "num_inference_steps": num_inference_steps,
+ }
+ if id_emb is not None:
+ controlnet_text_ids = torch.zeros(id_emb.shape[0], id_emb.shape[1], 3).to(device=hidden_states.device, dtype=hidden_states.dtype)
+ controlnet_extra_kwargs.update({"prompt_emb": id_emb, 'text_ids': controlnet_text_ids, 'guidance': infinityou_guidance})
+ controlnet_res_stack, controlnet_single_res_stack = controlnet(
+ controlnet_conditionings, **controlnet_extra_kwargs
+ )
+
+ # Flex
+ if flex_condition is not None:
+ if timestep.tolist()[0] >= flex_control_stop_timestep:
+ hidden_states = torch.concat([hidden_states, flex_condition], dim=1)
+ else:
+ hidden_states = torch.concat([hidden_states, flex_uncondition], dim=1)
+
+ # Step1x
+ if step1x_llm_embedding is not None:
+ prompt_emb, pooled_prompt_emb = step1x_connector(step1x_llm_embedding, timestep / 1000, step1x_mask)
+ text_ids = torch.zeros((1, prompt_emb.shape[1], 3), dtype=prompt_emb.dtype, device=prompt_emb.device)
+
+ if image_ids is None:
+ image_ids = dit.prepare_image_ids(hidden_states)
+
+ conditioning = dit.time_embedder(timestep, hidden_states.dtype) + dit.pooled_text_embedder(pooled_prompt_emb)
+ if dit.guidance_embedder is not None:
+ guidance = guidance * 1000
+ conditioning = conditioning + dit.guidance_embedder(guidance, hidden_states.dtype)
+
+ height, width = hidden_states.shape[-2:]
+ hidden_states = dit.patchify(hidden_states)
+
+ # Kontext
+ if kontext_latents is not None:
+ image_ids = torch.concat([image_ids, kontext_image_ids], dim=-2)
+ hidden_states = torch.concat([hidden_states, kontext_latents], dim=1)
+
+ # Step1x
+ if step1x_reference_latents is not None:
+ step1x_reference_image_ids = dit.prepare_image_ids(step1x_reference_latents)
+ step1x_reference_latents = dit.patchify(step1x_reference_latents)
+ image_ids = torch.concat([image_ids, step1x_reference_image_ids], dim=-2)
+ hidden_states = torch.concat([hidden_states, step1x_reference_latents], dim=1)
+
+ hidden_states = dit.x_embedder(hidden_states)
+
+ # EliGen
+ if entity_prompt_emb is not None and entity_masks is not None:
+ prompt_emb, image_rotary_emb, attention_mask = dit.process_entity_masks(hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids, latents.shape[1])
+ else:
+ prompt_emb = dit.context_embedder(prompt_emb)
+ image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
+ attention_mask = None
+
+ # TeaCache
+ if tea_cache is not None:
+ tea_cache_update = tea_cache.check(dit, hidden_states, conditioning)
+ else:
+ tea_cache_update = False
+
+ if tea_cache_update:
+ hidden_states = tea_cache.update(hidden_states)
+ else:
+ # Joint Blocks
+ for block_id, block in enumerate(dit.blocks):
+ hidden_states, prompt_emb = gradient_checkpoint_forward(
+ block,
+ use_gradient_checkpointing,
+ use_gradient_checkpointing_offload,
+ hidden_states,
+ prompt_emb,
+ conditioning,
+ image_rotary_emb,
+ attention_mask,
+ ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None),
+ )
+ # ControlNet
+ if controlnet is not None and controlnet_conditionings is not None and controlnet_res_stack is not None:
+ if kontext_latents is None:
+ hidden_states = hidden_states + controlnet_res_stack[block_id]
+ else:
+ hidden_states[:, :-kontext_latents.shape[1]] = hidden_states[:, :-kontext_latents.shape[1]] + controlnet_res_stack[block_id]
+
+ # Single Blocks
+ hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
+ num_joint_blocks = len(dit.blocks)
+ for block_id, block in enumerate(dit.single_blocks):
+ hidden_states, prompt_emb = gradient_checkpoint_forward(
+ block,
+ use_gradient_checkpointing,
+ use_gradient_checkpointing_offload,
+ hidden_states,
+ prompt_emb,
+ conditioning,
+ image_rotary_emb,
+ attention_mask,
+ ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id + num_joint_blocks, None),
+ )
+ # ControlNet
+ if controlnet is not None and controlnet_conditionings is not None and controlnet_single_res_stack is not None:
+ if kontext_latents is None:
+ hidden_states[:, prompt_emb.shape[1]:] = hidden_states[:, prompt_emb.shape[1]:] + controlnet_single_res_stack[block_id]
+ else:
+ hidden_states[:, prompt_emb.shape[1]:-kontext_latents.shape[1]] = hidden_states[:, prompt_emb.shape[1]:-kontext_latents.shape[1]] + controlnet_single_res_stack[block_id]
+ hidden_states = hidden_states[:, prompt_emb.shape[1]:]
+
+ if tea_cache is not None:
+ tea_cache.store(hidden_states)
+
+ hidden_states = dit.final_norm_out(hidden_states, conditioning)
+ hidden_states = dit.final_proj_out(hidden_states)
+
+ # Step1x
+ if step1x_reference_latents is not None:
+ hidden_states = hidden_states[:, :hidden_states.shape[1] // 2]
+
+ # Kontext
+ if kontext_latents is not None:
+ hidden_states = hidden_states[:, :-kontext_latents.shape[1]]
+
+ hidden_states = dit.unpatchify(hidden_states, height, width)
+
+ return hidden_states
diff --git a/diffsynth/pipelines/hunyuan_image.py b/diffsynth/pipelines/hunyuan_image.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c6f6d5dedc6aac50b06a9f10701f7f8ab33117f
--- /dev/null
+++ b/diffsynth/pipelines/hunyuan_image.py
@@ -0,0 +1,288 @@
+from ..models.hunyuan_dit import HunyuanDiT
+from ..models.hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
+from ..models.sdxl_vae_encoder import SDXLVAEEncoder
+from ..models.sdxl_vae_decoder import SDXLVAEDecoder
+from ..models import ModelManager
+from ..prompters import HunyuanDiTPrompter
+from ..schedulers import EnhancedDDIMScheduler
+from .base import BasePipeline
+import torch
+from tqdm import tqdm
+import numpy as np
+
+
+
+class ImageSizeManager:
+ def __init__(self):
+ pass
+
+
+ def _to_tuple(self, x):
+ if isinstance(x, int):
+ return x, x
+ else:
+ return x
+
+
+ def get_fill_resize_and_crop(self, src, tgt):
+ th, tw = self._to_tuple(tgt)
+ h, w = self._to_tuple(src)
+
+ tr = th / tw # base 分辨率
+ r = h / w # 目标分辨率
+
+ # resize
+ if r > tr:
+ resize_height = th
+ resize_width = int(round(th / h * w))
+ else:
+ resize_width = tw
+ resize_height = int(round(tw / w * h)) # 根据base分辨率,将目标分辨率resize下来
+
+ crop_top = int(round((th - resize_height) / 2.0))
+ crop_left = int(round((tw - resize_width) / 2.0))
+
+ return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
+
+
+ def get_meshgrid(self, start, *args):
+ if len(args) == 0:
+ # start is grid_size
+ num = self._to_tuple(start)
+ start = (0, 0)
+ stop = num
+ elif len(args) == 1:
+ # start is start, args[0] is stop, step is 1
+ start = self._to_tuple(start)
+ stop = self._to_tuple(args[0])
+ num = (stop[0] - start[0], stop[1] - start[1])
+ elif len(args) == 2:
+ # start is start, args[0] is stop, args[1] is num
+ start = self._to_tuple(start) # 左上角 eg: 12,0
+ stop = self._to_tuple(args[0]) # 右下角 eg: 20,32
+ num = self._to_tuple(args[1]) # 目标大小 eg: 32,124
+ else:
+ raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
+
+ grid_h = np.linspace(start[0], stop[0], num[0], endpoint=False, dtype=np.float32) # 12-20 中间差值32份 0-32 中间差值124份
+ grid_w = np.linspace(start[1], stop[1], num[1], endpoint=False, dtype=np.float32)
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
+ grid = np.stack(grid, axis=0) # [2, W, H]
+ return grid
+
+
+ def get_2d_rotary_pos_embed(self, embed_dim, start, *args, use_real=True):
+ grid = self.get_meshgrid(start, *args) # [2, H, w]
+ grid = grid.reshape([2, 1, *grid.shape[1:]]) # 返回一个采样矩阵 分辨率与目标分辨率一致
+ pos_embed = self.get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real)
+ return pos_embed
+
+
+ def get_2d_rotary_pos_embed_from_grid(self, embed_dim, grid, use_real=False):
+ assert embed_dim % 4 == 0
+
+ # use half of dimensions to encode grid_h
+ emb_h = self.get_1d_rotary_pos_embed(embed_dim // 2, grid[0].reshape(-1), use_real=use_real) # (H*W, D/4)
+ emb_w = self.get_1d_rotary_pos_embed(embed_dim // 2, grid[1].reshape(-1), use_real=use_real) # (H*W, D/4)
+
+ if use_real:
+ cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D/2)
+ sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D/2)
+ return cos, sin
+ else:
+ emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2)
+ return emb
+
+
+ def get_1d_rotary_pos_embed(self, dim: int, pos, theta: float = 10000.0, use_real=False):
+ if isinstance(pos, int):
+ pos = np.arange(pos)
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [D/2]
+ t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S]
+ freqs = torch.outer(t, freqs).float() # type: ignore # [S, D/2]
+ if use_real:
+ freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
+ freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
+ return freqs_cos, freqs_sin
+ else:
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
+ return freqs_cis
+
+
+ def calc_rope(self, height, width):
+ patch_size = 2
+ head_size = 88
+ th = height // 8 // patch_size
+ tw = width // 8 // patch_size
+ base_size = 512 // 8 // patch_size
+ start, stop = self.get_fill_resize_and_crop((th, tw), base_size)
+ sub_args = [start, stop, (th, tw)]
+ rope = self.get_2d_rotary_pos_embed(head_size, *sub_args)
+ return rope
+
+
+
+class HunyuanDiTImagePipeline(BasePipeline):
+
+ def __init__(self, device="cuda", torch_dtype=torch.float16):
+ super().__init__(device=device, torch_dtype=torch_dtype, height_division_factor=16, width_division_factor=16)
+ self.scheduler = EnhancedDDIMScheduler(prediction_type="v_prediction", beta_start=0.00085, beta_end=0.03)
+ self.prompter = HunyuanDiTPrompter()
+ self.image_size_manager = ImageSizeManager()
+ # models
+ self.text_encoder: HunyuanDiTCLIPTextEncoder = None
+ self.text_encoder_t5: HunyuanDiTT5TextEncoder = None
+ self.dit: HunyuanDiT = None
+ self.vae_decoder: SDXLVAEDecoder = None
+ self.vae_encoder: SDXLVAEEncoder = None
+ self.model_names = ['text_encoder', 'text_encoder_t5', 'dit', 'vae_decoder', 'vae_encoder']
+
+
+ def denoising_model(self):
+ return self.dit
+
+
+ def fetch_models(self, model_manager: ModelManager, prompt_refiner_classes=[]):
+ # Main models
+ self.text_encoder = model_manager.fetch_model("hunyuan_dit_clip_text_encoder")
+ self.text_encoder_t5 = model_manager.fetch_model("hunyuan_dit_t5_text_encoder")
+ self.dit = model_manager.fetch_model("hunyuan_dit")
+ self.vae_decoder = model_manager.fetch_model("sdxl_vae_decoder")
+ self.vae_encoder = model_manager.fetch_model("sdxl_vae_encoder")
+ self.prompter.fetch_models(self.text_encoder, self.text_encoder_t5)
+ self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes)
+
+
+ @staticmethod
+ def from_model_manager(model_manager: ModelManager, prompt_refiner_classes=[], device=None):
+ pipe = HunyuanDiTImagePipeline(
+ device=model_manager.device if device is None else device,
+ torch_dtype=model_manager.torch_dtype,
+ )
+ pipe.fetch_models(model_manager, prompt_refiner_classes)
+ return pipe
+
+
+ def encode_image(self, image, tiled=False, tile_size=64, tile_stride=32):
+ latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
+ return latents
+
+
+ def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
+ image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
+ image = self.vae_output_to_image(image)
+ return image
+
+
+ def encode_prompt(self, prompt, clip_skip=1, clip_skip_2=1, positive=True):
+ text_emb, text_emb_mask, text_emb_t5, text_emb_mask_t5 = self.prompter.encode_prompt(
+ prompt,
+ clip_skip=clip_skip,
+ clip_skip_2=clip_skip_2,
+ positive=positive,
+ device=self.device
+ )
+ return {
+ "text_emb": text_emb,
+ "text_emb_mask": text_emb_mask,
+ "text_emb_t5": text_emb_t5,
+ "text_emb_mask_t5": text_emb_mask_t5
+ }
+
+
+ def prepare_extra_input(self, latents=None, tiled=False, tile_size=64, tile_stride=32):
+ batch_size, height, width = latents.shape[0], latents.shape[2] * 8, latents.shape[3] * 8
+ if tiled:
+ height, width = tile_size * 16, tile_size * 16
+ image_meta_size = torch.as_tensor([width, height, width, height, 0, 0]).to(device=self.device)
+ freqs_cis_img = self.image_size_manager.calc_rope(height, width)
+ image_meta_size = torch.stack([image_meta_size] * batch_size)
+ return {
+ "size_emb": image_meta_size,
+ "freq_cis_img": (freqs_cis_img[0].to(dtype=self.torch_dtype, device=self.device), freqs_cis_img[1].to(dtype=self.torch_dtype, device=self.device)),
+ "tiled": tiled,
+ "tile_size": tile_size,
+ "tile_stride": tile_stride
+ }
+
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt,
+ local_prompts=[],
+ masks=[],
+ mask_scales=[],
+ negative_prompt="",
+ cfg_scale=7.5,
+ clip_skip=1,
+ clip_skip_2=1,
+ input_image=None,
+ reference_strengths=[0.4],
+ denoising_strength=1.0,
+ height=1024,
+ width=1024,
+ num_inference_steps=20,
+ tiled=False,
+ tile_size=64,
+ tile_stride=32,
+ seed=None,
+ progress_bar_cmd=tqdm,
+ progress_bar_st=None,
+ ):
+ height, width = self.check_resize_height_width(height, width)
+
+ # Prepare scheduler
+ self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
+
+ # Prepare latent tensors
+ noise = self.generate_noise((1, 4, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
+ if input_image is not None:
+ self.load_models_to_device(['vae_encoder'])
+ image = self.preprocess_image(input_image).to(device=self.device, dtype=torch.float32)
+ latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(self.torch_dtype)
+ latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
+ else:
+ latents = noise.clone()
+
+ # Encode prompts
+ self.load_models_to_device(['text_encoder', 'text_encoder_t5'])
+ prompt_emb_posi = self.encode_prompt(prompt, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=True)
+ if cfg_scale != 1.0:
+ prompt_emb_nega = self.encode_prompt(negative_prompt, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=True)
+ prompt_emb_locals = [self.encode_prompt(prompt_local, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=True) for prompt_local in local_prompts]
+
+ # Prepare positional id
+ extra_input = self.prepare_extra_input(latents, tiled, tile_size)
+
+ # Denoise
+ self.load_models_to_device(['dit'])
+ for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
+ timestep = torch.tensor([timestep]).to(dtype=self.torch_dtype, device=self.device)
+
+ # Positive side
+ inference_callback = lambda prompt_emb_posi: self.dit(latents, timestep=timestep, **prompt_emb_posi, **extra_input)
+ noise_pred_posi = self.control_noise_via_local_prompts(prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback)
+
+ if cfg_scale != 1.0:
+ # Negative side
+ noise_pred_nega = self.dit(
+ latents, timestep=timestep, **prompt_emb_nega, **extra_input,
+ )
+ # Classifier-free guidance
+ noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
+ else:
+ noise_pred = noise_pred_posi
+
+ latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
+
+ if progress_bar_st is not None:
+ progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
+
+ # Decode image
+ self.load_models_to_device(['vae_decoder'])
+ image = self.decode_image(latents.to(torch.float32), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
+
+ # Offload all models
+ self.load_models_to_device([])
+ return image
diff --git a/diffsynth/pipelines/hunyuan_video.py b/diffsynth/pipelines/hunyuan_video.py
new file mode 100644
index 0000000000000000000000000000000000000000..d8a0411e155f293e86a2b64073fa8b25af3d83d5
--- /dev/null
+++ b/diffsynth/pipelines/hunyuan_video.py
@@ -0,0 +1,395 @@
+from ..models import ModelManager, SD3TextEncoder1, HunyuanVideoVAEDecoder, HunyuanVideoVAEEncoder
+from ..models.hunyuan_video_dit import HunyuanVideoDiT
+from ..models.hunyuan_video_text_encoder import HunyuanVideoLLMEncoder
+from ..schedulers.flow_match import FlowMatchScheduler
+from .base import BasePipeline
+from ..prompters import HunyuanVideoPrompter
+import torch
+import torchvision.transforms as transforms
+from einops import rearrange
+import numpy as np
+from PIL import Image
+from tqdm import tqdm
+
+
+class HunyuanVideoPipeline(BasePipeline):
+
+ def __init__(self, device="cuda", torch_dtype=torch.float16):
+ super().__init__(device=device, torch_dtype=torch_dtype)
+ self.scheduler = FlowMatchScheduler(shift=7.0, sigma_min=0.0, extra_one_step=True)
+ self.prompter = HunyuanVideoPrompter()
+ self.text_encoder_1: SD3TextEncoder1 = None
+ self.text_encoder_2: HunyuanVideoLLMEncoder = None
+ self.dit: HunyuanVideoDiT = None
+ self.vae_decoder: HunyuanVideoVAEDecoder = None
+ self.vae_encoder: HunyuanVideoVAEEncoder = None
+ self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder', 'vae_encoder']
+ self.vram_management = False
+
+
+ def enable_vram_management(self):
+ self.vram_management = True
+ self.enable_cpu_offload()
+ self.text_encoder_2.enable_auto_offload(dtype=self.torch_dtype, device=self.device)
+ self.dit.enable_auto_offload(dtype=self.torch_dtype, device=self.device)
+
+
+ def fetch_models(self, model_manager: ModelManager):
+ self.text_encoder_1 = model_manager.fetch_model("sd3_text_encoder_1")
+ self.text_encoder_2 = model_manager.fetch_model("hunyuan_video_text_encoder_2")
+ self.dit = model_manager.fetch_model("hunyuan_video_dit")
+ self.vae_decoder = model_manager.fetch_model("hunyuan_video_vae_decoder")
+ self.vae_encoder = model_manager.fetch_model("hunyuan_video_vae_encoder")
+ self.prompter.fetch_models(self.text_encoder_1, self.text_encoder_2)
+
+
+ @staticmethod
+ def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None, enable_vram_management=True):
+ if device is None: device = model_manager.device
+ if torch_dtype is None: torch_dtype = model_manager.torch_dtype
+ pipe = HunyuanVideoPipeline(device=device, torch_dtype=torch_dtype)
+ pipe.fetch_models(model_manager)
+ if enable_vram_management:
+ pipe.enable_vram_management()
+ return pipe
+
+ def generate_crop_size_list(self, base_size=256, patch_size=32, max_ratio=4.0):
+ num_patches = round((base_size / patch_size)**2)
+ assert max_ratio >= 1.0
+ crop_size_list = []
+ wp, hp = num_patches, 1
+ while wp > 0:
+ if max(wp, hp) / min(wp, hp) <= max_ratio:
+ crop_size_list.append((wp * patch_size, hp * patch_size))
+ if (hp + 1) * wp <= num_patches:
+ hp += 1
+ else:
+ wp -= 1
+ return crop_size_list
+
+
+ def get_closest_ratio(self, height: float, width: float, ratios: list, buckets: list):
+ aspect_ratio = float(height) / float(width)
+ closest_ratio_id = np.abs(ratios - aspect_ratio).argmin()
+ closest_ratio = min(ratios, key=lambda ratio: abs(float(ratio) - aspect_ratio))
+ return buckets[closest_ratio_id], float(closest_ratio)
+
+
+ def prepare_vae_images_inputs(self, semantic_images, i2v_resolution="720p"):
+ if i2v_resolution == "720p":
+ bucket_hw_base_size = 960
+ elif i2v_resolution == "540p":
+ bucket_hw_base_size = 720
+ elif i2v_resolution == "360p":
+ bucket_hw_base_size = 480
+ else:
+ raise ValueError(f"i2v_resolution: {i2v_resolution} must be in [360p, 540p, 720p]")
+ origin_size = semantic_images[0].size
+
+ crop_size_list = self.generate_crop_size_list(bucket_hw_base_size, 32)
+ aspect_ratios = np.array([round(float(h) / float(w), 5) for h, w in crop_size_list])
+ closest_size, closest_ratio = self.get_closest_ratio(origin_size[1], origin_size[0], aspect_ratios, crop_size_list)
+ ref_image_transform = transforms.Compose([
+ transforms.Resize(closest_size),
+ transforms.CenterCrop(closest_size),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5])
+ ])
+
+ semantic_image_pixel_values = [ref_image_transform(semantic_image) for semantic_image in semantic_images]
+ semantic_image_pixel_values = torch.cat(semantic_image_pixel_values).unsqueeze(0).unsqueeze(2).to(self.device)
+ target_height, target_width = closest_size
+ return semantic_image_pixel_values, target_height, target_width
+
+
+ def encode_prompt(self, prompt, positive=True, clip_sequence_length=77, llm_sequence_length=256, input_images=None):
+ prompt_emb, pooled_prompt_emb, text_mask = self.prompter.encode_prompt(
+ prompt, device=self.device, positive=positive, clip_sequence_length=clip_sequence_length, llm_sequence_length=llm_sequence_length, images=input_images
+ )
+ return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb, "text_mask": text_mask}
+
+
+ def prepare_extra_input(self, latents=None, guidance=1.0):
+ freqs_cos, freqs_sin = self.dit.prepare_freqs(latents)
+ guidance = torch.Tensor([guidance] * latents.shape[0]).to(device=latents.device, dtype=latents.dtype)
+ return {"freqs_cos": freqs_cos, "freqs_sin": freqs_sin, "guidance": guidance}
+
+
+ def tensor2video(self, frames):
+ frames = rearrange(frames, "C T H W -> T H W C")
+ frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8)
+ frames = [Image.fromarray(frame) for frame in frames]
+ return frames
+
+
+ def encode_video(self, frames, tile_size=(17, 30, 30), tile_stride=(12, 20, 20)):
+ tile_size = ((tile_size[0] - 1) * 4 + 1, tile_size[1] * 8, tile_size[2] * 8)
+ tile_stride = (tile_stride[0] * 4, tile_stride[1] * 8, tile_stride[2] * 8)
+ latents = self.vae_encoder.encode_video(frames, tile_size=tile_size, tile_stride=tile_stride)
+ return latents
+
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt,
+ negative_prompt="",
+ input_video=None,
+ input_images=None,
+ i2v_resolution="720p",
+ i2v_stability=True,
+ denoising_strength=1.0,
+ seed=None,
+ rand_device=None,
+ height=720,
+ width=1280,
+ num_frames=129,
+ embedded_guidance=6.0,
+ cfg_scale=1.0,
+ num_inference_steps=30,
+ tea_cache_l1_thresh=None,
+ tile_size=(17, 30, 30),
+ tile_stride=(12, 20, 20),
+ step_processor=None,
+ progress_bar_cmd=lambda x: x,
+ progress_bar_st=None,
+ ):
+ # Tiler parameters
+ tiler_kwargs = {"tile_size": tile_size, "tile_stride": tile_stride}
+
+ # Scheduler
+ self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
+
+ # encoder input images
+ if input_images is not None:
+ self.load_models_to_device(['vae_encoder'])
+ image_pixel_values, height, width = self.prepare_vae_images_inputs(input_images, i2v_resolution=i2v_resolution)
+ with torch.autocast(device_type=self.device, dtype=torch.float16, enabled=True):
+ image_latents = self.vae_encoder(image_pixel_values)
+
+ # Initialize noise
+ rand_device = self.device if rand_device is None else rand_device
+ noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=rand_device, dtype=self.torch_dtype).to(self.device)
+ if input_video is not None:
+ self.load_models_to_device(['vae_encoder'])
+ input_video = self.preprocess_images(input_video)
+ input_video = torch.stack(input_video, dim=2)
+ latents = self.encode_video(input_video, **tiler_kwargs).to(dtype=self.torch_dtype, device=self.device)
+ latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
+ elif input_images is not None and i2v_stability:
+ noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=rand_device, dtype=image_latents.dtype).to(self.device)
+ t = torch.tensor([0.999]).to(device=self.device)
+ latents = noise * t + image_latents.repeat(1, 1, (num_frames - 1) // 4 + 1, 1, 1) * (1 - t)
+ latents = latents.to(dtype=image_latents.dtype)
+ else:
+ latents = noise
+
+ # Encode prompts
+ # current mllm does not support vram_management
+ self.load_models_to_device(["text_encoder_1"] if self.vram_management and input_images is None else ["text_encoder_1", "text_encoder_2"])
+ prompt_emb_posi = self.encode_prompt(prompt, positive=True, input_images=input_images)
+ if cfg_scale != 1.0:
+ prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False)
+
+ # Extra input
+ extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance)
+
+ # TeaCache
+ tea_cache_kwargs = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh) if tea_cache_l1_thresh is not None else None}
+
+ # Denoise
+ self.load_models_to_device([] if self.vram_management else ["dit"])
+ for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
+ timestep = timestep.unsqueeze(0).to(self.device)
+ print(f"Step {progress_id + 1} / {len(self.scheduler.timesteps)}")
+
+ forward_func = lets_dance_hunyuan_video
+ if input_images is not None:
+ latents = torch.concat([image_latents, latents[:, :, 1:, :, :]], dim=2)
+ forward_func = lets_dance_hunyuan_video_i2v
+
+ # Inference
+ with torch.autocast(device_type=self.device, dtype=self.torch_dtype):
+ noise_pred_posi = forward_func(self.dit, latents, timestep, **prompt_emb_posi, **extra_input, **tea_cache_kwargs)
+ if cfg_scale != 1.0:
+ noise_pred_nega = forward_func(self.dit, latents, timestep, **prompt_emb_nega, **extra_input)
+ noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
+ else:
+ noise_pred = noise_pred_posi
+
+ # (Experimental feature, may be removed in the future)
+ if step_processor is not None:
+ self.load_models_to_device(['vae_decoder'])
+ rendered_frames = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents, to_final=True)
+ rendered_frames = self.vae_decoder.decode_video(rendered_frames, **tiler_kwargs)
+ rendered_frames = self.tensor2video(rendered_frames[0])
+ rendered_frames = step_processor(rendered_frames, original_frames=input_video)
+ self.load_models_to_device(['vae_encoder'])
+ rendered_frames = self.preprocess_images(rendered_frames)
+ rendered_frames = torch.stack(rendered_frames, dim=2)
+ target_latents = self.encode_video(rendered_frames).to(dtype=self.torch_dtype, device=self.device)
+ noise_pred = self.scheduler.return_to_timestep(self.scheduler.timesteps[progress_id], latents, target_latents)
+ self.load_models_to_device([] if self.vram_management else ["dit"])
+
+ # Scheduler
+ if input_images is not None:
+ latents = self.scheduler.step(noise_pred[:, :, 1:, :, :], self.scheduler.timesteps[progress_id], latents[:, :, 1:, :, :])
+ latents = torch.concat([image_latents, latents], dim=2)
+ else:
+ latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
+
+ # Decode
+ self.load_models_to_device(['vae_decoder'])
+ frames = self.vae_decoder.decode_video(latents, **tiler_kwargs)
+ self.load_models_to_device([])
+ frames = self.tensor2video(frames[0])
+
+ return frames
+
+
+
+class TeaCache:
+ def __init__(self, num_inference_steps, rel_l1_thresh):
+ self.num_inference_steps = num_inference_steps
+ self.step = 0
+ self.accumulated_rel_l1_distance = 0
+ self.previous_modulated_input = None
+ self.rel_l1_thresh = rel_l1_thresh
+ self.previous_residual = None
+ self.previous_hidden_states = None
+
+ def check(self, dit: HunyuanVideoDiT, img, vec):
+ img_ = img.clone()
+ vec_ = vec.clone()
+ img_mod1_shift, img_mod1_scale, _, _, _, _ = dit.double_blocks[0].component_a.mod(vec_).chunk(6, dim=-1)
+ normed_inp = dit.double_blocks[0].component_a.norm1(img_)
+ modulated_inp = normed_inp * (1 + img_mod1_scale.unsqueeze(1)) + img_mod1_shift.unsqueeze(1)
+ if self.step == 0 or self.step == self.num_inference_steps - 1:
+ should_calc = True
+ self.accumulated_rel_l1_distance = 0
+ else:
+ coefficients = [7.33226126e+02, -4.01131952e+02, 6.75869174e+01, -3.14987800e+00, 9.61237896e-02]
+ rescale_func = np.poly1d(coefficients)
+ self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
+ if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
+ should_calc = False
+ else:
+ should_calc = True
+ self.accumulated_rel_l1_distance = 0
+ self.previous_modulated_input = modulated_inp
+ self.step += 1
+ if self.step == self.num_inference_steps:
+ self.step = 0
+ if should_calc:
+ self.previous_hidden_states = img.clone()
+ return not should_calc
+
+ def store(self, hidden_states):
+ self.previous_residual = hidden_states - self.previous_hidden_states
+ self.previous_hidden_states = None
+
+ def update(self, hidden_states):
+ hidden_states = hidden_states + self.previous_residual
+ return hidden_states
+
+
+
+def lets_dance_hunyuan_video(
+ dit: HunyuanVideoDiT,
+ x: torch.Tensor,
+ t: torch.Tensor,
+ prompt_emb: torch.Tensor = None,
+ text_mask: torch.Tensor = None,
+ pooled_prompt_emb: torch.Tensor = None,
+ freqs_cos: torch.Tensor = None,
+ freqs_sin: torch.Tensor = None,
+ guidance: torch.Tensor = None,
+ tea_cache: TeaCache = None,
+ **kwargs
+):
+ B, C, T, H, W = x.shape
+
+ vec = dit.time_in(t, dtype=torch.float32) + dit.vector_in(pooled_prompt_emb) + dit.guidance_in(guidance * 1000, dtype=torch.float32)
+ img = dit.img_in(x)
+ txt = dit.txt_in(prompt_emb, t, text_mask)
+
+ # TeaCache
+ if tea_cache is not None:
+ tea_cache_update = tea_cache.check(dit, img, vec)
+ else:
+ tea_cache_update = False
+
+ if tea_cache_update:
+ print("TeaCache skip forward.")
+ img = tea_cache.update(img)
+ else:
+ split_token = int(text_mask.sum(dim=1))
+ txt_len = int(txt.shape[1])
+ for block in tqdm(dit.double_blocks, desc="Double stream blocks"):
+ img, txt = block(img, txt, vec, (freqs_cos, freqs_sin), split_token=split_token)
+
+ x = torch.concat([img, txt], dim=1)
+ for block in tqdm(dit.single_blocks, desc="Single stream blocks"):
+ x = block(x, vec, (freqs_cos, freqs_sin), txt_len=txt_len, split_token=split_token)
+ img = x[:, :-txt_len]
+
+ if tea_cache is not None:
+ tea_cache.store(img)
+ img = dit.final_layer(img, vec)
+ img = dit.unpatchify(img, T=T//1, H=H//2, W=W//2)
+ return img
+
+
+def lets_dance_hunyuan_video_i2v(
+ dit: HunyuanVideoDiT,
+ x: torch.Tensor,
+ t: torch.Tensor,
+ prompt_emb: torch.Tensor = None,
+ text_mask: torch.Tensor = None,
+ pooled_prompt_emb: torch.Tensor = None,
+ freqs_cos: torch.Tensor = None,
+ freqs_sin: torch.Tensor = None,
+ guidance: torch.Tensor = None,
+ tea_cache: TeaCache = None,
+ **kwargs
+):
+ B, C, T, H, W = x.shape
+ # Uncomment below to keep same as official implementation
+ # guidance = guidance.to(dtype=torch.float32).to(torch.bfloat16)
+ vec = dit.time_in(t, dtype=torch.bfloat16)
+ vec_2 = dit.vector_in(pooled_prompt_emb)
+ vec = vec + vec_2
+ vec = vec + dit.guidance_in(guidance * 1000., dtype=torch.bfloat16)
+
+ token_replace_vec = dit.time_in(torch.zeros_like(t), dtype=torch.bfloat16)
+ tr_token = (H // 2) * (W // 2)
+ token_replace_vec = token_replace_vec + vec_2
+
+ img = dit.img_in(x)
+ txt = dit.txt_in(prompt_emb, t, text_mask)
+
+ # TeaCache
+ if tea_cache is not None:
+ tea_cache_update = tea_cache.check(dit, img, vec)
+ else:
+ tea_cache_update = False
+
+ if tea_cache_update:
+ print("TeaCache skip forward.")
+ img = tea_cache.update(img)
+ else:
+ split_token = int(text_mask.sum(dim=1))
+ txt_len = int(txt.shape[1])
+ for block in tqdm(dit.double_blocks, desc="Double stream blocks"):
+ img, txt = block(img, txt, vec, (freqs_cos, freqs_sin), token_replace_vec, tr_token, split_token)
+
+ x = torch.concat([img, txt], dim=1)
+ for block in tqdm(dit.single_blocks, desc="Single stream blocks"):
+ x = block(x, vec, (freqs_cos, freqs_sin), txt_len, token_replace_vec, tr_token, split_token)
+ img = x[:, :-txt_len]
+
+ if tea_cache is not None:
+ tea_cache.store(img)
+ img = dit.final_layer(img, vec)
+ img = dit.unpatchify(img, T=T//1, H=H//2, W=W//2)
+ return img
diff --git a/diffsynth/pipelines/omnigen_image.py b/diffsynth/pipelines/omnigen_image.py
new file mode 100644
index 0000000000000000000000000000000000000000..ddb2ae656639550084b7143fe690186602c0387d
--- /dev/null
+++ b/diffsynth/pipelines/omnigen_image.py
@@ -0,0 +1,289 @@
+from ..models.omnigen import OmniGenTransformer
+from ..models.sdxl_vae_encoder import SDXLVAEEncoder
+from ..models.sdxl_vae_decoder import SDXLVAEDecoder
+from ..models.model_manager import ModelManager
+from ..prompters.omnigen_prompter import OmniGenPrompter
+from ..schedulers import FlowMatchScheduler
+from .base import BasePipeline
+from typing import Optional, Dict, Any, Tuple, List
+from transformers.cache_utils import DynamicCache
+import torch, os
+from tqdm import tqdm
+
+
+
+class OmniGenCache(DynamicCache):
+ def __init__(self,
+ num_tokens_for_img: int, offload_kv_cache: bool=False) -> None:
+ if not torch.cuda.is_available():
+ print("No available GPU, offload_kv_cache will be set to False, which will result in large memory usage and time cost when input multiple images!!!")
+ offload_kv_cache = False
+ raise RuntimeError("OffloadedCache can only be used with a GPU")
+ super().__init__()
+ self.original_device = []
+ self.prefetch_stream = torch.cuda.Stream()
+ self.num_tokens_for_img = num_tokens_for_img
+ self.offload_kv_cache = offload_kv_cache
+
+ def prefetch_layer(self, layer_idx: int):
+ "Starts prefetching the next layer cache"
+ if layer_idx < len(self):
+ with torch.cuda.stream(self.prefetch_stream):
+ # Prefetch next layer tensors to GPU
+ device = self.original_device[layer_idx]
+ self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device, non_blocking=True)
+ self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device, non_blocking=True)
+
+
+ def evict_previous_layer(self, layer_idx: int):
+ "Moves the previous layer cache to the CPU"
+ if len(self) > 2:
+ # We do it on the default stream so it occurs after all earlier computations on these tensors are done
+ if layer_idx == 0:
+ prev_layer_idx = -1
+ else:
+ prev_layer_idx = (layer_idx - 1) % len(self)
+ self.key_cache[prev_layer_idx] = self.key_cache[prev_layer_idx].to("cpu", non_blocking=True)
+ self.value_cache[prev_layer_idx] = self.value_cache[prev_layer_idx].to("cpu", non_blocking=True)
+
+
+ def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
+ "Gets the cache for this layer to the device. Prefetches the next and evicts the previous layer."
+ if layer_idx < len(self):
+ if self.offload_kv_cache:
+ # Evict the previous layer if necessary
+ torch.cuda.current_stream().synchronize()
+ self.evict_previous_layer(layer_idx)
+ # Load current layer cache to its original device if not already there
+ original_device = self.original_device[layer_idx]
+ # self.prefetch_stream.synchronize(original_device)
+ torch.cuda.synchronize(self.prefetch_stream)
+ key_tensor = self.key_cache[layer_idx]
+ value_tensor = self.value_cache[layer_idx]
+
+ # Prefetch the next layer
+ self.prefetch_layer((layer_idx + 1) % len(self))
+ else:
+ key_tensor = self.key_cache[layer_idx]
+ value_tensor = self.value_cache[layer_idx]
+ return (key_tensor, value_tensor)
+ else:
+ raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
+
+
+ def update(
+ self,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ layer_idx: int,
+ cache_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
+ Parameters:
+ key_states (`torch.Tensor`):
+ The new key states to cache.
+ value_states (`torch.Tensor`):
+ The new value states to cache.
+ layer_idx (`int`):
+ The index of the layer to cache the states for.
+ cache_kwargs (`Dict[str, Any]`, `optional`):
+ Additional arguments for the cache subclass. No additional arguments are used in `OffloadedCache`.
+ Return:
+ A tuple containing the updated key and value states.
+ """
+ # Update the cache
+ if len(self.key_cache) < layer_idx:
+ raise ValueError("OffloadedCache does not support model usage where layers are skipped. Use DynamicCache.")
+ elif len(self.key_cache) == layer_idx:
+ # only cache the states for condition tokens
+ key_states = key_states[..., :-(self.num_tokens_for_img+1), :]
+ value_states = value_states[..., :-(self.num_tokens_for_img+1), :]
+
+ # Update the number of seen tokens
+ if layer_idx == 0:
+ self._seen_tokens += key_states.shape[-2]
+
+ self.key_cache.append(key_states)
+ self.value_cache.append(value_states)
+ self.original_device.append(key_states.device)
+ if self.offload_kv_cache:
+ self.evict_previous_layer(layer_idx)
+ return self.key_cache[layer_idx], self.value_cache[layer_idx]
+ else:
+ # only cache the states for condition tokens
+ key_tensor, value_tensor = self[layer_idx]
+ k = torch.cat([key_tensor, key_states], dim=-2)
+ v = torch.cat([value_tensor, value_states], dim=-2)
+ return k, v
+
+
+
+class OmnigenImagePipeline(BasePipeline):
+
+ def __init__(self, device="cuda", torch_dtype=torch.float16):
+ super().__init__(device=device, torch_dtype=torch_dtype)
+ self.scheduler = FlowMatchScheduler(num_train_timesteps=1, shift=1, inverse_timesteps=True, sigma_min=0, sigma_max=1)
+ # models
+ self.vae_decoder: SDXLVAEDecoder = None
+ self.vae_encoder: SDXLVAEEncoder = None
+ self.transformer: OmniGenTransformer = None
+ self.prompter: OmniGenPrompter = None
+ self.model_names = ['transformer', 'vae_decoder', 'vae_encoder']
+
+
+ def denoising_model(self):
+ return self.transformer
+
+
+ def fetch_models(self, model_manager: ModelManager, prompt_refiner_classes=[]):
+ # Main models
+ self.transformer, model_path = model_manager.fetch_model("omnigen_transformer", require_model_path=True)
+ self.vae_decoder = model_manager.fetch_model("sdxl_vae_decoder")
+ self.vae_encoder = model_manager.fetch_model("sdxl_vae_encoder")
+ self.prompter = OmniGenPrompter.from_pretrained(os.path.dirname(model_path))
+
+
+ @staticmethod
+ def from_model_manager(model_manager: ModelManager, prompt_refiner_classes=[], device=None):
+ pipe = OmnigenImagePipeline(
+ device=model_manager.device if device is None else device,
+ torch_dtype=model_manager.torch_dtype,
+ )
+ pipe.fetch_models(model_manager, prompt_refiner_classes=[])
+ return pipe
+
+
+ def encode_image(self, image, tiled=False, tile_size=64, tile_stride=32):
+ latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
+ return latents
+
+
+ def encode_images(self, images, tiled=False, tile_size=64, tile_stride=32):
+ latents = [self.encode_image(image.to(device=self.device), tiled, tile_size, tile_stride).to(self.torch_dtype) for image in images]
+ return latents
+
+
+ def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
+ image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
+ image = self.vae_output_to_image(image)
+ return image
+
+
+ def encode_prompt(self, prompt, clip_skip=1, positive=True):
+ prompt_emb = self.prompter.encode_prompt(prompt, clip_skip=clip_skip, device=self.device, positive=positive)
+ return {"encoder_hidden_states": prompt_emb}
+
+
+ def prepare_extra_input(self, latents=None):
+ return {}
+
+
+ def crop_position_ids_for_cache(self, position_ids, num_tokens_for_img):
+ if isinstance(position_ids, list):
+ for i in range(len(position_ids)):
+ position_ids[i] = position_ids[i][:, -(num_tokens_for_img+1):]
+ else:
+ position_ids = position_ids[:, -(num_tokens_for_img+1):]
+ return position_ids
+
+
+ def crop_attention_mask_for_cache(self, attention_mask, num_tokens_for_img):
+ if isinstance(attention_mask, list):
+ return [x[..., -(num_tokens_for_img+1):, :] for x in attention_mask]
+ return attention_mask[..., -(num_tokens_for_img+1):, :]
+
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt,
+ reference_images=[],
+ cfg_scale=2.0,
+ image_cfg_scale=2.0,
+ use_kv_cache=True,
+ offload_kv_cache=True,
+ input_image=None,
+ denoising_strength=1.0,
+ height=1024,
+ width=1024,
+ num_inference_steps=20,
+ tiled=False,
+ tile_size=64,
+ tile_stride=32,
+ seed=None,
+ progress_bar_cmd=tqdm,
+ progress_bar_st=None,
+ ):
+ height, width = self.check_resize_height_width(height, width)
+
+ # Tiler parameters
+ tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
+
+ # Prepare scheduler
+ self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
+
+ # Prepare latent tensors
+ if input_image is not None:
+ self.load_models_to_device(['vae_encoder'])
+ image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
+ latents = self.encode_image(image, **tiler_kwargs)
+ noise = self.generate_noise((1, 4, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
+ latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
+ else:
+ latents = self.generate_noise((1, 4, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
+ latents = latents.repeat(3, 1, 1, 1)
+
+ # Encode prompts
+ input_data = self.prompter(prompt, reference_images, height=height, width=width, use_img_cfg=True, separate_cfg_input=True, use_input_image_size_as_output=False)
+
+ # Encode images
+ reference_latents = [self.encode_images(images, **tiler_kwargs) for images in input_data['input_pixel_values']]
+
+ # Pack all parameters
+ model_kwargs = dict(input_ids=[input_ids.to(self.device) for input_ids in input_data['input_ids']],
+ input_img_latents=reference_latents,
+ input_image_sizes=input_data['input_image_sizes'],
+ attention_mask=[attention_mask.to(self.device) for attention_mask in input_data["attention_mask"]],
+ position_ids=[position_ids.to(self.device) for position_ids in input_data["position_ids"]],
+ cfg_scale=cfg_scale,
+ img_cfg_scale=image_cfg_scale,
+ use_img_cfg=True,
+ use_kv_cache=use_kv_cache,
+ offload_model=False,
+ )
+
+ # Denoise
+ self.load_models_to_device(['transformer'])
+ cache = [OmniGenCache(latents.size(-1)*latents.size(-2) // 4, offload_kv_cache) for _ in range(len(model_kwargs['input_ids']))] if use_kv_cache else None
+ for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
+ timestep = timestep.unsqueeze(0).repeat(latents.shape[0]).to(self.device)
+
+ # Forward
+ noise_pred, cache = self.transformer.forward_with_separate_cfg(latents, timestep, past_key_values=cache, **model_kwargs)
+
+ # Scheduler
+ latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
+
+ # Update KV cache
+ if progress_id == 0 and use_kv_cache:
+ num_tokens_for_img = latents.size(-1)*latents.size(-2) // 4
+ if isinstance(cache, list):
+ model_kwargs['input_ids'] = [None] * len(cache)
+ else:
+ model_kwargs['input_ids'] = None
+ model_kwargs['position_ids'] = self.crop_position_ids_for_cache(model_kwargs['position_ids'], num_tokens_for_img)
+ model_kwargs['attention_mask'] = self.crop_attention_mask_for_cache(model_kwargs['attention_mask'], num_tokens_for_img)
+
+ # UI
+ if progress_bar_st is not None:
+ progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
+
+ # Decode image
+ del cache
+ self.load_models_to_device(['vae_decoder'])
+ image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
+
+ # offload all models
+ self.load_models_to_device([])
+ return image
diff --git a/diffsynth/pipelines/pipeline_runner.py b/diffsynth/pipelines/pipeline_runner.py
new file mode 100644
index 0000000000000000000000000000000000000000..1b842f9bd7b25edca1c9951e67ebe5c364deca81
--- /dev/null
+++ b/diffsynth/pipelines/pipeline_runner.py
@@ -0,0 +1,105 @@
+import os, torch, json
+from .sd_video import ModelManager, SDVideoPipeline, ControlNetConfigUnit
+from ..processors.sequencial_processor import SequencialProcessor
+from ..data import VideoData, save_frames, save_video
+
+
+
+class SDVideoPipelineRunner:
+ def __init__(self, in_streamlit=False):
+ self.in_streamlit = in_streamlit
+
+
+ def load_pipeline(self, model_list, textual_inversion_folder, device, lora_alphas, controlnet_units):
+ # Load models
+ model_manager = ModelManager(torch_dtype=torch.float16, device=device)
+ model_manager.load_models(model_list)
+ pipe = SDVideoPipeline.from_model_manager(
+ model_manager,
+ [
+ ControlNetConfigUnit(
+ processor_id=unit["processor_id"],
+ model_path=unit["model_path"],
+ scale=unit["scale"]
+ ) for unit in controlnet_units
+ ]
+ )
+ textual_inversion_paths = []
+ for file_name in os.listdir(textual_inversion_folder):
+ if file_name.endswith(".pt") or file_name.endswith(".bin") or file_name.endswith(".pth") or file_name.endswith(".safetensors"):
+ textual_inversion_paths.append(os.path.join(textual_inversion_folder, file_name))
+ pipe.prompter.load_textual_inversions(textual_inversion_paths)
+ return model_manager, pipe
+
+
+ def load_smoother(self, model_manager, smoother_configs):
+ smoother = SequencialProcessor.from_model_manager(model_manager, smoother_configs)
+ return smoother
+
+
+ def synthesize_video(self, model_manager, pipe, seed, smoother, **pipeline_inputs):
+ torch.manual_seed(seed)
+ if self.in_streamlit:
+ import streamlit as st
+ progress_bar_st = st.progress(0.0)
+ output_video = pipe(**pipeline_inputs, smoother=smoother, progress_bar_st=progress_bar_st)
+ progress_bar_st.progress(1.0)
+ else:
+ output_video = pipe(**pipeline_inputs, smoother=smoother)
+ model_manager.to("cpu")
+ return output_video
+
+
+ def load_video(self, video_file, image_folder, height, width, start_frame_id, end_frame_id):
+ video = VideoData(video_file=video_file, image_folder=image_folder, height=height, width=width)
+ if start_frame_id is None:
+ start_frame_id = 0
+ if end_frame_id is None:
+ end_frame_id = len(video)
+ frames = [video[i] for i in range(start_frame_id, end_frame_id)]
+ return frames
+
+
+ def add_data_to_pipeline_inputs(self, data, pipeline_inputs):
+ pipeline_inputs["input_frames"] = self.load_video(**data["input_frames"])
+ pipeline_inputs["num_frames"] = len(pipeline_inputs["input_frames"])
+ pipeline_inputs["width"], pipeline_inputs["height"] = pipeline_inputs["input_frames"][0].size
+ if len(data["controlnet_frames"]) > 0:
+ pipeline_inputs["controlnet_frames"] = [self.load_video(**unit) for unit in data["controlnet_frames"]]
+ return pipeline_inputs
+
+
+ def save_output(self, video, output_folder, fps, config):
+ os.makedirs(output_folder, exist_ok=True)
+ save_frames(video, os.path.join(output_folder, "frames"))
+ save_video(video, os.path.join(output_folder, "video.mp4"), fps=fps)
+ config["pipeline"]["pipeline_inputs"]["input_frames"] = []
+ config["pipeline"]["pipeline_inputs"]["controlnet_frames"] = []
+ with open(os.path.join(output_folder, "config.json"), 'w') as file:
+ json.dump(config, file, indent=4)
+
+
+ def run(self, config):
+ if self.in_streamlit:
+ import streamlit as st
+ if self.in_streamlit: st.markdown("Loading videos ...")
+ config["pipeline"]["pipeline_inputs"] = self.add_data_to_pipeline_inputs(config["data"], config["pipeline"]["pipeline_inputs"])
+ if self.in_streamlit: st.markdown("Loading videos ... done!")
+ if self.in_streamlit: st.markdown("Loading models ...")
+ model_manager, pipe = self.load_pipeline(**config["models"])
+ if self.in_streamlit: st.markdown("Loading models ... done!")
+ if "smoother_configs" in config:
+ if self.in_streamlit: st.markdown("Loading smoother ...")
+ smoother = self.load_smoother(model_manager, config["smoother_configs"])
+ if self.in_streamlit: st.markdown("Loading smoother ... done!")
+ else:
+ smoother = None
+ if self.in_streamlit: st.markdown("Synthesizing videos ...")
+ output_video = self.synthesize_video(model_manager, pipe, config["pipeline"]["seed"], smoother, **config["pipeline"]["pipeline_inputs"])
+ if self.in_streamlit: st.markdown("Synthesizing videos ... done!")
+ if self.in_streamlit: st.markdown("Saving videos ...")
+ self.save_output(output_video, config["data"]["output_folder"], config["data"]["fps"], config)
+ if self.in_streamlit: st.markdown("Saving videos ... done!")
+ if self.in_streamlit: st.markdown("Finished!")
+ video_file = open(os.path.join(os.path.join(config["data"]["output_folder"], "video.mp4")), 'rb')
+ if self.in_streamlit: st.video(video_file.read())
diff --git a/diffsynth/pipelines/qwen_image.py b/diffsynth/pipelines/qwen_image.py
new file mode 100644
index 0000000000000000000000000000000000000000..e949b56e02492db9ca8b514a18c106c108c9efd0
--- /dev/null
+++ b/diffsynth/pipelines/qwen_image.py
@@ -0,0 +1,861 @@
+import torch
+from PIL import Image
+from typing import Union
+from PIL import Image
+from tqdm import tqdm
+from einops import rearrange
+import numpy as np
+
+from ..models import ModelManager, load_state_dict
+from ..models.qwen_image_dit import QwenImageDiT
+from ..models.qwen_image_text_encoder import QwenImageTextEncoder
+from ..models.qwen_image_vae import QwenImageVAE
+from ..models.qwen_image_controlnet import QwenImageBlockWiseControlNet
+from ..schedulers import FlowMatchScheduler
+from ..utils import BasePipeline, ModelConfig, PipelineUnitRunner, PipelineUnit
+from ..lora import GeneralLoRALoader
+from .flux_image_new import ControlNetInput
+
+from ..vram_management import gradient_checkpoint_forward, enable_vram_management, AutoWrappedModule, AutoWrappedLinear
+
+
+class QwenImageBlockwiseMultiControlNet(torch.nn.Module):
+ def __init__(self, models: list[QwenImageBlockWiseControlNet]):
+ super().__init__()
+ if not isinstance(models, list):
+ models = [models]
+ self.models = torch.nn.ModuleList(models)
+
+ def preprocess(self, controlnet_inputs: list[ControlNetInput], conditionings: list[torch.Tensor], **kwargs):
+ processed_conditionings = []
+ for controlnet_input, conditioning in zip(controlnet_inputs, conditionings):
+ conditioning = rearrange(conditioning, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2)
+ model_output = self.models[controlnet_input.controlnet_id].process_controlnet_conditioning(conditioning)
+ processed_conditionings.append(model_output)
+ return processed_conditionings
+
+ def blockwise_forward(self, image, conditionings: list[torch.Tensor], controlnet_inputs: list[ControlNetInput], progress_id, num_inference_steps, block_id, **kwargs):
+ res = 0
+ for controlnet_input, conditioning in zip(controlnet_inputs, conditionings):
+ progress = (num_inference_steps - 1 - progress_id) / max(num_inference_steps - 1, 1)
+ if progress > controlnet_input.start + (1e-4) or progress < controlnet_input.end - (1e-4):
+ continue
+ model_output = self.models[controlnet_input.controlnet_id].blockwise_forward(image, conditioning, block_id)
+ res = res + model_output * controlnet_input.scale
+ return res
+
+
+class QwenImagePipeline(BasePipeline):
+
+ def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
+ super().__init__(
+ device=device, torch_dtype=torch_dtype,
+ height_division_factor=16, width_division_factor=16,
+ )
+ from transformers import Qwen2Tokenizer, Qwen2VLProcessor
+
+ self.scheduler = FlowMatchScheduler(sigma_min=0, sigma_max=1, extra_one_step=True, exponential_shift=True, exponential_shift_mu=0.8, shift_terminal=0.02)
+ self.text_encoder: QwenImageTextEncoder = None
+ self.dit: QwenImageDiT = None
+ self.vae: QwenImageVAE = None
+ self.blockwise_controlnet: QwenImageBlockwiseMultiControlNet = None
+ self.tokenizer: Qwen2Tokenizer = None
+ self.processor: Qwen2VLProcessor = None
+ self.unit_runner = PipelineUnitRunner()
+ self.in_iteration_models = ("dit", "blockwise_controlnet")
+ self.units = [
+ QwenImageUnit_ShapeChecker(),
+ QwenImageUnit_NoiseInitializer(),
+ QwenImageUnit_InputImageEmbedder(),
+ QwenImageUnit_Inpaint(),
+ QwenImageUnit_EditImageEmbedder(),
+ QwenImageUnit_ContextImageEmbedder(),
+ QwenImageUnit_PromptEmbedder(),
+ QwenImageUnit_EntityControl(),
+ QwenImageUnit_BlockwiseControlNet(),
+ ]
+ self.model_fn = model_fn_qwen_image
+
+
+ def load_lora(
+ self,
+ module: torch.nn.Module,
+ lora_config: Union[ModelConfig, str] = None,
+ alpha=1,
+ hotload=False,
+ state_dict=None,
+ ):
+ if state_dict is None:
+ if isinstance(lora_config, str):
+ lora = load_state_dict(lora_config, torch_dtype=self.torch_dtype, device=self.device)
+ else:
+ lora_config.download_if_necessary()
+ lora = load_state_dict(lora_config.path, torch_dtype=self.torch_dtype, device=self.device)
+ else:
+ lora = state_dict
+ if hotload:
+ for name, module in module.named_modules():
+ if isinstance(module, AutoWrappedLinear):
+ lora_a_name = f'{name}.lora_A.default.weight'
+ lora_b_name = f'{name}.lora_B.default.weight'
+ if lora_a_name in lora and lora_b_name in lora:
+ module.lora_A_weights.append(lora[lora_a_name] * alpha)
+ module.lora_B_weights.append(lora[lora_b_name])
+ else:
+ loader = GeneralLoRALoader(torch_dtype=self.torch_dtype, device=self.device)
+ loader.load(module, lora, alpha=alpha)
+
+
+ def clear_lora(self):
+ for name, module in self.named_modules():
+ if isinstance(module, AutoWrappedLinear):
+ if hasattr(module, "lora_A_weights"):
+ module.lora_A_weights.clear()
+ if hasattr(module, "lora_B_weights"):
+ module.lora_B_weights.clear()
+
+
+ def enable_lora_magic(self):
+ if self.dit is not None:
+ if not (hasattr(self.dit, "vram_management_enabled") and self.dit.vram_management_enabled):
+ dtype = next(iter(self.dit.parameters())).dtype
+ enable_vram_management(
+ self.dit,
+ module_map = {
+ torch.nn.Linear: AutoWrappedLinear,
+ },
+ module_config = dict(
+ offload_dtype=dtype,
+ offload_device=self.device,
+ onload_dtype=dtype,
+ onload_device=self.device,
+ computation_dtype=self.torch_dtype,
+ computation_device=self.device,
+ ),
+ vram_limit=None,
+ )
+
+
+ def training_loss(self, **inputs):
+ timestep_id = torch.randint(0, self.scheduler.num_train_timesteps, (1,))
+ timestep = self.scheduler.timesteps[timestep_id].to(dtype=self.torch_dtype, device=self.device)
+
+ noise = torch.randn_like(inputs["input_latents"])
+ inputs["latents"] = self.scheduler.add_noise(inputs["input_latents"], noise, timestep)
+ training_target = self.scheduler.training_target(inputs["input_latents"], noise, timestep)
+
+ noise_pred = self.model_fn(**inputs, timestep=timestep)
+
+ loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
+ loss = loss * self.scheduler.training_weight(timestep)
+ return loss
+
+
+ def direct_distill_loss(self, **inputs):
+ self.scheduler.set_timesteps(inputs["num_inference_steps"])
+ models = {name: getattr(self, name) for name in self.in_iteration_models}
+ for progress_id, timestep in enumerate(self.scheduler.timesteps):
+ timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
+ noise_pred = self.model_fn(**models, **inputs, timestep=timestep, progress_id=progress_id)
+ inputs["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs)
+ loss = torch.nn.functional.mse_loss(inputs["latents"].float(), inputs["input_latents"].float())
+ return loss
+
+
+ def _enable_fp8_lora_training(self, dtype):
+ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLRotaryEmbedding, Qwen2RMSNorm, Qwen2_5_VisionPatchEmbed, Qwen2_5_VisionRotaryEmbedding
+ from ..models.qwen_image_dit import RMSNorm
+ from ..models.qwen_image_vae import QwenImageRMS_norm
+ module_map = {
+ RMSNorm: AutoWrappedModule,
+ torch.nn.Linear: AutoWrappedLinear,
+ torch.nn.Conv3d: AutoWrappedModule,
+ torch.nn.Conv2d: AutoWrappedModule,
+ torch.nn.Embedding: AutoWrappedModule,
+ Qwen2_5_VLRotaryEmbedding: AutoWrappedModule,
+ Qwen2RMSNorm: AutoWrappedModule,
+ Qwen2_5_VisionPatchEmbed: AutoWrappedModule,
+ Qwen2_5_VisionRotaryEmbedding: AutoWrappedModule,
+ QwenImageRMS_norm: AutoWrappedModule,
+ }
+ model_config = dict(
+ offload_dtype=dtype,
+ offload_device="cuda",
+ onload_dtype=dtype,
+ onload_device="cuda",
+ computation_dtype=self.torch_dtype,
+ computation_device="cuda",
+ )
+ if self.text_encoder is not None:
+ enable_vram_management(self.text_encoder, module_map=module_map, module_config=model_config)
+ if self.dit is not None:
+ enable_vram_management(self.dit, module_map=module_map, module_config=model_config)
+ if self.vae is not None:
+ enable_vram_management(self.vae, module_map=module_map, module_config=model_config)
+
+
+ def enable_vram_management(self, num_persistent_param_in_dit=None, vram_limit=None, vram_buffer=0.5, auto_offload=True, enable_dit_fp8_computation=False):
+ self.vram_management_enabled = True
+ if vram_limit is None and auto_offload:
+ vram_limit = self.get_vram()
+ if vram_limit is not None:
+ vram_limit = vram_limit - vram_buffer
+
+ if self.text_encoder is not None:
+ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLRotaryEmbedding, Qwen2RMSNorm, Qwen2_5_VisionPatchEmbed, Qwen2_5_VisionRotaryEmbedding
+ dtype = next(iter(self.text_encoder.parameters())).dtype
+ enable_vram_management(
+ self.text_encoder,
+ module_map = {
+ torch.nn.Linear: AutoWrappedLinear,
+ torch.nn.Embedding: AutoWrappedModule,
+ Qwen2_5_VLRotaryEmbedding: AutoWrappedModule,
+ Qwen2RMSNorm: AutoWrappedModule,
+ Qwen2_5_VisionPatchEmbed: AutoWrappedModule,
+ Qwen2_5_VisionRotaryEmbedding: AutoWrappedModule,
+ },
+ module_config = dict(
+ offload_dtype=dtype,
+ offload_device="cpu",
+ onload_dtype=dtype,
+ onload_device="cpu",
+ computation_dtype=self.torch_dtype,
+ computation_device=self.device,
+ ),
+ vram_limit=vram_limit,
+ )
+ if self.dit is not None:
+ from ..models.qwen_image_dit import RMSNorm
+ dtype = next(iter(self.dit.parameters())).dtype
+ device = "cpu" if vram_limit is not None else self.device
+ if not enable_dit_fp8_computation:
+ enable_vram_management(
+ self.dit,
+ module_map = {
+ RMSNorm: AutoWrappedModule,
+ torch.nn.Linear: AutoWrappedLinear,
+ },
+ module_config = dict(
+ offload_dtype=dtype,
+ offload_device="cpu",
+ onload_dtype=dtype,
+ onload_device=device,
+ computation_dtype=self.torch_dtype,
+ computation_device=self.device,
+ ),
+ vram_limit=vram_limit,
+ )
+ else:
+ enable_vram_management(
+ self.dit,
+ module_map = {
+ RMSNorm: AutoWrappedModule,
+ },
+ module_config = dict(
+ offload_dtype=dtype,
+ offload_device="cpu",
+ onload_dtype=dtype,
+ onload_device=device,
+ computation_dtype=self.torch_dtype,
+ computation_device=self.device,
+ ),
+ vram_limit=vram_limit,
+ )
+ enable_vram_management(
+ self.dit,
+ module_map = {
+ torch.nn.Linear: AutoWrappedLinear,
+ },
+ module_config = dict(
+ offload_dtype=dtype,
+ offload_device="cpu",
+ onload_dtype=dtype,
+ onload_device=device,
+ computation_dtype=dtype,
+ computation_device=self.device,
+ ),
+ vram_limit=vram_limit,
+ )
+ if self.vae is not None:
+ from ..models.qwen_image_vae import QwenImageRMS_norm
+ dtype = next(iter(self.vae.parameters())).dtype
+ enable_vram_management(
+ self.vae,
+ module_map = {
+ torch.nn.Linear: AutoWrappedLinear,
+ torch.nn.Conv3d: AutoWrappedModule,
+ torch.nn.Conv2d: AutoWrappedModule,
+ QwenImageRMS_norm: AutoWrappedModule,
+ },
+ module_config = dict(
+ offload_dtype=dtype,
+ offload_device="cpu",
+ onload_dtype=dtype,
+ onload_device="cpu",
+ computation_dtype=self.torch_dtype,
+ computation_device=self.device,
+ ),
+ vram_limit=vram_limit,
+ )
+ if self.blockwise_controlnet is not None:
+ enable_vram_management(
+ self.blockwise_controlnet,
+ module_map = {
+ RMSNorm: AutoWrappedModule,
+ torch.nn.Linear: AutoWrappedLinear,
+ },
+ module_config = dict(
+ offload_dtype=dtype,
+ offload_device="cpu",
+ onload_dtype=dtype,
+ onload_device=device,
+ computation_dtype=self.torch_dtype,
+ computation_device=self.device,
+ ),
+ vram_limit=vram_limit,
+ )
+
+
+ @staticmethod
+ def from_pretrained(
+ torch_dtype: torch.dtype = torch.bfloat16,
+ device: Union[str, torch.device] = "cuda",
+ model_configs: list[ModelConfig] = [],
+ tokenizer_config: ModelConfig = ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
+ processor_config: ModelConfig = None,
+ ):
+ # Download and load models
+ model_manager = ModelManager()
+ for model_config in model_configs:
+ model_config.download_if_necessary()
+ model_manager.load_model(
+ model_config.path,
+ device=model_config.offload_device or device,
+ torch_dtype=model_config.offload_dtype or torch_dtype
+ )
+
+ # Initialize pipeline
+ pipe = QwenImagePipeline(device=device, torch_dtype=torch_dtype)
+ pipe.text_encoder = model_manager.fetch_model("qwen_image_text_encoder")
+ pipe.dit = model_manager.fetch_model("qwen_image_dit")
+ pipe.vae = model_manager.fetch_model("qwen_image_vae")
+ pipe.blockwise_controlnet = QwenImageBlockwiseMultiControlNet(model_manager.fetch_model("qwen_image_blockwise_controlnet", index="all"))
+ if tokenizer_config is not None and pipe.text_encoder is not None:
+ tokenizer_config.download_if_necessary()
+ from transformers import Qwen2Tokenizer
+ pipe.tokenizer = Qwen2Tokenizer.from_pretrained(tokenizer_config.path)
+ if processor_config is not None:
+ processor_config.download_if_necessary()
+ from transformers import Qwen2VLProcessor
+ pipe.processor = Qwen2VLProcessor.from_pretrained(processor_config.path)
+ return pipe
+
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ # Prompt
+ prompt: str,
+ negative_prompt: str = "",
+ cfg_scale: float = 4.0,
+ # Image
+ input_image: Image.Image = None,
+ denoising_strength: float = 1.0,
+ # Inpaint
+ inpaint_mask: Image.Image = None,
+ inpaint_blur_size: int = None,
+ inpaint_blur_sigma: float = None,
+ # Shape
+ height: int = 1328,
+ width: int = 1328,
+ # Randomness
+ seed: int = None,
+ rand_device: str = "cpu",
+ # Steps
+ num_inference_steps: int = 30,
+ exponential_shift_mu: float = None,
+ # Blockwise ControlNet
+ blockwise_controlnet_inputs: list[ControlNetInput] = None,
+ # EliGen
+ eligen_entity_prompts: list[str] = None,
+ eligen_entity_masks: list[Image.Image] = None,
+ eligen_enable_on_negative: bool = False,
+ # Qwen-Image-Edit
+ edit_image: Image.Image = None,
+ edit_image_auto_resize: bool = True,
+ edit_rope_interpolation: bool = False,
+ # In-context control
+ context_image: Image.Image = None,
+ # FP8
+ enable_fp8_attention: bool = False,
+ # Tile
+ tiled: bool = False,
+ tile_size: int = 128,
+ tile_stride: int = 64,
+ # Progress bar
+ progress_bar_cmd = tqdm,
+ ):
+ # Scheduler
+ self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=(height // 16) * (width // 16), exponential_shift_mu=exponential_shift_mu)
+
+ # Parameters
+ inputs_posi = {
+ "prompt": prompt,
+ }
+ inputs_nega = {
+ "negative_prompt": negative_prompt,
+ }
+ inputs_shared = {
+ "cfg_scale": cfg_scale,
+ "input_image": input_image, "denoising_strength": denoising_strength,
+ "inpaint_mask": inpaint_mask, "inpaint_blur_size": inpaint_blur_size, "inpaint_blur_sigma": inpaint_blur_sigma,
+ "height": height, "width": width,
+ "seed": seed, "rand_device": rand_device,
+ "enable_fp8_attention": enable_fp8_attention,
+ "num_inference_steps": num_inference_steps,
+ "blockwise_controlnet_inputs": blockwise_controlnet_inputs,
+ "tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride,
+ "eligen_entity_prompts": eligen_entity_prompts, "eligen_entity_masks": eligen_entity_masks, "eligen_enable_on_negative": eligen_enable_on_negative,
+ "edit_image": edit_image, "edit_image_auto_resize": edit_image_auto_resize, "edit_rope_interpolation": edit_rope_interpolation,
+ "context_image": context_image,
+ }
+ for unit in self.units:
+ inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
+
+ # Denoise
+ self.load_models_to_device(self.in_iteration_models)
+ models = {name: getattr(self, name) for name in self.in_iteration_models}
+ for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
+ timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
+
+ # Inference
+ noise_pred_posi = self.model_fn(**models, **inputs_shared, **inputs_posi, timestep=timestep, progress_id=progress_id)
+ if cfg_scale != 1.0:
+ noise_pred_nega = self.model_fn(**models, **inputs_shared, **inputs_nega, timestep=timestep, progress_id=progress_id)
+ noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
+ else:
+ noise_pred = noise_pred_posi
+
+ # Scheduler
+ inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared)
+
+ # Decode
+ self.load_models_to_device(['vae'])
+ image = self.vae.decode(inputs_shared["latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
+ image = self.vae_output_to_image(image)
+ self.load_models_to_device([])
+
+ return image
+
+
+
+class QwenImageUnit_ShapeChecker(PipelineUnit):
+ def __init__(self):
+ super().__init__(input_params=("height", "width"))
+
+ def process(self, pipe: QwenImagePipeline, height, width):
+ height, width = pipe.check_resize_height_width(height, width)
+ return {"height": height, "width": width}
+
+
+
+class QwenImageUnit_NoiseInitializer(PipelineUnit):
+ def __init__(self):
+ super().__init__(input_params=("height", "width", "seed", "rand_device"))
+
+ def process(self, pipe: QwenImagePipeline, height, width, seed, rand_device):
+ noise = pipe.generate_noise((1, 16, height//8, width//8), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
+ return {"noise": noise}
+
+
+
+class QwenImageUnit_InputImageEmbedder(PipelineUnit):
+ def __init__(self):
+ super().__init__(
+ input_params=("input_image", "noise", "tiled", "tile_size", "tile_stride"),
+ onload_model_names=("vae",)
+ )
+
+ def process(self, pipe: QwenImagePipeline, input_image, noise, tiled, tile_size, tile_stride):
+ if input_image is None:
+ return {"latents": noise, "input_latents": None}
+ pipe.load_models_to_device(['vae'])
+ image = pipe.preprocess_image(input_image).to(device=pipe.device, dtype=pipe.torch_dtype)
+ input_latents = pipe.vae.encode(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
+ if pipe.scheduler.training:
+ return {"latents": noise, "input_latents": input_latents}
+ else:
+ latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0])
+ return {"latents": latents, "input_latents": input_latents}
+
+
+
+class QwenImageUnit_Inpaint(PipelineUnit):
+ def __init__(self):
+ super().__init__(
+ input_params=("inpaint_mask", "height", "width", "inpaint_blur_size", "inpaint_blur_sigma"),
+ )
+
+ def process(self, pipe: QwenImagePipeline, inpaint_mask, height, width, inpaint_blur_size, inpaint_blur_sigma):
+ if inpaint_mask is None:
+ return {}
+ inpaint_mask = pipe.preprocess_image(inpaint_mask.convert("RGB").resize((width // 8, height // 8)), min_value=0, max_value=1)
+ inpaint_mask = inpaint_mask.mean(dim=1, keepdim=True)
+ if inpaint_blur_size is not None and inpaint_blur_sigma is not None:
+ from torchvision.transforms import GaussianBlur
+ blur = GaussianBlur(kernel_size=inpaint_blur_size * 2 + 1, sigma=inpaint_blur_sigma)
+ inpaint_mask = blur(inpaint_mask)
+ return {"inpaint_mask": inpaint_mask}
+
+
+class QwenImageUnit_PromptEmbedder(PipelineUnit):
+ def __init__(self):
+ super().__init__(
+ seperate_cfg=True,
+ input_params_posi={"prompt": "prompt"},
+ input_params_nega={"prompt": "negative_prompt"},
+ input_params=("edit_image",),
+ onload_model_names=("text_encoder",)
+ )
+
+ def extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor):
+ bool_mask = mask.bool()
+ valid_lengths = bool_mask.sum(dim=1)
+ selected = hidden_states[bool_mask]
+ split_result = torch.split(selected, valid_lengths.tolist(), dim=0)
+ return split_result
+
+ def calculate_dimensions(self, target_area, ratio):
+ import math
+ width = math.sqrt(target_area * ratio)
+ height = width / ratio
+ width = round(width / 32) * 32
+ height = round(height / 32) * 32
+ return width, height
+
+ def resize_image(self, image, target_area=384*384):
+ width, height = self.calculate_dimensions(target_area, image.size[0] / image.size[1])
+ return image.resize((width, height))
+
+ def encode_prompt(self, pipe: QwenImagePipeline, prompt):
+ template = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
+ drop_idx = 34
+ txt = [template.format(e) for e in prompt]
+ model_inputs = pipe.tokenizer(txt, max_length=4096+drop_idx, padding=True, truncation=True, return_tensors="pt").to(pipe.device)
+ if model_inputs.input_ids.shape[1] >= 1024:
+ print(f"Warning!!! QwenImage model was trained on prompts up to 512 tokens. Current prompt requires {model_inputs['input_ids'].shape[1] - drop_idx} tokens, which may lead to unpredictable behavior.")
+ hidden_states = pipe.text_encoder(input_ids=model_inputs.input_ids, attention_mask=model_inputs.attention_mask, output_hidden_states=True,)[-1]
+ split_hidden_states = self.extract_masked_hidden(hidden_states, model_inputs.attention_mask)
+ split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
+ return split_hidden_states
+
+ def encode_prompt_edit(self, pipe: QwenImagePipeline, prompt, edit_image):
+ template = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
+ drop_idx = 64
+ txt = [template.format(e) for e in prompt]
+ model_inputs = pipe.processor(text=txt, images=edit_image, padding=True, return_tensors="pt").to(pipe.device)
+ hidden_states = pipe.text_encoder(input_ids=model_inputs.input_ids, attention_mask=model_inputs.attention_mask, pixel_values=model_inputs.pixel_values, image_grid_thw=model_inputs.image_grid_thw, output_hidden_states=True,)[-1]
+ split_hidden_states = self.extract_masked_hidden(hidden_states, model_inputs.attention_mask)
+ split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
+ return split_hidden_states
+
+ def encode_prompt_edit_multi(self, pipe: QwenImagePipeline, prompt, edit_image):
+ template = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
+ drop_idx = 64
+ img_prompt_template = "Picture {}: <|vision_start|><|image_pad|><|vision_end|>"
+ base_img_prompt = "".join([img_prompt_template.format(i + 1) for i in range(len(edit_image))])
+ txt = [template.format(base_img_prompt + e) for e in prompt]
+ edit_image = [self.resize_image(image) for image in edit_image]
+ model_inputs = pipe.processor(text=txt, images=edit_image, padding=True, return_tensors="pt").to(pipe.device)
+ hidden_states = pipe.text_encoder(input_ids=model_inputs.input_ids, attention_mask=model_inputs.attention_mask, pixel_values=model_inputs.pixel_values, image_grid_thw=model_inputs.image_grid_thw, output_hidden_states=True,)[-1]
+ split_hidden_states = self.extract_masked_hidden(hidden_states, model_inputs.attention_mask)
+ split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
+ return split_hidden_states
+
+ def process(self, pipe: QwenImagePipeline, prompt, edit_image=None) -> dict:
+ if pipe.text_encoder is not None:
+ prompt = [prompt]
+ if edit_image is None:
+ split_hidden_states = self.encode_prompt(pipe, prompt)
+ elif isinstance(edit_image, Image.Image):
+ split_hidden_states = self.encode_prompt_edit(pipe, prompt, edit_image)
+ else:
+ split_hidden_states = self.encode_prompt_edit_multi(pipe, prompt, edit_image)
+ attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
+ max_seq_len = max([e.size(0) for e in split_hidden_states])
+ prompt_embeds = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states])
+ encoder_attention_mask = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list])
+ prompt_embeds = prompt_embeds.to(dtype=pipe.torch_dtype, device=pipe.device)
+ return {"prompt_emb": prompt_embeds, "prompt_emb_mask": encoder_attention_mask}
+ else:
+ return {}
+
+
+class QwenImageUnit_EntityControl(PipelineUnit):
+ def __init__(self):
+ super().__init__(
+ take_over=True,
+ onload_model_names=("text_encoder",)
+ )
+
+ def extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor):
+ bool_mask = mask.bool()
+ valid_lengths = bool_mask.sum(dim=1)
+ selected = hidden_states[bool_mask]
+ split_result = torch.split(selected, valid_lengths.tolist(), dim=0)
+ return split_result
+
+ def get_prompt_emb(self, pipe: QwenImagePipeline, prompt) -> dict:
+ if pipe.text_encoder is not None:
+ prompt = [prompt]
+ template = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
+ drop_idx = 34
+ txt = [template.format(e) for e in prompt]
+ txt_tokens = pipe.tokenizer(txt, max_length=1024+drop_idx, padding=True, truncation=True, return_tensors="pt").to(pipe.device)
+ hidden_states = pipe.text_encoder(input_ids=txt_tokens.input_ids, attention_mask=txt_tokens.attention_mask, output_hidden_states=True,)[-1]
+
+ split_hidden_states = self.extract_masked_hidden(hidden_states, txt_tokens.attention_mask)
+ split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
+ attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
+ max_seq_len = max([e.size(0) for e in split_hidden_states])
+ prompt_embeds = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states])
+ encoder_attention_mask = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list])
+ prompt_embeds = prompt_embeds.to(dtype=pipe.torch_dtype, device=pipe.device)
+ return {"prompt_emb": prompt_embeds, "prompt_emb_mask": encoder_attention_mask}
+ else:
+ return {}
+
+ def preprocess_masks(self, pipe, masks, height, width, dim):
+ out_masks = []
+ for mask in masks:
+ mask = pipe.preprocess_image(mask.resize((width, height), resample=Image.NEAREST)).mean(dim=1, keepdim=True) > 0
+ mask = mask.repeat(1, dim, 1, 1).to(device=pipe.device, dtype=pipe.torch_dtype)
+ out_masks.append(mask)
+ return out_masks
+
+ def prepare_entity_inputs(self, pipe, entity_prompts, entity_masks, width, height):
+ entity_masks = self.preprocess_masks(pipe, entity_masks, height//8, width//8, 1)
+ entity_masks = torch.cat(entity_masks, dim=0).unsqueeze(0) # b, n_mask, c, h, w
+ prompt_embs, prompt_emb_masks = [], []
+ for entity_prompt in entity_prompts:
+ prompt_emb_dict = self.get_prompt_emb(pipe, entity_prompt)
+ prompt_embs.append(prompt_emb_dict['prompt_emb'])
+ prompt_emb_masks.append(prompt_emb_dict['prompt_emb_mask'])
+ return prompt_embs, prompt_emb_masks, entity_masks
+
+ def prepare_eligen(self, pipe, prompt_emb_nega, eligen_entity_prompts, eligen_entity_masks, width, height, enable_eligen_on_negative, cfg_scale):
+ entity_prompt_emb_posi, entity_prompt_emb_posi_mask, entity_masks_posi = self.prepare_entity_inputs(pipe, eligen_entity_prompts, eligen_entity_masks, width, height)
+ if enable_eligen_on_negative and cfg_scale != 1.0:
+ entity_prompt_emb_nega = [prompt_emb_nega['prompt_emb']] * len(entity_prompt_emb_posi)
+ entity_prompt_emb_nega_mask = [prompt_emb_nega['prompt_emb_mask']] * len(entity_prompt_emb_posi)
+ entity_masks_nega = entity_masks_posi
+ else:
+ entity_prompt_emb_nega, entity_prompt_emb_nega_mask, entity_masks_nega = None, None, None
+ eligen_kwargs_posi = {"entity_prompt_emb": entity_prompt_emb_posi, "entity_masks": entity_masks_posi, "entity_prompt_emb_mask": entity_prompt_emb_posi_mask}
+ eligen_kwargs_nega = {"entity_prompt_emb": entity_prompt_emb_nega, "entity_masks": entity_masks_nega, "entity_prompt_emb_mask": entity_prompt_emb_nega_mask}
+ return eligen_kwargs_posi, eligen_kwargs_nega
+
+ def process(self, pipe: QwenImagePipeline, inputs_shared, inputs_posi, inputs_nega):
+ eligen_entity_prompts, eligen_entity_masks = inputs_shared.get("eligen_entity_prompts", None), inputs_shared.get("eligen_entity_masks", None)
+ if eligen_entity_prompts is None or eligen_entity_masks is None or len(eligen_entity_prompts) == 0 or len(eligen_entity_masks) == 0:
+ return inputs_shared, inputs_posi, inputs_nega
+ pipe.load_models_to_device(self.onload_model_names)
+ eligen_enable_on_negative = inputs_shared.get("eligen_enable_on_negative", False)
+ eligen_kwargs_posi, eligen_kwargs_nega = self.prepare_eligen(pipe, inputs_nega,
+ eligen_entity_prompts, eligen_entity_masks, inputs_shared["width"], inputs_shared["height"],
+ eligen_enable_on_negative, inputs_shared["cfg_scale"])
+ inputs_posi.update(eligen_kwargs_posi)
+ if inputs_shared.get("cfg_scale", 1.0) != 1.0:
+ inputs_nega.update(eligen_kwargs_nega)
+ return inputs_shared, inputs_posi, inputs_nega
+
+
+
+class QwenImageUnit_BlockwiseControlNet(PipelineUnit):
+ def __init__(self):
+ super().__init__(
+ input_params=("blockwise_controlnet_inputs", "tiled", "tile_size", "tile_stride"),
+ onload_model_names=("vae",)
+ )
+
+ def apply_controlnet_mask_on_latents(self, pipe, latents, mask):
+ mask = (pipe.preprocess_image(mask) + 1) / 2
+ mask = mask.mean(dim=1, keepdim=True)
+ mask = 1 - torch.nn.functional.interpolate(mask, size=latents.shape[-2:])
+ latents = torch.concat([latents, mask], dim=1)
+ return latents
+
+ def apply_controlnet_mask_on_image(self, pipe, image, mask):
+ mask = mask.resize(image.size)
+ mask = pipe.preprocess_image(mask).mean(dim=[0, 1]).cpu()
+ image = np.array(image)
+ image[mask > 0] = 0
+ image = Image.fromarray(image)
+ return image
+
+ def process(self, pipe: QwenImagePipeline, blockwise_controlnet_inputs: list[ControlNetInput], tiled, tile_size, tile_stride):
+ if blockwise_controlnet_inputs is None:
+ return {}
+ pipe.load_models_to_device(self.onload_model_names)
+ conditionings = []
+ for controlnet_input in blockwise_controlnet_inputs:
+ image = controlnet_input.image
+ if controlnet_input.inpaint_mask is not None:
+ image = self.apply_controlnet_mask_on_image(pipe, image, controlnet_input.inpaint_mask)
+
+ image = pipe.preprocess_image(image).to(device=pipe.device, dtype=pipe.torch_dtype)
+ image = pipe.vae.encode(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
+
+ if controlnet_input.inpaint_mask is not None:
+ image = self.apply_controlnet_mask_on_latents(pipe, image, controlnet_input.inpaint_mask)
+ conditionings.append(image)
+
+ return {"blockwise_controlnet_conditioning": conditionings}
+
+
+class QwenImageUnit_EditImageEmbedder(PipelineUnit):
+ def __init__(self):
+ super().__init__(
+ input_params=("edit_image", "tiled", "tile_size", "tile_stride", "edit_image_auto_resize"),
+ onload_model_names=("vae",)
+ )
+
+
+ def calculate_dimensions(self, target_area, ratio):
+ import math
+ width = math.sqrt(target_area * ratio)
+ height = width / ratio
+ width = round(width / 32) * 32
+ height = round(height / 32) * 32
+ return width, height
+
+
+ def edit_image_auto_resize(self, edit_image):
+ calculated_width, calculated_height = self.calculate_dimensions(1024 * 1024, edit_image.size[0] / edit_image.size[1])
+ return edit_image.resize((calculated_width, calculated_height))
+
+
+ def process(self, pipe: QwenImagePipeline, edit_image, tiled, tile_size, tile_stride, edit_image_auto_resize=False):
+ if edit_image is None:
+ return {}
+ pipe.load_models_to_device(['vae'])
+ if isinstance(edit_image, Image.Image):
+ resized_edit_image = self.edit_image_auto_resize(edit_image) if edit_image_auto_resize else edit_image
+ edit_image = pipe.preprocess_image(resized_edit_image).to(device=pipe.device, dtype=pipe.torch_dtype)
+ edit_latents = pipe.vae.encode(edit_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
+ else:
+ resized_edit_image, edit_latents = [], []
+ for image in edit_image:
+ if edit_image_auto_resize:
+ image = self.edit_image_auto_resize(image)
+ resized_edit_image.append(image)
+ image = pipe.preprocess_image(image).to(device=pipe.device, dtype=pipe.torch_dtype)
+ latents = pipe.vae.encode(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
+ edit_latents.append(latents)
+ return {"edit_latents": edit_latents, "edit_image": resized_edit_image}
+
+
+class QwenImageUnit_ContextImageEmbedder(PipelineUnit):
+ def __init__(self):
+ super().__init__(
+ input_params=("context_image", "height", "width", "tiled", "tile_size", "tile_stride"),
+ onload_model_names=("vae",)
+ )
+
+ def process(self, pipe: QwenImagePipeline, context_image, height, width, tiled, tile_size, tile_stride):
+ if context_image is None:
+ return {}
+ pipe.load_models_to_device(['vae'])
+ context_image = pipe.preprocess_image(context_image.resize((width, height))).to(device=pipe.device, dtype=pipe.torch_dtype)
+ context_latents = pipe.vae.encode(context_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
+ return {"context_latents": context_latents}
+
+
+def model_fn_qwen_image(
+ dit: QwenImageDiT = None,
+ blockwise_controlnet: QwenImageBlockwiseMultiControlNet = None,
+ latents=None,
+ timestep=None,
+ prompt_emb=None,
+ prompt_emb_mask=None,
+ height=None,
+ width=None,
+ blockwise_controlnet_conditioning=None,
+ blockwise_controlnet_inputs=None,
+ progress_id=0,
+ num_inference_steps=1,
+ entity_prompt_emb=None,
+ entity_prompt_emb_mask=None,
+ entity_masks=None,
+ edit_latents=None,
+ context_latents=None,
+ enable_fp8_attention=False,
+ use_gradient_checkpointing=False,
+ use_gradient_checkpointing_offload=False,
+ edit_rope_interpolation=False,
+ **kwargs
+):
+ img_shapes = [(latents.shape[0], latents.shape[2]//2, latents.shape[3]//2)]
+ txt_seq_lens = prompt_emb_mask.sum(dim=1).tolist()
+ timestep = timestep / 1000
+
+ image = rearrange(latents, "B C (H P) (W Q) -> B (H W) (C P Q)", H=height//16, W=width//16, P=2, Q=2)
+ image_seq_len = image.shape[1]
+
+ if context_latents is not None:
+ img_shapes += [(context_latents.shape[0], context_latents.shape[2]//2, context_latents.shape[3]//2)]
+ context_image = rearrange(context_latents, "B C (H P) (W Q) -> B (H W) (C P Q)", H=context_latents.shape[2]//2, W=context_latents.shape[3]//2, P=2, Q=2)
+ image = torch.cat([image, context_image], dim=1)
+ if edit_latents is not None:
+ edit_latents_list = edit_latents if isinstance(edit_latents, list) else [edit_latents]
+ img_shapes += [(e.shape[0], e.shape[2]//2, e.shape[3]//2) for e in edit_latents_list]
+ edit_image = [rearrange(e, "B C (H P) (W Q) -> B (H W) (C P Q)", H=e.shape[2]//2, W=e.shape[3]//2, P=2, Q=2) for e in edit_latents_list]
+ image = torch.cat([image] + edit_image, dim=1)
+
+ image = dit.img_in(image)
+ conditioning = dit.time_text_embed(timestep, image.dtype)
+
+ if entity_prompt_emb is not None:
+ text, image_rotary_emb, attention_mask = dit.process_entity_masks(
+ latents, prompt_emb, prompt_emb_mask, entity_prompt_emb, entity_prompt_emb_mask,
+ entity_masks, height, width, image, img_shapes,
+ )
+ else:
+ text = dit.txt_in(dit.txt_norm(prompt_emb))
+ if edit_rope_interpolation:
+ image_rotary_emb = dit.pos_embed.forward_sampling(img_shapes, txt_seq_lens, device=latents.device)
+ else:
+ image_rotary_emb = dit.pos_embed(img_shapes, txt_seq_lens, device=latents.device)
+ attention_mask = None
+
+ if blockwise_controlnet_conditioning is not None:
+ blockwise_controlnet_conditioning = blockwise_controlnet.preprocess(
+ blockwise_controlnet_inputs, blockwise_controlnet_conditioning)
+
+ for block_id, block in enumerate(dit.transformer_blocks):
+ text, image = gradient_checkpoint_forward(
+ block,
+ use_gradient_checkpointing,
+ use_gradient_checkpointing_offload,
+ image=image,
+ text=text,
+ temb=conditioning,
+ image_rotary_emb=image_rotary_emb,
+ attention_mask=attention_mask,
+ enable_fp8_attention=enable_fp8_attention,
+ )
+ if blockwise_controlnet_conditioning is not None:
+ image_slice = image[:, :image_seq_len].clone()
+ controlnet_output = blockwise_controlnet.blockwise_forward(
+ image=image_slice, conditionings=blockwise_controlnet_conditioning,
+ controlnet_inputs=blockwise_controlnet_inputs, block_id=block_id,
+ progress_id=progress_id, num_inference_steps=num_inference_steps,
+ )
+ image[:, :image_seq_len] = image_slice + controlnet_output
+
+ image = dit.norm_out(image, conditioning)
+ image = dit.proj_out(image)
+ image = image[:, :image_seq_len]
+
+ latents = rearrange(image, "B (H W) (C P Q) -> B C (H P) (W Q)", H=height//16, W=width//16, P=2, Q=2)
+ return latents
diff --git a/diffsynth/pipelines/sd3_image.py b/diffsynth/pipelines/sd3_image.py
new file mode 100644
index 0000000000000000000000000000000000000000..c6098739b2701d59958ef3fa85b0dc96b5ffe86a
--- /dev/null
+++ b/diffsynth/pipelines/sd3_image.py
@@ -0,0 +1,147 @@
+from ..models import ModelManager, SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3, SD3DiT, SD3VAEDecoder, SD3VAEEncoder
+from ..prompters import SD3Prompter
+from ..schedulers import FlowMatchScheduler
+from .base import BasePipeline
+import torch
+from tqdm import tqdm
+
+
+
+class SD3ImagePipeline(BasePipeline):
+
+ def __init__(self, device="cuda", torch_dtype=torch.float16):
+ super().__init__(device=device, torch_dtype=torch_dtype, height_division_factor=16, width_division_factor=16)
+ self.scheduler = FlowMatchScheduler()
+ self.prompter = SD3Prompter()
+ # models
+ self.text_encoder_1: SD3TextEncoder1 = None
+ self.text_encoder_2: SD3TextEncoder2 = None
+ self.text_encoder_3: SD3TextEncoder3 = None
+ self.dit: SD3DiT = None
+ self.vae_decoder: SD3VAEDecoder = None
+ self.vae_encoder: SD3VAEEncoder = None
+ self.model_names = ['text_encoder_1', 'text_encoder_2', 'text_encoder_3', 'dit', 'vae_decoder', 'vae_encoder']
+
+
+ def denoising_model(self):
+ return self.dit
+
+
+ def fetch_models(self, model_manager: ModelManager, prompt_refiner_classes=[]):
+ self.text_encoder_1 = model_manager.fetch_model("sd3_text_encoder_1")
+ self.text_encoder_2 = model_manager.fetch_model("sd3_text_encoder_2")
+ self.text_encoder_3 = model_manager.fetch_model("sd3_text_encoder_3")
+ self.dit = model_manager.fetch_model("sd3_dit")
+ self.vae_decoder = model_manager.fetch_model("sd3_vae_decoder")
+ self.vae_encoder = model_manager.fetch_model("sd3_vae_encoder")
+ self.prompter.fetch_models(self.text_encoder_1, self.text_encoder_2, self.text_encoder_3)
+ self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes)
+
+
+ @staticmethod
+ def from_model_manager(model_manager: ModelManager, prompt_refiner_classes=[], device=None):
+ pipe = SD3ImagePipeline(
+ device=model_manager.device if device is None else device,
+ torch_dtype=model_manager.torch_dtype,
+ )
+ pipe.fetch_models(model_manager, prompt_refiner_classes)
+ return pipe
+
+
+ def encode_image(self, image, tiled=False, tile_size=64, tile_stride=32):
+ latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
+ return latents
+
+
+ def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
+ image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
+ image = self.vae_output_to_image(image)
+ return image
+
+
+ def encode_prompt(self, prompt, positive=True, t5_sequence_length=77):
+ prompt_emb, pooled_prompt_emb = self.prompter.encode_prompt(
+ prompt, device=self.device, positive=positive, t5_sequence_length=t5_sequence_length
+ )
+ return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb}
+
+
+ def prepare_extra_input(self, latents=None):
+ return {}
+
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt,
+ local_prompts=[],
+ masks=[],
+ mask_scales=[],
+ negative_prompt="",
+ cfg_scale=7.5,
+ input_image=None,
+ denoising_strength=1.0,
+ height=1024,
+ width=1024,
+ num_inference_steps=20,
+ t5_sequence_length=77,
+ tiled=False,
+ tile_size=128,
+ tile_stride=64,
+ seed=None,
+ progress_bar_cmd=tqdm,
+ progress_bar_st=None,
+ ):
+ height, width = self.check_resize_height_width(height, width)
+
+ # Tiler parameters
+ tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
+
+ # Prepare scheduler
+ self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
+
+ # Prepare latent tensors
+ if input_image is not None:
+ self.load_models_to_device(['vae_encoder'])
+ image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
+ latents = self.encode_image(image, **tiler_kwargs)
+ noise = self.generate_noise((1, 16, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
+ latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
+ else:
+ latents = self.generate_noise((1, 16, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
+
+ # Encode prompts
+ self.load_models_to_device(['text_encoder_1', 'text_encoder_2', 'text_encoder_3'])
+ prompt_emb_posi = self.encode_prompt(prompt, positive=True, t5_sequence_length=t5_sequence_length)
+ prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False, t5_sequence_length=t5_sequence_length)
+ prompt_emb_locals = [self.encode_prompt(prompt_local, t5_sequence_length=t5_sequence_length) for prompt_local in local_prompts]
+
+ # Denoise
+ self.load_models_to_device(['dit'])
+ for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
+ timestep = timestep.unsqueeze(0).to(self.device)
+
+ # Classifier-free guidance
+ inference_callback = lambda prompt_emb_posi: self.dit(
+ latents, timestep=timestep, **prompt_emb_posi, **tiler_kwargs,
+ )
+ noise_pred_posi = self.control_noise_via_local_prompts(prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback)
+ noise_pred_nega = self.dit(
+ latents, timestep=timestep, **prompt_emb_nega, **tiler_kwargs,
+ )
+ noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
+
+ # DDIM
+ latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
+
+ # UI
+ if progress_bar_st is not None:
+ progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
+
+ # Decode image
+ self.load_models_to_device(['vae_decoder'])
+ image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
+
+ # offload all models
+ self.load_models_to_device([])
+ return image
diff --git a/diffsynth/pipelines/sd_image.py b/diffsynth/pipelines/sd_image.py
new file mode 100644
index 0000000000000000000000000000000000000000..c22c3fe69578f28925be900036bf21afeb750f17
--- /dev/null
+++ b/diffsynth/pipelines/sd_image.py
@@ -0,0 +1,191 @@
+from ..models import SDTextEncoder, SDUNet, SDVAEDecoder, SDVAEEncoder, SDIpAdapter, IpAdapterCLIPImageEmbedder
+from ..models.model_manager import ModelManager
+from ..controlnets import MultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
+from ..prompters import SDPrompter
+from ..schedulers import EnhancedDDIMScheduler
+from .base import BasePipeline
+from .dancer import lets_dance
+from typing import List
+import torch
+from tqdm import tqdm
+
+
+
+class SDImagePipeline(BasePipeline):
+
+ def __init__(self, device="cuda", torch_dtype=torch.float16):
+ super().__init__(device=device, torch_dtype=torch_dtype)
+ self.scheduler = EnhancedDDIMScheduler()
+ self.prompter = SDPrompter()
+ # models
+ self.text_encoder: SDTextEncoder = None
+ self.unet: SDUNet = None
+ self.vae_decoder: SDVAEDecoder = None
+ self.vae_encoder: SDVAEEncoder = None
+ self.controlnet: MultiControlNetManager = None
+ self.ipadapter_image_encoder: IpAdapterCLIPImageEmbedder = None
+ self.ipadapter: SDIpAdapter = None
+ self.model_names = ['text_encoder', 'unet', 'vae_decoder', 'vae_encoder', 'controlnet', 'ipadapter_image_encoder', 'ipadapter']
+
+
+ def denoising_model(self):
+ return self.unet
+
+
+ def fetch_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[]):
+ # Main models
+ self.text_encoder = model_manager.fetch_model("sd_text_encoder")
+ self.unet = model_manager.fetch_model("sd_unet")
+ self.vae_decoder = model_manager.fetch_model("sd_vae_decoder")
+ self.vae_encoder = model_manager.fetch_model("sd_vae_encoder")
+ self.prompter.fetch_models(self.text_encoder)
+ self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes)
+
+ # ControlNets
+ controlnet_units = []
+ for config in controlnet_config_units:
+ controlnet_unit = ControlNetUnit(
+ Annotator(config.processor_id, device=self.device),
+ model_manager.fetch_model("sd_controlnet", config.model_path),
+ config.scale
+ )
+ controlnet_units.append(controlnet_unit)
+ self.controlnet = MultiControlNetManager(controlnet_units)
+
+ # IP-Adapters
+ self.ipadapter = model_manager.fetch_model("sd_ipadapter")
+ self.ipadapter_image_encoder = model_manager.fetch_model("sd_ipadapter_clip_image_encoder")
+
+
+ @staticmethod
+ def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[], device=None):
+ pipe = SDImagePipeline(
+ device=model_manager.device if device is None else device,
+ torch_dtype=model_manager.torch_dtype,
+ )
+ pipe.fetch_models(model_manager, controlnet_config_units, prompt_refiner_classes=[])
+ return pipe
+
+
+ def encode_image(self, image, tiled=False, tile_size=64, tile_stride=32):
+ latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
+ return latents
+
+
+ def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
+ image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
+ image = self.vae_output_to_image(image)
+ return image
+
+
+ def encode_prompt(self, prompt, clip_skip=1, positive=True):
+ prompt_emb = self.prompter.encode_prompt(prompt, clip_skip=clip_skip, device=self.device, positive=positive)
+ return {"encoder_hidden_states": prompt_emb}
+
+
+ def prepare_extra_input(self, latents=None):
+ return {}
+
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt,
+ local_prompts=[],
+ masks=[],
+ mask_scales=[],
+ negative_prompt="",
+ cfg_scale=7.5,
+ clip_skip=1,
+ input_image=None,
+ ipadapter_images=None,
+ ipadapter_scale=1.0,
+ controlnet_image=None,
+ denoising_strength=1.0,
+ height=512,
+ width=512,
+ num_inference_steps=20,
+ tiled=False,
+ tile_size=64,
+ tile_stride=32,
+ seed=None,
+ progress_bar_cmd=tqdm,
+ progress_bar_st=None,
+ ):
+ height, width = self.check_resize_height_width(height, width)
+
+ # Tiler parameters
+ tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
+
+ # Prepare scheduler
+ self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
+
+ # Prepare latent tensors
+ if input_image is not None:
+ self.load_models_to_device(['vae_encoder'])
+ image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
+ latents = self.encode_image(image, **tiler_kwargs)
+ noise = self.generate_noise((1, 4, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
+ latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
+ else:
+ latents = self.generate_noise((1, 4, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
+
+ # Encode prompts
+ self.load_models_to_device(['text_encoder'])
+ prompt_emb_posi = self.encode_prompt(prompt, clip_skip=clip_skip, positive=True)
+ prompt_emb_nega = self.encode_prompt(negative_prompt, clip_skip=clip_skip, positive=False)
+ prompt_emb_locals = [self.encode_prompt(prompt_local, clip_skip=clip_skip, positive=True) for prompt_local in local_prompts]
+
+ # IP-Adapter
+ if ipadapter_images is not None:
+ self.load_models_to_device(['ipadapter_image_encoder'])
+ ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images)
+ self.load_models_to_device(['ipadapter'])
+ ipadapter_kwargs_list_posi = {"ipadapter_kwargs_list": self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)}
+ ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": self.ipadapter(torch.zeros_like(ipadapter_image_encoding))}
+ else:
+ ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": {}}, {"ipadapter_kwargs_list": {}}
+
+ # Prepare ControlNets
+ if controlnet_image is not None:
+ self.load_models_to_device(['controlnet'])
+ controlnet_image = self.controlnet.process_image(controlnet_image).to(device=self.device, dtype=self.torch_dtype)
+ controlnet_image = controlnet_image.unsqueeze(1)
+ controlnet_kwargs = {"controlnet_frames": controlnet_image}
+ else:
+ controlnet_kwargs = {"controlnet_frames": None}
+
+ # Denoise
+ self.load_models_to_device(['controlnet', 'unet'])
+ for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
+ timestep = timestep.unsqueeze(0).to(self.device)
+
+ # Classifier-free guidance
+ inference_callback = lambda prompt_emb_posi: lets_dance(
+ self.unet, motion_modules=None, controlnet=self.controlnet,
+ sample=latents, timestep=timestep,
+ **prompt_emb_posi, **controlnet_kwargs, **tiler_kwargs, **ipadapter_kwargs_list_posi,
+ device=self.device,
+ )
+ noise_pred_posi = self.control_noise_via_local_prompts(prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback)
+ noise_pred_nega = lets_dance(
+ self.unet, motion_modules=None, controlnet=self.controlnet,
+ sample=latents, timestep=timestep, **prompt_emb_nega, **controlnet_kwargs, **tiler_kwargs, **ipadapter_kwargs_list_nega,
+ device=self.device,
+ )
+ noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
+
+ # DDIM
+ latents = self.scheduler.step(noise_pred, timestep, latents)
+
+ # UI
+ if progress_bar_st is not None:
+ progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
+
+ # Decode image
+ self.load_models_to_device(['vae_decoder'])
+ image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
+
+ # offload all models
+ self.load_models_to_device([])
+ return image
diff --git a/diffsynth/pipelines/sd_video.py b/diffsynth/pipelines/sd_video.py
new file mode 100644
index 0000000000000000000000000000000000000000..4337beb4f7a2d4a08c5955fdbd5f528ea328b39e
--- /dev/null
+++ b/diffsynth/pipelines/sd_video.py
@@ -0,0 +1,269 @@
+from ..models import SDTextEncoder, SDUNet, SDVAEDecoder, SDVAEEncoder, SDIpAdapter, IpAdapterCLIPImageEmbedder, SDMotionModel
+from ..models.model_manager import ModelManager
+from ..controlnets import MultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
+from ..prompters import SDPrompter
+from ..schedulers import EnhancedDDIMScheduler
+from .sd_image import SDImagePipeline
+from .dancer import lets_dance
+from typing import List
+import torch
+from tqdm import tqdm
+
+
+
+def lets_dance_with_long_video(
+ unet: SDUNet,
+ motion_modules: SDMotionModel = None,
+ controlnet: MultiControlNetManager = None,
+ sample = None,
+ timestep = None,
+ encoder_hidden_states = None,
+ ipadapter_kwargs_list = {},
+ controlnet_frames = None,
+ unet_batch_size = 1,
+ controlnet_batch_size = 1,
+ cross_frame_attention = False,
+ tiled=False,
+ tile_size=64,
+ tile_stride=32,
+ device="cuda",
+ animatediff_batch_size=16,
+ animatediff_stride=8,
+):
+ num_frames = sample.shape[0]
+ hidden_states_output = [(torch.zeros(sample[0].shape, dtype=sample[0].dtype), 0) for i in range(num_frames)]
+
+ for batch_id in range(0, num_frames, animatediff_stride):
+ batch_id_ = min(batch_id + animatediff_batch_size, num_frames)
+
+ # process this batch
+ hidden_states_batch = lets_dance(
+ unet, motion_modules, controlnet,
+ sample[batch_id: batch_id_].to(device),
+ timestep,
+ encoder_hidden_states,
+ ipadapter_kwargs_list=ipadapter_kwargs_list,
+ controlnet_frames=controlnet_frames[:, batch_id: batch_id_].to(device) if controlnet_frames is not None else None,
+ unet_batch_size=unet_batch_size, controlnet_batch_size=controlnet_batch_size,
+ cross_frame_attention=cross_frame_attention,
+ tiled=tiled, tile_size=tile_size, tile_stride=tile_stride, device=device
+ ).cpu()
+
+ # update hidden_states
+ for i, hidden_states_updated in zip(range(batch_id, batch_id_), hidden_states_batch):
+ bias = max(1 - abs(i - (batch_id + batch_id_ - 1) / 2) / ((batch_id_ - batch_id - 1 + 1e-2) / 2), 1e-2)
+ hidden_states, num = hidden_states_output[i]
+ hidden_states = hidden_states * (num / (num + bias)) + hidden_states_updated * (bias / (num + bias))
+ hidden_states_output[i] = (hidden_states, num + bias)
+
+ if batch_id_ == num_frames:
+ break
+
+ # output
+ hidden_states = torch.stack([h for h, _ in hidden_states_output])
+ return hidden_states
+
+
+
+class SDVideoPipeline(SDImagePipeline):
+
+ def __init__(self, device="cuda", torch_dtype=torch.float16, use_original_animatediff=True):
+ super().__init__(device=device, torch_dtype=torch_dtype)
+ self.scheduler = EnhancedDDIMScheduler(beta_schedule="linear" if use_original_animatediff else "scaled_linear")
+ self.prompter = SDPrompter()
+ # models
+ self.text_encoder: SDTextEncoder = None
+ self.unet: SDUNet = None
+ self.vae_decoder: SDVAEDecoder = None
+ self.vae_encoder: SDVAEEncoder = None
+ self.controlnet: MultiControlNetManager = None
+ self.ipadapter_image_encoder: IpAdapterCLIPImageEmbedder = None
+ self.ipadapter: SDIpAdapter = None
+ self.motion_modules: SDMotionModel = None
+
+
+ def fetch_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[]):
+ # Main models
+ self.text_encoder = model_manager.fetch_model("sd_text_encoder")
+ self.unet = model_manager.fetch_model("sd_unet")
+ self.vae_decoder = model_manager.fetch_model("sd_vae_decoder")
+ self.vae_encoder = model_manager.fetch_model("sd_vae_encoder")
+ self.prompter.fetch_models(self.text_encoder)
+ self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes)
+
+ # ControlNets
+ controlnet_units = []
+ for config in controlnet_config_units:
+ controlnet_unit = ControlNetUnit(
+ Annotator(config.processor_id, device=self.device),
+ model_manager.fetch_model("sd_controlnet", config.model_path),
+ config.scale
+ )
+ controlnet_units.append(controlnet_unit)
+ self.controlnet = MultiControlNetManager(controlnet_units)
+
+ # IP-Adapters
+ self.ipadapter = model_manager.fetch_model("sd_ipadapter")
+ self.ipadapter_image_encoder = model_manager.fetch_model("sd_ipadapter_clip_image_encoder")
+
+ # Motion Modules
+ self.motion_modules = model_manager.fetch_model("sd_motion_modules")
+ if self.motion_modules is None:
+ self.scheduler = EnhancedDDIMScheduler(beta_schedule="scaled_linear")
+
+
+ @staticmethod
+ def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[]):
+ pipe = SDVideoPipeline(
+ device=model_manager.device,
+ torch_dtype=model_manager.torch_dtype,
+ )
+ pipe.fetch_models(model_manager, controlnet_config_units, prompt_refiner_classes)
+ return pipe
+
+
+ def decode_video(self, latents, tiled=False, tile_size=64, tile_stride=32):
+ images = [
+ self.decode_image(latents[frame_id: frame_id+1], tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
+ for frame_id in range(latents.shape[0])
+ ]
+ return images
+
+
+ def encode_video(self, processed_images, tiled=False, tile_size=64, tile_stride=32):
+ latents = []
+ for image in processed_images:
+ image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype)
+ latent = self.encode_image(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
+ latents.append(latent.cpu())
+ latents = torch.concat(latents, dim=0)
+ return latents
+
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt,
+ negative_prompt="",
+ cfg_scale=7.5,
+ clip_skip=1,
+ num_frames=None,
+ input_frames=None,
+ ipadapter_images=None,
+ ipadapter_scale=1.0,
+ controlnet_frames=None,
+ denoising_strength=1.0,
+ height=512,
+ width=512,
+ num_inference_steps=20,
+ animatediff_batch_size = 16,
+ animatediff_stride = 8,
+ unet_batch_size = 1,
+ controlnet_batch_size = 1,
+ cross_frame_attention = False,
+ smoother=None,
+ smoother_progress_ids=[],
+ tiled=False,
+ tile_size=64,
+ tile_stride=32,
+ seed=None,
+ progress_bar_cmd=tqdm,
+ progress_bar_st=None,
+ ):
+ height, width = self.check_resize_height_width(height, width)
+
+ # Tiler parameters, batch size ...
+ tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
+ other_kwargs = {
+ "animatediff_batch_size": animatediff_batch_size, "animatediff_stride": animatediff_stride,
+ "unet_batch_size": unet_batch_size, "controlnet_batch_size": controlnet_batch_size,
+ "cross_frame_attention": cross_frame_attention,
+ }
+
+ # Prepare scheduler
+ self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
+
+ # Prepare latent tensors
+ if self.motion_modules is None:
+ noise = self.generate_noise((1, 4, height//8, width//8), seed=seed, device="cpu", dtype=self.torch_dtype).repeat(num_frames, 1, 1, 1)
+ else:
+ noise = self.generate_noise((num_frames, 4, height//8, width//8), seed=seed, device="cpu", dtype=self.torch_dtype)
+ if input_frames is None or denoising_strength == 1.0:
+ latents = noise
+ else:
+ latents = self.encode_video(input_frames, **tiler_kwargs)
+ latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
+
+ # Encode prompts
+ prompt_emb_posi = self.encode_prompt(prompt, clip_skip=clip_skip, positive=True)
+ prompt_emb_nega = self.encode_prompt(negative_prompt, clip_skip=clip_skip, positive=False)
+
+ # IP-Adapter
+ if ipadapter_images is not None:
+ ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images)
+ ipadapter_kwargs_list_posi = {"ipadapter_kwargs_list": self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)}
+ ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": self.ipadapter(torch.zeros_like(ipadapter_image_encoding))}
+ else:
+ ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": {}}, {"ipadapter_kwargs_list": {}}
+
+ # Prepare ControlNets
+ if controlnet_frames is not None:
+ if isinstance(controlnet_frames[0], list):
+ controlnet_frames_ = []
+ for processor_id in range(len(controlnet_frames)):
+ controlnet_frames_.append(
+ torch.stack([
+ self.controlnet.process_image(controlnet_frame, processor_id=processor_id).to(self.torch_dtype)
+ for controlnet_frame in progress_bar_cmd(controlnet_frames[processor_id])
+ ], dim=1)
+ )
+ controlnet_frames = torch.concat(controlnet_frames_, dim=0)
+ else:
+ controlnet_frames = torch.stack([
+ self.controlnet.process_image(controlnet_frame).to(self.torch_dtype)
+ for controlnet_frame in progress_bar_cmd(controlnet_frames)
+ ], dim=1)
+ controlnet_kwargs = {"controlnet_frames": controlnet_frames}
+ else:
+ controlnet_kwargs = {"controlnet_frames": None}
+
+ # Denoise
+ for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
+ timestep = timestep.unsqueeze(0).to(self.device)
+
+ # Classifier-free guidance
+ noise_pred_posi = lets_dance_with_long_video(
+ self.unet, motion_modules=self.motion_modules, controlnet=self.controlnet,
+ sample=latents, timestep=timestep,
+ **prompt_emb_posi, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **other_kwargs, **tiler_kwargs,
+ device=self.device,
+ )
+ noise_pred_nega = lets_dance_with_long_video(
+ self.unet, motion_modules=self.motion_modules, controlnet=self.controlnet,
+ sample=latents, timestep=timestep,
+ **prompt_emb_nega, **controlnet_kwargs, **ipadapter_kwargs_list_nega, **other_kwargs, **tiler_kwargs,
+ device=self.device,
+ )
+ noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
+
+ # DDIM and smoother
+ if smoother is not None and progress_id in smoother_progress_ids:
+ rendered_frames = self.scheduler.step(noise_pred, timestep, latents, to_final=True)
+ rendered_frames = self.decode_video(rendered_frames)
+ rendered_frames = smoother(rendered_frames, original_frames=input_frames)
+ target_latents = self.encode_video(rendered_frames)
+ noise_pred = self.scheduler.return_to_timestep(timestep, latents, target_latents)
+ latents = self.scheduler.step(noise_pred, timestep, latents)
+
+ # UI
+ if progress_bar_st is not None:
+ progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
+
+ # Decode image
+ output_frames = self.decode_video(latents, **tiler_kwargs)
+
+ # Post-process
+ if smoother is not None and (num_inference_steps in smoother_progress_ids or -1 in smoother_progress_ids):
+ output_frames = smoother(output_frames, original_frames=input_frames)
+
+ return output_frames
diff --git a/diffsynth/pipelines/sdxl_image.py b/diffsynth/pipelines/sdxl_image.py
new file mode 100644
index 0000000000000000000000000000000000000000..499c4bbce707fa7cfd026c66af8c8dca3e554127
--- /dev/null
+++ b/diffsynth/pipelines/sdxl_image.py
@@ -0,0 +1,226 @@
+from ..models import SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder, SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
+from ..models.kolors_text_encoder import ChatGLMModel
+from ..models.model_manager import ModelManager
+from ..controlnets import MultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
+from ..prompters import SDXLPrompter, KolorsPrompter
+from ..schedulers import EnhancedDDIMScheduler
+from .base import BasePipeline
+from .dancer import lets_dance_xl
+from typing import List
+import torch
+from tqdm import tqdm
+from einops import repeat
+
+
+
+class SDXLImagePipeline(BasePipeline):
+
+ def __init__(self, device="cuda", torch_dtype=torch.float16):
+ super().__init__(device=device, torch_dtype=torch_dtype)
+ self.scheduler = EnhancedDDIMScheduler()
+ self.prompter = SDXLPrompter()
+ # models
+ self.text_encoder: SDXLTextEncoder = None
+ self.text_encoder_2: SDXLTextEncoder2 = None
+ self.text_encoder_kolors: ChatGLMModel = None
+ self.unet: SDXLUNet = None
+ self.vae_decoder: SDXLVAEDecoder = None
+ self.vae_encoder: SDXLVAEEncoder = None
+ self.controlnet: MultiControlNetManager = None
+ self.ipadapter_image_encoder: IpAdapterXLCLIPImageEmbedder = None
+ self.ipadapter: SDXLIpAdapter = None
+ self.model_names = ['text_encoder', 'text_encoder_2', 'text_encoder_kolors', 'unet', 'vae_decoder', 'vae_encoder', 'controlnet', 'ipadapter_image_encoder', 'ipadapter']
+
+
+ def denoising_model(self):
+ return self.unet
+
+
+ def fetch_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[]):
+ # Main models
+ self.text_encoder = model_manager.fetch_model("sdxl_text_encoder")
+ self.text_encoder_2 = model_manager.fetch_model("sdxl_text_encoder_2")
+ self.text_encoder_kolors = model_manager.fetch_model("kolors_text_encoder")
+ self.unet = model_manager.fetch_model("sdxl_unet")
+ self.vae_decoder = model_manager.fetch_model("sdxl_vae_decoder")
+ self.vae_encoder = model_manager.fetch_model("sdxl_vae_encoder")
+
+ # ControlNets
+ controlnet_units = []
+ for config in controlnet_config_units:
+ controlnet_unit = ControlNetUnit(
+ Annotator(config.processor_id, device=self.device),
+ model_manager.fetch_model("sdxl_controlnet", config.model_path),
+ config.scale
+ )
+ controlnet_units.append(controlnet_unit)
+ self.controlnet = MultiControlNetManager(controlnet_units)
+
+ # IP-Adapters
+ self.ipadapter = model_manager.fetch_model("sdxl_ipadapter")
+ self.ipadapter_image_encoder = model_manager.fetch_model("sdxl_ipadapter_clip_image_encoder")
+
+ # Kolors
+ if self.text_encoder_kolors is not None:
+ print("Switch to Kolors. The prompter and scheduler will be replaced.")
+ self.prompter = KolorsPrompter()
+ self.prompter.fetch_models(self.text_encoder_kolors)
+ self.scheduler = EnhancedDDIMScheduler(beta_end=0.014, num_train_timesteps=1100)
+ else:
+ self.prompter.fetch_models(self.text_encoder, self.text_encoder_2)
+ self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes)
+
+
+ @staticmethod
+ def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[], device=None):
+ pipe = SDXLImagePipeline(
+ device=model_manager.device if device is None else device,
+ torch_dtype=model_manager.torch_dtype,
+ )
+ pipe.fetch_models(model_manager, controlnet_config_units, prompt_refiner_classes)
+ return pipe
+
+
+ def encode_image(self, image, tiled=False, tile_size=64, tile_stride=32):
+ latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
+ return latents
+
+
+ def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
+ image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
+ image = self.vae_output_to_image(image)
+ return image
+
+
+ def encode_prompt(self, prompt, clip_skip=1, clip_skip_2=2, positive=True):
+ add_prompt_emb, prompt_emb = self.prompter.encode_prompt(
+ prompt,
+ clip_skip=clip_skip, clip_skip_2=clip_skip_2,
+ device=self.device,
+ positive=positive,
+ )
+ return {"encoder_hidden_states": prompt_emb, "add_text_embeds": add_prompt_emb}
+
+
+ def prepare_extra_input(self, latents=None):
+ height, width = latents.shape[2] * 8, latents.shape[3] * 8
+ add_time_id = torch.tensor([height, width, 0, 0, height, width], device=self.device).repeat(latents.shape[0])
+ return {"add_time_id": add_time_id}
+
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt,
+ local_prompts=[],
+ masks=[],
+ mask_scales=[],
+ negative_prompt="",
+ cfg_scale=7.5,
+ clip_skip=1,
+ clip_skip_2=2,
+ input_image=None,
+ ipadapter_images=None,
+ ipadapter_scale=1.0,
+ ipadapter_use_instant_style=False,
+ controlnet_image=None,
+ denoising_strength=1.0,
+ height=1024,
+ width=1024,
+ num_inference_steps=20,
+ tiled=False,
+ tile_size=64,
+ tile_stride=32,
+ seed=None,
+ progress_bar_cmd=tqdm,
+ progress_bar_st=None,
+ ):
+ height, width = self.check_resize_height_width(height, width)
+
+ # Tiler parameters
+ tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
+
+ # Prepare scheduler
+ self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
+
+ # Prepare latent tensors
+ if input_image is not None:
+ self.load_models_to_device(['vae_encoder'])
+ image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
+ latents = self.encode_image(image, **tiler_kwargs)
+ noise = self.generate_noise((1, 4, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
+ latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
+ else:
+ latents = self.generate_noise((1, 4, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
+
+ # Encode prompts
+ self.load_models_to_device(['text_encoder', 'text_encoder_2', 'text_encoder_kolors'])
+ prompt_emb_posi = self.encode_prompt(prompt, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=True)
+ prompt_emb_nega = self.encode_prompt(negative_prompt, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=False)
+ prompt_emb_locals = [self.encode_prompt(prompt_local, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=True) for prompt_local in local_prompts]
+
+ # IP-Adapter
+ if ipadapter_images is not None:
+ if ipadapter_use_instant_style:
+ self.ipadapter.set_less_adapter()
+ else:
+ self.ipadapter.set_full_adapter()
+ self.load_models_to_device(['ipadapter_image_encoder'])
+ ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images)
+ self.load_models_to_device(['ipadapter'])
+ ipadapter_kwargs_list_posi = {"ipadapter_kwargs_list": self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)}
+ ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": self.ipadapter(torch.zeros_like(ipadapter_image_encoding))}
+ else:
+ ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": {}}, {"ipadapter_kwargs_list": {}}
+
+ # Prepare ControlNets
+ if controlnet_image is not None:
+ self.load_models_to_device(['controlnet'])
+ controlnet_image = self.controlnet.process_image(controlnet_image).to(device=self.device, dtype=self.torch_dtype)
+ controlnet_image = controlnet_image.unsqueeze(1)
+ controlnet_kwargs = {"controlnet_frames": controlnet_image}
+ else:
+ controlnet_kwargs = {"controlnet_frames": None}
+
+ # Prepare extra input
+ extra_input = self.prepare_extra_input(latents)
+
+ # Denoise
+ self.load_models_to_device(['controlnet', 'unet'])
+ for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
+ timestep = timestep.unsqueeze(0).to(self.device)
+
+ # Classifier-free guidance
+ inference_callback = lambda prompt_emb_posi: lets_dance_xl(
+ self.unet, motion_modules=None, controlnet=self.controlnet,
+ sample=latents, timestep=timestep, **extra_input,
+ **prompt_emb_posi, **controlnet_kwargs, **tiler_kwargs, **ipadapter_kwargs_list_posi,
+ device=self.device,
+ )
+ noise_pred_posi = self.control_noise_via_local_prompts(prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback)
+
+ if cfg_scale != 1.0:
+ noise_pred_nega = lets_dance_xl(
+ self.unet, motion_modules=None, controlnet=self.controlnet,
+ sample=latents, timestep=timestep, **extra_input,
+ **prompt_emb_nega, **controlnet_kwargs, **tiler_kwargs, **ipadapter_kwargs_list_nega,
+ device=self.device,
+ )
+ noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
+ else:
+ noise_pred = noise_pred_posi
+
+ # DDIM
+ latents = self.scheduler.step(noise_pred, timestep, latents)
+
+ # UI
+ if progress_bar_st is not None:
+ progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
+
+ # Decode image
+ self.load_models_to_device(['vae_decoder'])
+ image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
+
+ # offload all models
+ self.load_models_to_device([])
+ return image
diff --git a/diffsynth/pipelines/sdxl_video.py b/diffsynth/pipelines/sdxl_video.py
new file mode 100644
index 0000000000000000000000000000000000000000..308590ca6a874c5803da95db1d90fced26126893
--- /dev/null
+++ b/diffsynth/pipelines/sdxl_video.py
@@ -0,0 +1,226 @@
+from ..models import SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder, SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder, SDXLMotionModel
+from ..models.kolors_text_encoder import ChatGLMModel
+from ..models.model_manager import ModelManager
+from ..controlnets import MultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
+from ..prompters import SDXLPrompter, KolorsPrompter
+from ..schedulers import EnhancedDDIMScheduler
+from .sdxl_image import SDXLImagePipeline
+from .dancer import lets_dance_xl
+from typing import List
+import torch
+from tqdm import tqdm
+
+
+
+class SDXLVideoPipeline(SDXLImagePipeline):
+
+ def __init__(self, device="cuda", torch_dtype=torch.float16, use_original_animatediff=True):
+ super().__init__(device=device, torch_dtype=torch_dtype)
+ self.scheduler = EnhancedDDIMScheduler(beta_schedule="linear" if use_original_animatediff else "scaled_linear")
+ self.prompter = SDXLPrompter()
+ # models
+ self.text_encoder: SDXLTextEncoder = None
+ self.text_encoder_2: SDXLTextEncoder2 = None
+ self.text_encoder_kolors: ChatGLMModel = None
+ self.unet: SDXLUNet = None
+ self.vae_decoder: SDXLVAEDecoder = None
+ self.vae_encoder: SDXLVAEEncoder = None
+ # self.controlnet: MultiControlNetManager = None (TODO)
+ self.ipadapter_image_encoder: IpAdapterXLCLIPImageEmbedder = None
+ self.ipadapter: SDXLIpAdapter = None
+ self.motion_modules: SDXLMotionModel = None
+
+
+ def fetch_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[]):
+ # Main models
+ self.text_encoder = model_manager.fetch_model("sdxl_text_encoder")
+ self.text_encoder_2 = model_manager.fetch_model("sdxl_text_encoder_2")
+ self.text_encoder_kolors = model_manager.fetch_model("kolors_text_encoder")
+ self.unet = model_manager.fetch_model("sdxl_unet")
+ self.vae_decoder = model_manager.fetch_model("sdxl_vae_decoder")
+ self.vae_encoder = model_manager.fetch_model("sdxl_vae_encoder")
+ self.prompter.fetch_models(self.text_encoder)
+ self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes)
+
+ # ControlNets (TODO)
+
+ # IP-Adapters
+ self.ipadapter = model_manager.fetch_model("sdxl_ipadapter")
+ self.ipadapter_image_encoder = model_manager.fetch_model("sdxl_ipadapter_clip_image_encoder")
+
+ # Motion Modules
+ self.motion_modules = model_manager.fetch_model("sdxl_motion_modules")
+ if self.motion_modules is None:
+ self.scheduler = EnhancedDDIMScheduler(beta_schedule="scaled_linear")
+
+ # Kolors
+ if self.text_encoder_kolors is not None:
+ print("Switch to Kolors. The prompter will be replaced.")
+ self.prompter = KolorsPrompter()
+ self.prompter.fetch_models(self.text_encoder_kolors)
+ # The schedulers of AniamteDiff and Kolors are incompatible. We align it with AniamteDiff.
+ if self.motion_modules is None:
+ self.scheduler = EnhancedDDIMScheduler(beta_end=0.014, num_train_timesteps=1100)
+ else:
+ self.prompter.fetch_models(self.text_encoder, self.text_encoder_2)
+
+
+ @staticmethod
+ def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[]):
+ pipe = SDXLVideoPipeline(
+ device=model_manager.device,
+ torch_dtype=model_manager.torch_dtype,
+ )
+ pipe.fetch_models(model_manager, controlnet_config_units, prompt_refiner_classes)
+ return pipe
+
+
+ def decode_video(self, latents, tiled=False, tile_size=64, tile_stride=32):
+ images = [
+ self.decode_image(latents[frame_id: frame_id+1], tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
+ for frame_id in range(latents.shape[0])
+ ]
+ return images
+
+
+ def encode_video(self, processed_images, tiled=False, tile_size=64, tile_stride=32):
+ latents = []
+ for image in processed_images:
+ image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype)
+ latent = self.encode_image(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
+ latents.append(latent.cpu())
+ latents = torch.concat(latents, dim=0)
+ return latents
+
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt,
+ negative_prompt="",
+ cfg_scale=7.5,
+ clip_skip=1,
+ num_frames=None,
+ input_frames=None,
+ ipadapter_images=None,
+ ipadapter_scale=1.0,
+ ipadapter_use_instant_style=False,
+ controlnet_frames=None,
+ denoising_strength=1.0,
+ height=512,
+ width=512,
+ num_inference_steps=20,
+ animatediff_batch_size = 16,
+ animatediff_stride = 8,
+ unet_batch_size = 1,
+ controlnet_batch_size = 1,
+ cross_frame_attention = False,
+ smoother=None,
+ smoother_progress_ids=[],
+ tiled=False,
+ tile_size=64,
+ tile_stride=32,
+ seed=None,
+ progress_bar_cmd=tqdm,
+ progress_bar_st=None,
+ ):
+ height, width = self.check_resize_height_width(height, width)
+
+ # Tiler parameters, batch size ...
+ tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
+
+ # Prepare scheduler
+ self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
+
+ # Prepare latent tensors
+ if self.motion_modules is None:
+ noise = self.generate_noise((1, 4, height//8, width//8), seed=seed, device="cpu", dtype=self.torch_dtype).repeat(num_frames, 1, 1, 1)
+ else:
+ noise = self.generate_noise((num_frames, 4, height//8, width//8), seed=seed, device="cpu", dtype=self.torch_dtype)
+ if input_frames is None or denoising_strength == 1.0:
+ latents = noise
+ else:
+ latents = self.encode_video(input_frames, **tiler_kwargs)
+ latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
+ latents = latents.to(self.device) # will be deleted for supporting long videos
+
+ # Encode prompts
+ prompt_emb_posi = self.encode_prompt(prompt, clip_skip=clip_skip, positive=True)
+ prompt_emb_nega = self.encode_prompt(negative_prompt, clip_skip=clip_skip, positive=False)
+
+ # IP-Adapter
+ if ipadapter_images is not None:
+ if ipadapter_use_instant_style:
+ self.ipadapter.set_less_adapter()
+ else:
+ self.ipadapter.set_full_adapter()
+ ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images)
+ ipadapter_kwargs_list_posi = {"ipadapter_kwargs_list": self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)}
+ ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": self.ipadapter(torch.zeros_like(ipadapter_image_encoding))}
+ else:
+ ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": {}}, {"ipadapter_kwargs_list": {}}
+
+ # Prepare ControlNets
+ if controlnet_frames is not None:
+ if isinstance(controlnet_frames[0], list):
+ controlnet_frames_ = []
+ for processor_id in range(len(controlnet_frames)):
+ controlnet_frames_.append(
+ torch.stack([
+ self.controlnet.process_image(controlnet_frame, processor_id=processor_id).to(self.torch_dtype)
+ for controlnet_frame in progress_bar_cmd(controlnet_frames[processor_id])
+ ], dim=1)
+ )
+ controlnet_frames = torch.concat(controlnet_frames_, dim=0)
+ else:
+ controlnet_frames = torch.stack([
+ self.controlnet.process_image(controlnet_frame).to(self.torch_dtype)
+ for controlnet_frame in progress_bar_cmd(controlnet_frames)
+ ], dim=1)
+ controlnet_kwargs = {"controlnet_frames": controlnet_frames}
+ else:
+ controlnet_kwargs = {"controlnet_frames": None}
+
+ # Prepare extra input
+ extra_input = self.prepare_extra_input(latents)
+
+ # Denoise
+ for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
+ timestep = timestep.unsqueeze(0).to(self.device)
+
+ # Classifier-free guidance
+ noise_pred_posi = lets_dance_xl(
+ self.unet, motion_modules=self.motion_modules, controlnet=None,
+ sample=latents, timestep=timestep,
+ **prompt_emb_posi, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **extra_input, **tiler_kwargs,
+ device=self.device,
+ )
+ noise_pred_nega = lets_dance_xl(
+ self.unet, motion_modules=self.motion_modules, controlnet=None,
+ sample=latents, timestep=timestep,
+ **prompt_emb_nega, **controlnet_kwargs, **ipadapter_kwargs_list_nega, **extra_input, **tiler_kwargs,
+ device=self.device,
+ )
+ noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
+
+ # DDIM and smoother
+ if smoother is not None and progress_id in smoother_progress_ids:
+ rendered_frames = self.scheduler.step(noise_pred, timestep, latents, to_final=True)
+ rendered_frames = self.decode_video(rendered_frames)
+ rendered_frames = smoother(rendered_frames, original_frames=input_frames)
+ target_latents = self.encode_video(rendered_frames)
+ noise_pred = self.scheduler.return_to_timestep(timestep, latents, target_latents)
+ latents = self.scheduler.step(noise_pred, timestep, latents)
+
+ # UI
+ if progress_bar_st is not None:
+ progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
+
+ # Decode image
+ output_frames = self.decode_video(latents, **tiler_kwargs)
+
+ # Post-process
+ if smoother is not None and (num_inference_steps in smoother_progress_ids or -1 in smoother_progress_ids):
+ output_frames = smoother(output_frames, original_frames=input_frames)
+
+ return output_frames
diff --git a/diffsynth/pipelines/step_video.py b/diffsynth/pipelines/step_video.py
new file mode 100644
index 0000000000000000000000000000000000000000..56140178e9d6cdaf5efeca77ea061f8232836f11
--- /dev/null
+++ b/diffsynth/pipelines/step_video.py
@@ -0,0 +1,209 @@
+from ..models import ModelManager
+from ..models.hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder
+from ..models.stepvideo_text_encoder import STEP1TextEncoder
+from ..models.stepvideo_dit import StepVideoModel
+from ..models.stepvideo_vae import StepVideoVAE
+from ..schedulers.flow_match import FlowMatchScheduler
+from .base import BasePipeline
+from ..prompters import StepVideoPrompter
+import torch
+from einops import rearrange
+import numpy as np
+from PIL import Image
+from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear
+from transformers.models.bert.modeling_bert import BertEmbeddings
+from ..models.stepvideo_dit import RMSNorm
+from ..models.stepvideo_vae import CausalConv, CausalConvAfterNorm, Upsample2D, BaseGroupNorm
+
+
+
+class StepVideoPipeline(BasePipeline):
+
+ def __init__(self, device="cuda", torch_dtype=torch.float16):
+ super().__init__(device=device, torch_dtype=torch_dtype)
+ self.scheduler = FlowMatchScheduler(sigma_min=0.0, extra_one_step=True, shift=13.0, reverse_sigmas=True, num_train_timesteps=1)
+ self.prompter = StepVideoPrompter()
+ self.text_encoder_1: HunyuanDiTCLIPTextEncoder = None
+ self.text_encoder_2: STEP1TextEncoder = None
+ self.dit: StepVideoModel = None
+ self.vae: StepVideoVAE = None
+ self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae']
+
+
+ def enable_vram_management(self, num_persistent_param_in_dit=None):
+ dtype = next(iter(self.text_encoder_1.parameters())).dtype
+ enable_vram_management(
+ self.text_encoder_1,
+ module_map = {
+ torch.nn.Linear: AutoWrappedLinear,
+ BertEmbeddings: AutoWrappedModule,
+ torch.nn.LayerNorm: AutoWrappedModule,
+ },
+ module_config = dict(
+ offload_dtype=dtype,
+ offload_device="cpu",
+ onload_dtype=dtype,
+ onload_device="cpu",
+ computation_dtype=torch.float32,
+ computation_device=self.device,
+ ),
+ )
+ dtype = next(iter(self.text_encoder_2.parameters())).dtype
+ enable_vram_management(
+ self.text_encoder_2,
+ module_map = {
+ torch.nn.Linear: AutoWrappedLinear,
+ RMSNorm: AutoWrappedModule,
+ torch.nn.Embedding: AutoWrappedModule,
+ },
+ module_config = dict(
+ offload_dtype=dtype,
+ offload_device="cpu",
+ onload_dtype=dtype,
+ onload_device="cpu",
+ computation_dtype=self.torch_dtype,
+ computation_device=self.device,
+ ),
+ )
+ dtype = next(iter(self.dit.parameters())).dtype
+ enable_vram_management(
+ self.dit,
+ module_map = {
+ torch.nn.Linear: AutoWrappedLinear,
+ torch.nn.Conv2d: AutoWrappedModule,
+ torch.nn.LayerNorm: AutoWrappedModule,
+ RMSNorm: AutoWrappedModule,
+ },
+ module_config = dict(
+ offload_dtype=dtype,
+ offload_device="cpu",
+ onload_dtype=dtype,
+ onload_device=self.device,
+ computation_dtype=self.torch_dtype,
+ computation_device=self.device,
+ ),
+ max_num_param=num_persistent_param_in_dit,
+ overflow_module_config = dict(
+ offload_dtype=dtype,
+ offload_device="cpu",
+ onload_dtype=dtype,
+ onload_device="cpu",
+ computation_dtype=self.torch_dtype,
+ computation_device=self.device,
+ ),
+ )
+ dtype = next(iter(self.vae.parameters())).dtype
+ enable_vram_management(
+ self.vae,
+ module_map = {
+ torch.nn.Linear: AutoWrappedLinear,
+ torch.nn.Conv3d: AutoWrappedModule,
+ CausalConv: AutoWrappedModule,
+ CausalConvAfterNorm: AutoWrappedModule,
+ Upsample2D: AutoWrappedModule,
+ BaseGroupNorm: AutoWrappedModule,
+ },
+ module_config = dict(
+ offload_dtype=dtype,
+ offload_device="cpu",
+ onload_dtype=dtype,
+ onload_device="cpu",
+ computation_dtype=self.torch_dtype,
+ computation_device=self.device,
+ ),
+ )
+ self.enable_cpu_offload()
+
+
+ def fetch_models(self, model_manager: ModelManager):
+ self.text_encoder_1 = model_manager.fetch_model("hunyuan_dit_clip_text_encoder")
+ self.text_encoder_2 = model_manager.fetch_model("stepvideo_text_encoder_2")
+ self.dit = model_manager.fetch_model("stepvideo_dit")
+ self.vae = model_manager.fetch_model("stepvideo_vae")
+ self.prompter.fetch_models(self.text_encoder_1, self.text_encoder_2)
+
+
+ @staticmethod
+ def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None):
+ if device is None: device = model_manager.device
+ if torch_dtype is None: torch_dtype = model_manager.torch_dtype
+ pipe = StepVideoPipeline(device=device, torch_dtype=torch_dtype)
+ pipe.fetch_models(model_manager)
+ return pipe
+
+
+ def encode_prompt(self, prompt, positive=True):
+ clip_embeds, llm_embeds, llm_mask = self.prompter.encode_prompt(prompt, device=self.device, positive=positive)
+ clip_embeds = clip_embeds.to(dtype=self.torch_dtype, device=self.device)
+ llm_embeds = llm_embeds.to(dtype=self.torch_dtype, device=self.device)
+ llm_mask = llm_mask.to(dtype=self.torch_dtype, device=self.device)
+ return {"encoder_hidden_states_2": clip_embeds, "encoder_hidden_states": llm_embeds, "encoder_attention_mask": llm_mask}
+
+
+ def tensor2video(self, frames):
+ frames = rearrange(frames, "C T H W -> T H W C")
+ frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8)
+ frames = [Image.fromarray(frame) for frame in frames]
+ return frames
+
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt,
+ negative_prompt="",
+ input_video=None,
+ denoising_strength=1.0,
+ seed=None,
+ rand_device="cpu",
+ height=544,
+ width=992,
+ num_frames=204,
+ cfg_scale=9.0,
+ num_inference_steps=30,
+ tiled=True,
+ tile_size=(34, 34),
+ tile_stride=(16, 16),
+ smooth_scale=0.6,
+ progress_bar_cmd=lambda x: x,
+ progress_bar_st=None,
+ ):
+ # Tiler parameters
+ tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
+
+ # Scheduler
+ self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
+
+ # Initialize noise
+ latents = self.generate_noise((1, max(num_frames//17*3, 1), 64, height//16, width//16), seed=seed, device=rand_device, dtype=self.torch_dtype).to(self.device)
+
+ # Encode prompts
+ self.load_models_to_device(["text_encoder_1", "text_encoder_2"])
+ prompt_emb_posi = self.encode_prompt(prompt, positive=True)
+ if cfg_scale != 1.0:
+ prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False)
+
+ # Denoise
+ self.load_models_to_device(["dit"])
+ for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
+ timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
+ print(f"Step {progress_id + 1} / {len(self.scheduler.timesteps)}")
+
+ # Inference
+ noise_pred_posi = self.dit(latents, timestep=timestep, **prompt_emb_posi)
+ if cfg_scale != 1.0:
+ noise_pred_nega = self.dit(latents, timestep=timestep, **prompt_emb_nega)
+ noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
+ else:
+ noise_pred = noise_pred_posi
+
+ # Scheduler
+ latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
+
+ # Decode
+ self.load_models_to_device(['vae'])
+ frames = self.vae.decode(latents, device=self.device, smooth_scale=smooth_scale, **tiler_kwargs)
+ self.load_models_to_device([])
+ frames = self.tensor2video(frames[0])
+
+ return frames
diff --git a/diffsynth/pipelines/svd_video.py b/diffsynth/pipelines/svd_video.py
new file mode 100644
index 0000000000000000000000000000000000000000..b71597efa73783f7e3746a2bcf6b7be5c70c360e
--- /dev/null
+++ b/diffsynth/pipelines/svd_video.py
@@ -0,0 +1,300 @@
+from ..models import ModelManager, SVDImageEncoder, SVDUNet, SVDVAEEncoder, SVDVAEDecoder
+from ..schedulers import ContinuousODEScheduler
+from .base import BasePipeline
+import torch
+from tqdm import tqdm
+from PIL import Image
+import numpy as np
+from einops import rearrange, repeat
+
+
+
+class SVDVideoPipeline(BasePipeline):
+
+ def __init__(self, device="cuda", torch_dtype=torch.float16):
+ super().__init__(device=device, torch_dtype=torch_dtype)
+ self.scheduler = ContinuousODEScheduler()
+ # models
+ self.image_encoder: SVDImageEncoder = None
+ self.unet: SVDUNet = None
+ self.vae_encoder: SVDVAEEncoder = None
+ self.vae_decoder: SVDVAEDecoder = None
+
+
+ def fetch_models(self, model_manager: ModelManager):
+ self.image_encoder = model_manager.fetch_model("svd_image_encoder")
+ self.unet = model_manager.fetch_model("svd_unet")
+ self.vae_encoder = model_manager.fetch_model("svd_vae_encoder")
+ self.vae_decoder = model_manager.fetch_model("svd_vae_decoder")
+
+
+ @staticmethod
+ def from_model_manager(model_manager: ModelManager, **kwargs):
+ pipe = SVDVideoPipeline(
+ device=model_manager.device,
+ torch_dtype=model_manager.torch_dtype
+ )
+ pipe.fetch_models(model_manager)
+ return pipe
+
+
+ def encode_image_with_clip(self, image):
+ image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype)
+ image = SVDCLIPImageProcessor().resize_with_antialiasing(image, (224, 224))
+ image = (image + 1.0) / 2.0
+ mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).reshape(1, 3, 1, 1).to(device=self.device, dtype=self.torch_dtype)
+ std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).reshape(1, 3, 1, 1).to(device=self.device, dtype=self.torch_dtype)
+ image = (image - mean) / std
+ image_emb = self.image_encoder(image)
+ return image_emb
+
+
+ def encode_image_with_vae(self, image, noise_aug_strength, seed=None):
+ image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype)
+ noise = self.generate_noise(image.shape, seed=seed, device=self.device, dtype=self.torch_dtype)
+ image = image + noise_aug_strength * noise
+ image_emb = self.vae_encoder(image) / self.vae_encoder.scaling_factor
+ return image_emb
+
+
+ def encode_video_with_vae(self, video):
+ video = torch.concat([self.preprocess_image(frame) for frame in video], dim=0)
+ video = rearrange(video, "T C H W -> 1 C T H W")
+ video = video.to(device=self.device, dtype=self.torch_dtype)
+ latents = self.vae_encoder.encode_video(video)
+ latents = rearrange(latents[0], "C T H W -> T C H W")
+ return latents
+
+
+ def tensor2video(self, frames):
+ frames = rearrange(frames, "C T H W -> T H W C")
+ frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8)
+ frames = [Image.fromarray(frame) for frame in frames]
+ return frames
+
+
+ def calculate_noise_pred(
+ self,
+ latents,
+ timestep,
+ add_time_id,
+ cfg_scales,
+ image_emb_vae_posi, image_emb_clip_posi,
+ image_emb_vae_nega, image_emb_clip_nega
+ ):
+ # Positive side
+ noise_pred_posi = self.unet(
+ torch.cat([latents, image_emb_vae_posi], dim=1),
+ timestep, image_emb_clip_posi, add_time_id
+ )
+ # Negative side
+ noise_pred_nega = self.unet(
+ torch.cat([latents, image_emb_vae_nega], dim=1),
+ timestep, image_emb_clip_nega, add_time_id
+ )
+
+ # Classifier-free guidance
+ noise_pred = noise_pred_nega + cfg_scales * (noise_pred_posi - noise_pred_nega)
+
+ return noise_pred
+
+
+ def post_process_latents(self, latents, post_normalize=True, contrast_enhance_scale=1.0):
+ if post_normalize:
+ mean, std = latents.mean(), latents.std()
+ latents = (latents - latents.mean(dim=[1, 2, 3], keepdim=True)) / latents.std(dim=[1, 2, 3], keepdim=True) * std + mean
+ latents = latents * contrast_enhance_scale
+ return latents
+
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ input_image=None,
+ input_video=None,
+ mask_frames=[],
+ mask_frame_ids=[],
+ min_cfg_scale=1.0,
+ max_cfg_scale=3.0,
+ denoising_strength=1.0,
+ num_frames=25,
+ height=576,
+ width=1024,
+ fps=7,
+ motion_bucket_id=127,
+ noise_aug_strength=0.02,
+ num_inference_steps=20,
+ post_normalize=True,
+ contrast_enhance_scale=1.2,
+ seed=None,
+ progress_bar_cmd=tqdm,
+ progress_bar_st=None,
+ ):
+ height, width = self.check_resize_height_width(height, width)
+
+ # Prepare scheduler
+ self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength)
+
+ # Prepare latent tensors
+ noise = self.generate_noise((num_frames, 4, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
+ if denoising_strength == 1.0:
+ latents = noise.clone()
+ else:
+ latents = self.encode_video_with_vae(input_video)
+ latents = self.scheduler.add_noise(latents, noise, self.scheduler.timesteps[0])
+
+ # Prepare mask frames
+ if len(mask_frames) > 0:
+ mask_latents = self.encode_video_with_vae(mask_frames)
+
+ # Encode image
+ image_emb_clip_posi = self.encode_image_with_clip(input_image)
+ image_emb_clip_nega = torch.zeros_like(image_emb_clip_posi)
+ image_emb_vae_posi = repeat(self.encode_image_with_vae(input_image, noise_aug_strength, seed=seed), "B C H W -> (B T) C H W", T=num_frames)
+ image_emb_vae_nega = torch.zeros_like(image_emb_vae_posi)
+
+ # Prepare classifier-free guidance
+ cfg_scales = torch.linspace(min_cfg_scale, max_cfg_scale, num_frames)
+ cfg_scales = cfg_scales.reshape(num_frames, 1, 1, 1).to(device=self.device, dtype=self.torch_dtype)
+
+ # Prepare positional id
+ add_time_id = torch.tensor([[fps-1, motion_bucket_id, noise_aug_strength]], device=self.device)
+
+ # Denoise
+ for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
+
+ # Mask frames
+ for frame_id, mask_frame_id in enumerate(mask_frame_ids):
+ latents[mask_frame_id] = self.scheduler.add_noise(mask_latents[frame_id], noise[mask_frame_id], timestep)
+
+ # Fetch model output
+ noise_pred = self.calculate_noise_pred(
+ latents, timestep, add_time_id, cfg_scales,
+ image_emb_vae_posi, image_emb_clip_posi, image_emb_vae_nega, image_emb_clip_nega
+ )
+
+ # Forward Euler
+ latents = self.scheduler.step(noise_pred, timestep, latents)
+
+ # Update progress bar
+ if progress_bar_st is not None:
+ progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
+
+ # Decode image
+ latents = self.post_process_latents(latents, post_normalize=post_normalize, contrast_enhance_scale=contrast_enhance_scale)
+ video = self.vae_decoder.decode_video(latents, progress_bar=progress_bar_cmd)
+ video = self.tensor2video(video)
+
+ return video
+
+
+
+class SVDCLIPImageProcessor:
+ def __init__(self):
+ pass
+
+ def resize_with_antialiasing(self, input, size, interpolation="bicubic", align_corners=True):
+ h, w = input.shape[-2:]
+ factors = (h / size[0], w / size[1])
+
+ # First, we have to determine sigma
+ # Taken from skimage: https://github.com/scikit-image/scikit-image/blob/v0.19.2/skimage/transform/_warps.py#L171
+ sigmas = (
+ max((factors[0] - 1.0) / 2.0, 0.001),
+ max((factors[1] - 1.0) / 2.0, 0.001),
+ )
+
+ # Now kernel size. Good results are for 3 sigma, but that is kind of slow. Pillow uses 1 sigma
+ # https://github.com/python-pillow/Pillow/blob/master/src/libImaging/Resample.c#L206
+ # But they do it in the 2 passes, which gives better results. Let's try 2 sigmas for now
+ ks = int(max(2.0 * 2 * sigmas[0], 3)), int(max(2.0 * 2 * sigmas[1], 3))
+
+ # Make sure it is odd
+ if (ks[0] % 2) == 0:
+ ks = ks[0] + 1, ks[1]
+
+ if (ks[1] % 2) == 0:
+ ks = ks[0], ks[1] + 1
+
+ input = self._gaussian_blur2d(input, ks, sigmas)
+
+ output = torch.nn.functional.interpolate(input, size=size, mode=interpolation, align_corners=align_corners)
+ return output
+
+
+ def _compute_padding(self, kernel_size):
+ """Compute padding tuple."""
+ # 4 or 6 ints: (padding_left, padding_right,padding_top,padding_bottom)
+ # https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad
+ if len(kernel_size) < 2:
+ raise AssertionError(kernel_size)
+ computed = [k - 1 for k in kernel_size]
+
+ # for even kernels we need to do asymmetric padding :(
+ out_padding = 2 * len(kernel_size) * [0]
+
+ for i in range(len(kernel_size)):
+ computed_tmp = computed[-(i + 1)]
+
+ pad_front = computed_tmp // 2
+ pad_rear = computed_tmp - pad_front
+
+ out_padding[2 * i + 0] = pad_front
+ out_padding[2 * i + 1] = pad_rear
+
+ return out_padding
+
+
+ def _filter2d(self, input, kernel):
+ # prepare kernel
+ b, c, h, w = input.shape
+ tmp_kernel = kernel[:, None, ...].to(device=input.device, dtype=input.dtype)
+
+ tmp_kernel = tmp_kernel.expand(-1, c, -1, -1)
+
+ height, width = tmp_kernel.shape[-2:]
+
+ padding_shape: list[int] = self._compute_padding([height, width])
+ input = torch.nn.functional.pad(input, padding_shape, mode="reflect")
+
+ # kernel and input tensor reshape to align element-wise or batch-wise params
+ tmp_kernel = tmp_kernel.reshape(-1, 1, height, width)
+ input = input.view(-1, tmp_kernel.size(0), input.size(-2), input.size(-1))
+
+ # convolve the tensor with the kernel.
+ output = torch.nn.functional.conv2d(input, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1)
+
+ out = output.view(b, c, h, w)
+ return out
+
+
+ def _gaussian(self, window_size: int, sigma):
+ if isinstance(sigma, float):
+ sigma = torch.tensor([[sigma]])
+
+ batch_size = sigma.shape[0]
+
+ x = (torch.arange(window_size, device=sigma.device, dtype=sigma.dtype) - window_size // 2).expand(batch_size, -1)
+
+ if window_size % 2 == 0:
+ x = x + 0.5
+
+ gauss = torch.exp(-x.pow(2.0) / (2 * sigma.pow(2.0)))
+
+ return gauss / gauss.sum(-1, keepdim=True)
+
+
+ def _gaussian_blur2d(self, input, kernel_size, sigma):
+ if isinstance(sigma, tuple):
+ sigma = torch.tensor([sigma], dtype=input.dtype)
+ else:
+ sigma = sigma.to(dtype=input.dtype)
+
+ ky, kx = int(kernel_size[0]), int(kernel_size[1])
+ bs = sigma.shape[0]
+ kernel_x = self._gaussian(kx, sigma[:, 1].view(bs, 1))
+ kernel_y = self._gaussian(ky, sigma[:, 0].view(bs, 1))
+ out_x = self._filter2d(input, kernel_x[..., None, :])
+ out = self._filter2d(out_x, kernel_y[..., None])
+
+ return out
diff --git a/diffsynth/pipelines/wan_video.py b/diffsynth/pipelines/wan_video.py
new file mode 100644
index 0000000000000000000000000000000000000000..e70e0cc246e2d41ffd19056fd42060144ab6230c
--- /dev/null
+++ b/diffsynth/pipelines/wan_video.py
@@ -0,0 +1,626 @@
+import types
+from ..models import ModelManager
+from ..models.wan_video_dit import WanModel
+from ..models.wan_video_text_encoder import WanTextEncoder
+from ..models.wan_video_vae import WanVideoVAE
+from ..models.wan_video_image_encoder import WanImageEncoder
+from ..models.wan_video_vace import VaceWanModel
+from ..schedulers.flow_match import FlowMatchScheduler
+from .base import BasePipeline
+from ..prompters import WanPrompter
+import torch, os
+from einops import rearrange
+import numpy as np
+from PIL import Image
+from tqdm import tqdm
+from typing import Optional
+
+from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear
+from ..models.wan_video_text_encoder import T5RelativeEmbedding, T5LayerNorm
+from ..models.wan_video_dit import RMSNorm, sinusoidal_embedding_1d
+from ..models.wan_video_vae import RMS_norm, CausalConv3d, Upsample
+from ..models.wan_video_motion_controller import WanMotionControllerModel
+
+
+
+class WanVideoPipeline(BasePipeline):
+
+ def __init__(self, device="cuda", torch_dtype=torch.float16, tokenizer_path=None):
+ super().__init__(device=device, torch_dtype=torch_dtype)
+ self.scheduler = FlowMatchScheduler(shift=5, sigma_min=0.0, extra_one_step=True)
+ self.prompter = WanPrompter(tokenizer_path=tokenizer_path)
+ self.text_encoder: WanTextEncoder = None
+ self.image_encoder: WanImageEncoder = None
+ self.dit: WanModel = None
+ self.vae: WanVideoVAE = None
+ self.motion_controller: WanMotionControllerModel = None
+ self.vace: VaceWanModel = None
+ self.model_names = ['text_encoder', 'dit', 'vae', 'image_encoder', 'motion_controller', 'vace']
+ self.height_division_factor = 16
+ self.width_division_factor = 16
+ self.use_unified_sequence_parallel = False
+
+
+ def enable_vram_management(self, num_persistent_param_in_dit=None):
+ dtype = next(iter(self.text_encoder.parameters())).dtype
+ enable_vram_management(
+ self.text_encoder,
+ module_map = {
+ torch.nn.Linear: AutoWrappedLinear,
+ torch.nn.Embedding: AutoWrappedModule,
+ T5RelativeEmbedding: AutoWrappedModule,
+ T5LayerNorm: AutoWrappedModule,
+ },
+ module_config = dict(
+ offload_dtype=dtype,
+ offload_device="cpu",
+ onload_dtype=dtype,
+ onload_device="cpu",
+ computation_dtype=self.torch_dtype,
+ computation_device=self.device,
+ ),
+ )
+ dtype = next(iter(self.dit.parameters())).dtype
+ enable_vram_management(
+ self.dit,
+ module_map = {
+ torch.nn.Linear: AutoWrappedLinear,
+ torch.nn.Conv3d: AutoWrappedModule,
+ torch.nn.LayerNorm: AutoWrappedModule,
+ RMSNorm: AutoWrappedModule,
+ torch.nn.Conv2d: AutoWrappedModule,
+ },
+ module_config = dict(
+ offload_dtype=dtype,
+ offload_device="cpu",
+ onload_dtype=dtype,
+ onload_device=self.device,
+ computation_dtype=self.torch_dtype,
+ computation_device=self.device,
+ ),
+ max_num_param=num_persistent_param_in_dit,
+ overflow_module_config = dict(
+ offload_dtype=dtype,
+ offload_device="cpu",
+ onload_dtype=dtype,
+ onload_device="cpu",
+ computation_dtype=self.torch_dtype,
+ computation_device=self.device,
+ ),
+ )
+ dtype = next(iter(self.vae.parameters())).dtype
+ enable_vram_management(
+ self.vae,
+ module_map = {
+ torch.nn.Linear: AutoWrappedLinear,
+ torch.nn.Conv2d: AutoWrappedModule,
+ RMS_norm: AutoWrappedModule,
+ CausalConv3d: AutoWrappedModule,
+ Upsample: AutoWrappedModule,
+ torch.nn.SiLU: AutoWrappedModule,
+ torch.nn.Dropout: AutoWrappedModule,
+ },
+ module_config = dict(
+ offload_dtype=dtype,
+ offload_device="cpu",
+ onload_dtype=dtype,
+ onload_device=self.device,
+ computation_dtype=self.torch_dtype,
+ computation_device=self.device,
+ ),
+ )
+ if self.image_encoder is not None:
+ dtype = next(iter(self.image_encoder.parameters())).dtype
+ enable_vram_management(
+ self.image_encoder,
+ module_map = {
+ torch.nn.Linear: AutoWrappedLinear,
+ torch.nn.Conv2d: AutoWrappedModule,
+ torch.nn.LayerNorm: AutoWrappedModule,
+ },
+ module_config = dict(
+ offload_dtype=dtype,
+ offload_device="cpu",
+ onload_dtype=dtype,
+ onload_device="cpu",
+ computation_dtype=dtype,
+ computation_device=self.device,
+ ),
+ )
+ if self.motion_controller is not None:
+ dtype = next(iter(self.motion_controller.parameters())).dtype
+ enable_vram_management(
+ self.motion_controller,
+ module_map = {
+ torch.nn.Linear: AutoWrappedLinear,
+ },
+ module_config = dict(
+ offload_dtype=dtype,
+ offload_device="cpu",
+ onload_dtype=dtype,
+ onload_device="cpu",
+ computation_dtype=dtype,
+ computation_device=self.device,
+ ),
+ )
+ if self.vace is not None:
+ enable_vram_management(
+ self.vace,
+ module_map = {
+ torch.nn.Linear: AutoWrappedLinear,
+ torch.nn.Conv3d: AutoWrappedModule,
+ torch.nn.LayerNorm: AutoWrappedModule,
+ RMSNorm: AutoWrappedModule,
+ },
+ module_config = dict(
+ offload_dtype=dtype,
+ offload_device="cpu",
+ onload_dtype=dtype,
+ onload_device=self.device,
+ computation_dtype=self.torch_dtype,
+ computation_device=self.device,
+ ),
+ )
+ self.enable_cpu_offload()
+
+
+ def fetch_models(self, model_manager: ModelManager):
+ text_encoder_model_and_path = model_manager.fetch_model("wan_video_text_encoder", require_model_path=True)
+ if text_encoder_model_and_path is not None:
+ self.text_encoder, tokenizer_path = text_encoder_model_and_path
+ self.prompter.fetch_models(self.text_encoder)
+ self.prompter.fetch_tokenizer(os.path.join(os.path.dirname(tokenizer_path), "google/umt5-xxl"))
+ self.dit = model_manager.fetch_model("wan_video_dit")
+ self.vae = model_manager.fetch_model("wan_video_vae")
+ self.image_encoder = model_manager.fetch_model("wan_video_image_encoder")
+ self.motion_controller = model_manager.fetch_model("wan_video_motion_controller")
+ self.vace = model_manager.fetch_model("wan_video_vace")
+
+
+ @staticmethod
+ def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None, use_usp=False):
+ if device is None: device = model_manager.device
+ if torch_dtype is None: torch_dtype = model_manager.torch_dtype
+ pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype)
+ pipe.fetch_models(model_manager)
+ if use_usp:
+ from xfuser.core.distributed import get_sequence_parallel_world_size
+ from ..distributed.xdit_context_parallel import usp_attn_forward, usp_dit_forward
+
+ for block in pipe.dit.blocks:
+ block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
+ pipe.dit.forward = types.MethodType(usp_dit_forward, pipe.dit)
+ pipe.sp_size = get_sequence_parallel_world_size()
+ pipe.use_unified_sequence_parallel = True
+ return pipe
+
+
+ def denoising_model(self):
+ return self.dit
+
+
+ def encode_prompt(self, prompt, positive=True):
+ prompt_emb = self.prompter.encode_prompt(prompt, positive=positive, device=self.device)
+ return {"context": prompt_emb}
+
+
+ def encode_image(self, image, end_image, num_frames, height, width, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
+ image = self.preprocess_image(image.resize((width, height))).to(self.device)
+ clip_context = self.image_encoder.encode_image([image])
+ msk = torch.ones(1, num_frames, height//8, width//8, device=self.device)
+ msk[:, 1:] = 0
+ if end_image is not None:
+ end_image = self.preprocess_image(end_image.resize((width, height))).to(self.device)
+ vae_input = torch.concat([image.transpose(0,1), torch.zeros(3, num_frames-2, height, width).to(image.device), end_image.transpose(0,1)],dim=1)
+ if self.dit.has_image_pos_emb:
+ clip_context = torch.concat([clip_context, self.image_encoder.encode_image([end_image])], dim=1)
+ msk[:, -1:] = 1
+ else:
+ vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)
+
+ msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
+ msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
+ msk = msk.transpose(1, 2)[0]
+
+ y = self.vae.encode([vae_input.to(dtype=self.torch_dtype, device=self.device)], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
+ y = y.to(dtype=self.torch_dtype, device=self.device)
+ y = torch.concat([msk, y])
+ y = y.unsqueeze(0)
+ clip_context = clip_context.to(dtype=self.torch_dtype, device=self.device)
+ y = y.to(dtype=self.torch_dtype, device=self.device)
+ return {"clip_feature": clip_context, "y": y}
+
+
+ def encode_control_video(self, control_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
+ control_video = self.preprocess_images(control_video)
+ control_video = torch.stack(control_video, dim=2).to(dtype=self.torch_dtype, device=self.device)
+ latents = self.encode_video(control_video, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=self.torch_dtype, device=self.device)
+ return latents
+
+
+ def prepare_reference_image(self, reference_image, height, width):
+ if reference_image is not None:
+ self.load_models_to_device(["vae"])
+ reference_image = reference_image.resize((width, height))
+ reference_image = self.preprocess_images([reference_image])
+ reference_image = torch.stack(reference_image, dim=2).to(dtype=self.torch_dtype, device=self.device)
+ reference_latents = self.vae.encode(reference_image, device=self.device)
+ return {"reference_latents": reference_latents}
+ else:
+ return {}
+
+
+ def prepare_controlnet_kwargs(self, control_video, num_frames, height, width, clip_feature=None, y=None, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
+ if control_video is not None:
+ control_latents = self.encode_control_video(control_video, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
+ if clip_feature is None or y is None:
+ clip_feature = torch.zeros((1, 257, 1280), dtype=self.torch_dtype, device=self.device)
+ y = torch.zeros((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), dtype=self.torch_dtype, device=self.device)
+ else:
+ y = y[:, -16:]
+ y = torch.concat([control_latents, y], dim=1)
+ return {"clip_feature": clip_feature, "y": y}
+
+
+ def tensor2video(self, frames):
+ frames = rearrange(frames, "C T H W -> T H W C")
+ frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8)
+ frames = [Image.fromarray(frame) for frame in frames]
+ return frames
+
+
+ def prepare_extra_input(self, latents=None):
+ return {}
+
+
+ def encode_video(self, input_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
+ latents = self.vae.encode(input_video, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
+ return latents
+
+
+ def decode_video(self, latents, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
+ frames = self.vae.decode(latents, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
+ return frames
+
+
+ def prepare_unified_sequence_parallel(self):
+ return {"use_unified_sequence_parallel": self.use_unified_sequence_parallel}
+
+
+ def prepare_motion_bucket_id(self, motion_bucket_id):
+ motion_bucket_id = torch.Tensor((motion_bucket_id,)).to(dtype=self.torch_dtype, device=self.device)
+ return {"motion_bucket_id": motion_bucket_id}
+
+
+ def prepare_vace_kwargs(
+ self,
+ latents,
+ vace_video=None, vace_mask=None, vace_reference_image=None, vace_scale=1.0,
+ height=480, width=832, num_frames=81,
+ seed=None, rand_device="cpu",
+ tiled=True, tile_size=(34, 34), tile_stride=(18, 16)
+ ):
+ if vace_video is not None or vace_mask is not None or vace_reference_image is not None:
+ self.load_models_to_device(["vae"])
+ if vace_video is None:
+ vace_video = torch.zeros((1, 3, num_frames, height, width), dtype=self.torch_dtype, device=self.device)
+ else:
+ vace_video = self.preprocess_images(vace_video)
+ vace_video = torch.stack(vace_video, dim=2).to(dtype=self.torch_dtype, device=self.device)
+
+ if vace_mask is None:
+ vace_mask = torch.ones_like(vace_video)
+ else:
+ vace_mask = self.preprocess_images(vace_mask)
+ vace_mask = torch.stack(vace_mask, dim=2).to(dtype=self.torch_dtype, device=self.device)
+
+ inactive = vace_video * (1 - vace_mask) + 0 * vace_mask
+ reactive = vace_video * vace_mask + 0 * (1 - vace_mask)
+ inactive = self.encode_video(inactive, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=self.torch_dtype, device=self.device)
+ reactive = self.encode_video(reactive, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=self.torch_dtype, device=self.device)
+ vace_video_latents = torch.concat((inactive, reactive), dim=1)
+
+ vace_mask_latents = rearrange(vace_mask[0,0], "T (H P) (W Q) -> 1 (P Q) T H W", P=8, Q=8)
+ vace_mask_latents = torch.nn.functional.interpolate(vace_mask_latents, size=((vace_mask_latents.shape[2] + 3) // 4, vace_mask_latents.shape[3], vace_mask_latents.shape[4]), mode='nearest-exact')
+
+ if vace_reference_image is None:
+ pass
+ else:
+ vace_reference_image = self.preprocess_images([vace_reference_image])
+ vace_reference_image = torch.stack(vace_reference_image, dim=2).to(dtype=self.torch_dtype, device=self.device)
+ vace_reference_latents = self.encode_video(vace_reference_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=self.torch_dtype, device=self.device)
+ vace_reference_latents = torch.concat((vace_reference_latents, torch.zeros_like(vace_reference_latents)), dim=1)
+ vace_video_latents = torch.concat((vace_reference_latents, vace_video_latents), dim=2)
+ vace_mask_latents = torch.concat((torch.zeros_like(vace_mask_latents[:, :, :1]), vace_mask_latents), dim=2)
+
+ noise = self.generate_noise((1, 16, 1, latents.shape[3], latents.shape[4]), seed=seed, device=rand_device, dtype=torch.float32)
+ noise = noise.to(dtype=self.torch_dtype, device=self.device)
+ latents = torch.concat((noise, latents), dim=2)
+
+ vace_context = torch.concat((vace_video_latents, vace_mask_latents), dim=1)
+ return latents, {"vace_context": vace_context, "vace_scale": vace_scale}
+ else:
+ return latents, {"vace_context": None, "vace_scale": vace_scale}
+
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt,
+ negative_prompt="",
+ input_image=None,
+ end_image=None,
+ input_video=None,
+ control_video=None,
+ reference_image=None,
+ vace_video=None,
+ vace_video_mask=None,
+ vace_reference_image=None,
+ vace_scale=1.0,
+ denoising_strength=1.0,
+ seed=None,
+ rand_device="cpu",
+ height=480,
+ width=832,
+ num_frames=81,
+ cfg_scale=5.0,
+ num_inference_steps=50,
+ sigma_shift=5.0,
+ motion_bucket_id=None,
+ tiled=True,
+ tile_size=(30, 52),
+ tile_stride=(15, 26),
+ tea_cache_l1_thresh=None,
+ tea_cache_model_id="",
+ progress_bar_cmd=tqdm,
+ progress_bar_st=None,
+ ):
+ # Parameter check
+ height, width = self.check_resize_height_width(height, width)
+ if num_frames % 4 != 1:
+ num_frames = (num_frames + 2) // 4 * 4 + 1
+ print(f"Only `num_frames % 4 == 1` is acceptable. We round it up to {num_frames}.")
+
+ # Tiler parameters
+ tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
+
+ # Scheduler
+ self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift)
+
+ # Initialize noise
+ noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=rand_device, dtype=torch.float32)
+ noise = noise.to(dtype=self.torch_dtype, device=self.device)
+ if input_video is not None:
+ self.load_models_to_device(['vae'])
+ input_video = self.preprocess_images(input_video)
+ input_video = torch.stack(input_video, dim=2).to(dtype=self.torch_dtype, device=self.device)
+ latents = self.encode_video(input_video, **tiler_kwargs).to(dtype=self.torch_dtype, device=self.device)
+ latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
+ else:
+ latents = noise
+
+ # Encode prompts
+ self.load_models_to_device(["text_encoder"])
+ prompt_emb_posi = self.encode_prompt(prompt, positive=True)
+ if cfg_scale != 1.0:
+ prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False)
+
+ # Encode image
+ if input_image is not None and self.image_encoder is not None:
+ self.load_models_to_device(["image_encoder", "vae"])
+ image_emb = self.encode_image(input_image, end_image, num_frames, height, width, **tiler_kwargs)
+ else:
+ image_emb = {}
+
+ # Reference image
+ reference_image_kwargs = self.prepare_reference_image(reference_image, height, width)
+
+ # ControlNet
+ if control_video is not None:
+ self.load_models_to_device(["image_encoder", "vae"])
+ image_emb = self.prepare_controlnet_kwargs(control_video, num_frames, height, width, **image_emb, **tiler_kwargs)
+
+ # Motion Controller
+ if self.motion_controller is not None and motion_bucket_id is not None:
+ motion_kwargs = self.prepare_motion_bucket_id(motion_bucket_id)
+ else:
+ motion_kwargs = {}
+
+ # Extra input
+ extra_input = self.prepare_extra_input(latents)
+
+ # VACE
+ latents, vace_kwargs = self.prepare_vace_kwargs(
+ latents, vace_video, vace_video_mask, vace_reference_image, vace_scale,
+ height=height, width=width, num_frames=num_frames, seed=seed, rand_device=rand_device, **tiler_kwargs
+ )
+
+ # TeaCache
+ tea_cache_posi = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None else None}
+ tea_cache_nega = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None else None}
+
+ # Unified Sequence Parallel
+ usp_kwargs = self.prepare_unified_sequence_parallel()
+
+ # Denoise
+ self.load_models_to_device(["dit", "motion_controller", "vace"])
+ for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
+ timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
+
+ # Inference
+ noise_pred_posi = model_fn_wan_video(
+ self.dit, motion_controller=self.motion_controller, vace=self.vace,
+ x=latents, timestep=timestep,
+ **prompt_emb_posi, **image_emb, **extra_input,
+ **tea_cache_posi, **usp_kwargs, **motion_kwargs, **vace_kwargs, **reference_image_kwargs,
+ )
+ if cfg_scale != 1.0:
+ noise_pred_nega = model_fn_wan_video(
+ self.dit, motion_controller=self.motion_controller, vace=self.vace,
+ x=latents, timestep=timestep,
+ **prompt_emb_nega, **image_emb, **extra_input,
+ **tea_cache_nega, **usp_kwargs, **motion_kwargs, **vace_kwargs, **reference_image_kwargs,
+ )
+ noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
+ else:
+ noise_pred = noise_pred_posi
+
+ # Scheduler
+ latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
+
+ if vace_reference_image is not None:
+ latents = latents[:, :, 1:]
+
+ # Decode
+ self.load_models_to_device(['vae'])
+ frames = self.decode_video(latents, **tiler_kwargs)
+ self.load_models_to_device([])
+ frames = self.tensor2video(frames[0])
+
+ return frames
+
+
+
+class TeaCache:
+ def __init__(self, num_inference_steps, rel_l1_thresh, model_id):
+ self.num_inference_steps = num_inference_steps
+ self.step = 0
+ self.accumulated_rel_l1_distance = 0
+ self.previous_modulated_input = None
+ self.rel_l1_thresh = rel_l1_thresh
+ self.previous_residual = None
+ self.previous_hidden_states = None
+
+ self.coefficients_dict = {
+ "Wan2.1-T2V-1.3B": [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02],
+ "Wan2.1-T2V-14B": [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01],
+ "Wan2.1-I2V-14B-480P": [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01],
+ "Wan2.1-I2V-14B-720P": [ 8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02],
+ }
+ if model_id not in self.coefficients_dict:
+ supported_model_ids = ", ".join([i for i in self.coefficients_dict])
+ raise ValueError(f"{model_id} is not a supported TeaCache model id. Please choose a valid model id in ({supported_model_ids}).")
+ self.coefficients = self.coefficients_dict[model_id]
+
+ def check(self, dit: WanModel, x, t_mod):
+ modulated_inp = t_mod.clone()
+ if self.step == 0 or self.step == self.num_inference_steps - 1:
+ should_calc = True
+ self.accumulated_rel_l1_distance = 0
+ else:
+ coefficients = self.coefficients
+ rescale_func = np.poly1d(coefficients)
+ self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
+ if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
+ should_calc = False
+ else:
+ should_calc = True
+ self.accumulated_rel_l1_distance = 0
+ self.previous_modulated_input = modulated_inp
+ self.step += 1
+ if self.step == self.num_inference_steps:
+ self.step = 0
+ if should_calc:
+ self.previous_hidden_states = x.clone()
+ return not should_calc
+
+ def store(self, hidden_states):
+ self.previous_residual = hidden_states - self.previous_hidden_states
+ self.previous_hidden_states = None
+
+ def update(self, hidden_states):
+ hidden_states = hidden_states + self.previous_residual
+ return hidden_states
+
+
+
+def model_fn_wan_video(
+ dit: WanModel,
+ motion_controller: WanMotionControllerModel = None,
+ vace: VaceWanModel = None,
+ x: torch.Tensor = None,
+ timestep: torch.Tensor = None,
+ context: torch.Tensor = None,
+ clip_feature: Optional[torch.Tensor] = None,
+ y: Optional[torch.Tensor] = None,
+ reference_latents = None,
+ vace_context = None,
+ vace_scale = 1.0,
+ tea_cache: TeaCache = None,
+ use_unified_sequence_parallel: bool = False,
+ motion_bucket_id: Optional[torch.Tensor] = None,
+ **kwargs,
+):
+ if use_unified_sequence_parallel:
+ import torch.distributed as dist
+ from xfuser.core.distributed import (get_sequence_parallel_rank,
+ get_sequence_parallel_world_size,
+ get_sp_group)
+
+ t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep))
+ t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim))
+ if motion_bucket_id is not None and motion_controller is not None:
+ t_mod = t_mod + motion_controller(motion_bucket_id).unflatten(1, (6, dit.dim))
+ context = dit.text_embedding(context)
+
+ if dit.has_image_input:
+ x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w)
+ clip_embdding = dit.img_emb(clip_feature)
+ context = torch.cat([clip_embdding, context], dim=1)
+
+ x, (f, h, w) = dit.patchify(x)
+
+ # Reference image
+ if reference_latents is not None:
+ reference_latents = dit.ref_conv(reference_latents[:, :, 0]).flatten(2).transpose(1, 2)
+ x = torch.concat([reference_latents, x], dim=1)
+ f += 1
+
+ freqs = torch.cat([
+ dit.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
+ dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
+ dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
+ ], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
+
+ # TeaCache
+ if tea_cache is not None:
+ tea_cache_update = tea_cache.check(dit, x, t_mod)
+ else:
+ tea_cache_update = False
+
+ if vace_context is not None:
+ vace_hints = vace(x, vace_context, context, t_mod, freqs)
+
+ # blocks
+ if use_unified_sequence_parallel:
+ if dist.is_initialized() and dist.get_world_size() > 1:
+ chunks = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)
+ pad_shape = chunks[0].shape[1] - chunks[-1].shape[1]
+ chunks = [torch.nn.functional.pad(chunk, (0, 0, 0, chunks[0].shape[1]-chunk.shape[1]), value=0) for chunk in chunks]
+ x = chunks[get_sequence_parallel_rank()]
+
+ if tea_cache_update:
+ x = tea_cache.update(x)
+ else:
+ for block_id, block in enumerate(dit.blocks):
+ x = block(x, context, t_mod, freqs)
+ if vace_context is not None and block_id in vace.vace_layers_mapping:
+ current_vace_hint = vace_hints[vace.vace_layers_mapping[block_id]]
+ if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1:
+ current_vace_hint = torch.chunk(current_vace_hint, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
+ current_vace_hint = torch.nn.functional.pad(current_vace_hint, (0, 0, 0, chunks[0].shape[1] - current_vace_hint.shape[1]), value=0)
+ x = x + current_vace_hint * vace_scale
+ if tea_cache is not None:
+ tea_cache.store(x)
+
+ x = dit.head(x, t)
+ if use_unified_sequence_parallel:
+ if dist.is_initialized() and dist.get_world_size() > 1:
+ x = get_sp_group().all_gather(x, dim=1)
+ x = x[:, :-pad_shape] if pad_shape > 0 else x
+ # Remove reference latents
+ if reference_latents is not None:
+ x = x[:, reference_latents.shape[1]:]
+ f -= 1
+ x = dit.unpatchify(x, (f, h, w))
+ return x
diff --git a/diffsynth/pipelines/wan_video_new.py b/diffsynth/pipelines/wan_video_new.py
new file mode 100644
index 0000000000000000000000000000000000000000..80608388ea3cdb6d65b83f51ab08af18d3bebfbe
--- /dev/null
+++ b/diffsynth/pipelines/wan_video_new.py
@@ -0,0 +1,1758 @@
+import torch, warnings, glob, os, types
+import numpy as np
+from PIL import Image
+from einops import repeat, reduce
+from typing import Optional, Union
+from dataclasses import dataclass
+from modelscope import snapshot_download
+from einops import rearrange
+import numpy as np
+from PIL import Image
+from tqdm import tqdm
+from typing import Optional
+from typing_extensions import Literal
+
+from ..utils import BasePipeline, ModelConfig, PipelineUnit, PipelineUnitRunner
+from ..models import ModelManager, load_state_dict
+from ..models.wan_video_dit import WanModel, RMSNorm, sinusoidal_embedding_1d
+from ..models.wan_video_dit_s2v import rope_precompute
+from ..models.wan_video_text_encoder import WanTextEncoder, T5RelativeEmbedding, T5LayerNorm
+from ..models.wan_video_vae import WanVideoVAE, RMS_norm, CausalConv3d, Upsample
+from ..models.wan_video_image_encoder import WanImageEncoder
+from ..models.wan_video_vace import VaceWanModel
+from ..models.wan_video_motion_controller import WanMotionControllerModel
+from ..models.wan_video_animate_adapter import WanAnimateAdapter
+from ..models.wan_video_mot import MotWanModel
+from ..models.longcat_video_dit import LongCatVideoTransformer3DModel
+from ..schedulers.flow_match import FlowMatchScheduler
+from ..prompters import WanPrompter
+from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear, WanAutoCastLayerNorm
+from ..lora import GeneralLoRALoader
+
+
+
+class WanVideoPipeline(BasePipeline):
+
+ def __init__(self, device="cuda", torch_dtype=torch.bfloat16, tokenizer_path=None):
+ super().__init__(
+ device=device, torch_dtype=torch_dtype,
+ height_division_factor=16, width_division_factor=16, time_division_factor=4, time_division_remainder=1
+ )
+ self.scheduler = FlowMatchScheduler(shift=5, sigma_min=0.0, extra_one_step=True)
+ self.prompter = WanPrompter(tokenizer_path=tokenizer_path)
+ self.text_encoder: WanTextEncoder = None
+ self.image_encoder: WanImageEncoder = None
+ self.dit: WanModel = None
+ self.dit2: WanModel = None
+ self.vae: WanVideoVAE = None
+ self.motion_controller: WanMotionControllerModel = None
+ self.vace: VaceWanModel = None
+ self.vace2: VaceWanModel = None
+ self.vap: MotWanModel = None
+ self.animate_adapter: WanAnimateAdapter = None
+ self.in_iteration_models = ("dit", "motion_controller", "vace", "animate_adapter", "vap")
+ self.in_iteration_models_2 = ("dit2", "motion_controller", "vace2", "animate_adapter", "vap")
+ self.unit_runner = PipelineUnitRunner()
+ self.units = [
+ WanVideoUnit_ShapeChecker(),
+ WanVideoUnit_NoiseInitializer(),
+ WanVideoUnit_PromptEmbedder(),
+ WanVideoUnit_S2V(),
+ WanVideoUnit_InputVideoEmbedder(),
+ WanVideoUnit_ImageEmbedderVAE(),
+ WanVideoUnit_ImageEmbedderCLIP(),
+ WanVideoUnit_ImageEmbedderFused(),
+ WanVideoUnit_FunControl(),
+ WanVideoUnit_FunReference(),
+ WanVideoUnit_FunCameraControl(),
+ WanVideoUnit_SpeedControl(),
+ WanVideoUnit_VACE(),
+ WanVideoPostUnit_AnimateVideoSplit(),
+ WanVideoPostUnit_AnimatePoseLatents(),
+ WanVideoPostUnit_AnimateFacePixelValues(),
+ WanVideoPostUnit_AnimateInpaint(),
+ WanVideoUnit_VAP(),
+ WanVideoUnit_UnifiedSequenceParallel(),
+ WanVideoUnit_TeaCache(),
+ WanVideoUnit_CfgMerger(),
+ WanVideoUnit_LongCatVideo(),
+ ]
+ self.post_units = [
+ WanVideoPostUnit_S2V(),
+ ]
+ self.model_fn = model_fn_wan_video
+
+ def load_lora(
+ self,
+ module: torch.nn.Module,
+ lora_config: Union[ModelConfig, str] = None,
+ alpha=1,
+ hotload=False,
+ state_dict=None,
+ ):
+ if state_dict is None:
+ if isinstance(lora_config, str):
+ lora = load_state_dict(lora_config, torch_dtype=self.torch_dtype, device=self.device)
+ else:
+ lora_config.download_if_necessary()
+ lora = load_state_dict(lora_config.path, torch_dtype=self.torch_dtype, device=self.device)
+ else:
+ lora = state_dict
+ if hotload:
+ for name, module in module.named_modules():
+ if isinstance(module, AutoWrappedLinear):
+ lora_a_name = f'{name}.lora_A.default.weight'
+ lora_b_name = f'{name}.lora_B.default.weight'
+ if lora_a_name in lora and lora_b_name in lora:
+ module.lora_A_weights.append(lora[lora_a_name] * alpha)
+ module.lora_B_weights.append(lora[lora_b_name])
+ else:
+ loader = GeneralLoRALoader(torch_dtype=self.torch_dtype, device=self.device)
+ loader.load(module, lora, alpha=alpha)
+
+ def training_loss(self, **inputs):
+ max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * self.scheduler.num_train_timesteps)
+ min_timestep_boundary = int(inputs.get("min_timestep_boundary", 0) * self.scheduler.num_train_timesteps)
+ timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,))
+ timestep = self.scheduler.timesteps[timestep_id].to(dtype=self.torch_dtype, device=self.device)
+
+ inputs["latents"] = self.scheduler.add_noise(inputs["input_latents"], inputs["noise"], timestep)
+ training_target = self.scheduler.training_target(inputs["input_latents"], inputs["noise"], timestep)
+
+ noise_pred = self.model_fn(**inputs, timestep=timestep)
+
+ loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
+ loss = loss * self.scheduler.training_weight(timestep)
+ return loss
+
+
+ def enable_vram_management(self, num_persistent_param_in_dit=None, vram_limit=None, vram_buffer=0.5):
+ self.vram_management_enabled = True
+ if num_persistent_param_in_dit is not None:
+ vram_limit = None
+ else:
+ if vram_limit is None:
+ vram_limit = self.get_vram()
+ vram_limit = vram_limit - vram_buffer
+ if self.text_encoder is not None:
+ dtype = next(iter(self.text_encoder.parameters())).dtype
+ enable_vram_management(
+ self.text_encoder,
+ module_map = {
+ torch.nn.Linear: AutoWrappedLinear,
+ torch.nn.Embedding: AutoWrappedModule,
+ T5RelativeEmbedding: AutoWrappedModule,
+ T5LayerNorm: AutoWrappedModule,
+ },
+ module_config = dict(
+ offload_dtype=dtype,
+ offload_device="cpu",
+ onload_dtype=dtype,
+ onload_device="cpu",
+ computation_dtype=self.torch_dtype,
+ computation_device=self.device,
+ ),
+ vram_limit=vram_limit,
+ )
+ if self.dit is not None:
+ from ..models.longcat_video_dit import LayerNorm_FP32, RMSNorm_FP32
+ dtype = next(iter(self.dit.parameters())).dtype
+ device = "cpu" if vram_limit is not None else self.device
+ enable_vram_management(
+ self.dit,
+ module_map = {
+ torch.nn.Linear: AutoWrappedLinear,
+ torch.nn.Conv3d: AutoWrappedModule,
+ torch.nn.LayerNorm: WanAutoCastLayerNorm,
+ RMSNorm: AutoWrappedModule,
+ torch.nn.Conv2d: AutoWrappedModule,
+ torch.nn.Conv1d: AutoWrappedModule,
+ torch.nn.Embedding: AutoWrappedModule,
+ LayerNorm_FP32: AutoWrappedModule,
+ RMSNorm_FP32: AutoWrappedModule,
+ },
+ module_config = dict(
+ offload_dtype=dtype,
+ offload_device="cpu",
+ onload_dtype=dtype,
+ onload_device=device,
+ computation_dtype=self.torch_dtype,
+ computation_device=self.device,
+ ),
+ max_num_param=num_persistent_param_in_dit,
+ overflow_module_config = dict(
+ offload_dtype=dtype,
+ offload_device="cpu",
+ onload_dtype=dtype,
+ onload_device="cpu",
+ computation_dtype=self.torch_dtype,
+ computation_device=self.device,
+ ),
+ vram_limit=vram_limit,
+ )
+ if self.dit2 is not None:
+ dtype = next(iter(self.dit2.parameters())).dtype
+ device = "cpu" if vram_limit is not None else self.device
+ enable_vram_management(
+ self.dit2,
+ module_map = {
+ torch.nn.Linear: AutoWrappedLinear,
+ torch.nn.Conv3d: AutoWrappedModule,
+ torch.nn.LayerNorm: WanAutoCastLayerNorm,
+ RMSNorm: AutoWrappedModule,
+ torch.nn.Conv2d: AutoWrappedModule,
+ },
+ module_config = dict(
+ offload_dtype=dtype,
+ offload_device="cpu",
+ onload_dtype=dtype,
+ onload_device=device,
+ computation_dtype=self.torch_dtype,
+ computation_device=self.device,
+ ),
+ max_num_param=num_persistent_param_in_dit,
+ overflow_module_config = dict(
+ offload_dtype=dtype,
+ offload_device="cpu",
+ onload_dtype=dtype,
+ onload_device="cpu",
+ computation_dtype=self.torch_dtype,
+ computation_device=self.device,
+ ),
+ vram_limit=vram_limit,
+ )
+ if self.vae is not None:
+ dtype = next(iter(self.vae.parameters())).dtype
+ enable_vram_management(
+ self.vae,
+ module_map = {
+ torch.nn.Linear: AutoWrappedLinear,
+ torch.nn.Conv2d: AutoWrappedModule,
+ RMS_norm: AutoWrappedModule,
+ CausalConv3d: AutoWrappedModule,
+ Upsample: AutoWrappedModule,
+ torch.nn.SiLU: AutoWrappedModule,
+ torch.nn.Dropout: AutoWrappedModule,
+ },
+ module_config = dict(
+ offload_dtype=dtype,
+ offload_device="cpu",
+ onload_dtype=dtype,
+ onload_device=self.device,
+ computation_dtype=self.torch_dtype,
+ computation_device=self.device,
+ ),
+ )
+ if self.image_encoder is not None:
+ dtype = next(iter(self.image_encoder.parameters())).dtype
+ enable_vram_management(
+ self.image_encoder,
+ module_map = {
+ torch.nn.Linear: AutoWrappedLinear,
+ torch.nn.Conv2d: AutoWrappedModule,
+ torch.nn.LayerNorm: AutoWrappedModule,
+ },
+ module_config = dict(
+ offload_dtype=dtype,
+ offload_device="cpu",
+ onload_dtype=dtype,
+ onload_device="cpu",
+ computation_dtype=dtype,
+ computation_device=self.device,
+ ),
+ )
+ if self.motion_controller is not None:
+ dtype = next(iter(self.motion_controller.parameters())).dtype
+ enable_vram_management(
+ self.motion_controller,
+ module_map = {
+ torch.nn.Linear: AutoWrappedLinear,
+ },
+ module_config = dict(
+ offload_dtype=dtype,
+ offload_device="cpu",
+ onload_dtype=dtype,
+ onload_device="cpu",
+ computation_dtype=dtype,
+ computation_device=self.device,
+ ),
+ )
+ if self.vace is not None:
+ device = "cpu" if vram_limit is not None else self.device
+ enable_vram_management(
+ self.vace,
+ module_map = {
+ torch.nn.Linear: AutoWrappedLinear,
+ torch.nn.Conv3d: AutoWrappedModule,
+ torch.nn.LayerNorm: AutoWrappedModule,
+ RMSNorm: AutoWrappedModule,
+ },
+ module_config = dict(
+ offload_dtype=dtype,
+ offload_device="cpu",
+ onload_dtype=dtype,
+ onload_device=device,
+ computation_dtype=self.torch_dtype,
+ computation_device=self.device,
+ ),
+ vram_limit=vram_limit,
+ )
+ if self.audio_encoder is not None:
+ # TODO: need check
+ dtype = next(iter(self.audio_encoder.parameters())).dtype
+ enable_vram_management(
+ self.audio_encoder,
+ module_map = {
+ torch.nn.Linear: AutoWrappedLinear,
+ torch.nn.LayerNorm: AutoWrappedModule,
+ torch.nn.Conv1d: AutoWrappedModule,
+ },
+ module_config = dict(
+ offload_dtype=dtype,
+ offload_device="cpu",
+ onload_dtype=dtype,
+ onload_device="cpu",
+ computation_dtype=self.torch_dtype,
+ computation_device=self.device,
+ ),
+ )
+
+
+ def initialize_usp(self):
+ import torch.distributed as dist
+ from xfuser.core.distributed import initialize_model_parallel, init_distributed_environment
+ dist.init_process_group(backend="nccl", init_method="env://")
+ init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size())
+ initialize_model_parallel(
+ sequence_parallel_degree=dist.get_world_size(),
+ ring_degree=1,
+ ulysses_degree=dist.get_world_size(),
+ )
+ torch.cuda.set_device(dist.get_rank())
+
+
+ def enable_usp(self):
+ from xfuser.core.distributed import get_sequence_parallel_world_size
+ from ..distributed.xdit_context_parallel import usp_attn_forward, usp_dit_forward
+
+ for block in self.dit.blocks:
+ block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
+ self.dit.forward = types.MethodType(usp_dit_forward, self.dit)
+ if self.dit2 is not None:
+ for block in self.dit2.blocks:
+ block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
+ self.dit2.forward = types.MethodType(usp_dit_forward, self.dit2)
+ self.sp_size = get_sequence_parallel_world_size()
+ self.use_unified_sequence_parallel = True
+
+
+ @staticmethod
+ def from_pretrained(
+ torch_dtype: torch.dtype = torch.bfloat16,
+ device: Union[str, torch.device] = "cuda",
+ model_configs: list[ModelConfig] = [],
+ tokenizer_config: ModelConfig = ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/*"),
+ audio_processor_config: ModelConfig = None,
+ redirect_common_files: bool = True,
+ use_usp=False,
+ ):
+ # Redirect model path
+ if redirect_common_files:
+ redirect_dict = {
+ "models_t5_umt5-xxl-enc-bf16.pth": "Wan-AI/Wan2.1-T2V-1.3B",
+ "Wan2.1_VAE.pth": "Wan-AI/Wan2.1-T2V-1.3B",
+ "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth": "Wan-AI/Wan2.1-I2V-14B-480P",
+ }
+ for model_config in model_configs:
+ if model_config.origin_file_pattern is None or model_config.model_id is None:
+ continue
+ if model_config.origin_file_pattern in redirect_dict and model_config.model_id != redirect_dict[model_config.origin_file_pattern]:
+ print(f"To avoid repeatedly downloading model files, ({model_config.model_id}, {model_config.origin_file_pattern}) is redirected to ({redirect_dict[model_config.origin_file_pattern]}, {model_config.origin_file_pattern}). You can use `redirect_common_files=False` to disable file redirection.")
+ model_config.model_id = redirect_dict[model_config.origin_file_pattern]
+
+ # Initialize pipeline
+ pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype)
+ if use_usp: pipe.initialize_usp()
+
+ # Download and load models
+ model_manager = ModelManager()
+ for model_config in model_configs:
+ model_config.download_if_necessary(use_usp=use_usp)
+ model_manager.load_model(
+ model_config.path,
+ device=model_config.offload_device or device,
+ torch_dtype=model_config.offload_dtype or torch_dtype
+ )
+
+ # Load models
+ pipe.text_encoder = model_manager.fetch_model("wan_video_text_encoder")
+ dit = model_manager.fetch_model("wan_video_dit", index=2)
+ if isinstance(dit, list):
+ pipe.dit, pipe.dit2 = dit
+ else:
+ pipe.dit = dit
+ pipe.vae = model_manager.fetch_model("wan_video_vae")
+ pipe.image_encoder = model_manager.fetch_model("wan_video_image_encoder")
+ pipe.motion_controller = model_manager.fetch_model("wan_video_motion_controller")
+ vace = model_manager.fetch_model("wan_video_vace", index=2)
+ pipe.vap = model_manager.fetch_model("wan_video_vap")
+ if isinstance(vace, list):
+ pipe.vace, pipe.vace2 = vace
+ else:
+ pipe.vace = vace
+ pipe.audio_encoder = model_manager.fetch_model("wans2v_audio_encoder")
+ pipe.animate_adapter = model_manager.fetch_model("wan_video_animate_adapter")
+
+ # Size division factor
+ if pipe.vae is not None:
+ pipe.height_division_factor = pipe.vae.upsampling_factor * 2
+ pipe.width_division_factor = pipe.vae.upsampling_factor * 2
+
+ # Initialize tokenizer
+ tokenizer_config.download_if_necessary(use_usp=use_usp)
+ pipe.prompter.fetch_models(pipe.text_encoder)
+ pipe.prompter.fetch_tokenizer(tokenizer_config.path)
+
+ if audio_processor_config is not None:
+ audio_processor_config.download_if_necessary(use_usp=use_usp)
+ from transformers import Wav2Vec2Processor
+ pipe.audio_processor = Wav2Vec2Processor.from_pretrained(audio_processor_config.path)
+ # Unified Sequence Parallel
+ if use_usp: pipe.enable_usp()
+ return pipe
+
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ # Prompt
+ prompt: str,
+ negative_prompt: Optional[str] = "",
+ # Image-to-video
+ input_image: Optional[Image.Image] = None,
+ # First-last-frame-to-video
+ end_image: Optional[Image.Image] = None,
+ # Video-to-video
+ input_video: Optional[list[Image.Image]] = None,
+ denoising_strength: Optional[float] = 1.0,
+ # Speech-to-video
+ input_audio: Optional[np.array] = None,
+ audio_embeds: Optional[torch.Tensor] = None,
+ audio_sample_rate: Optional[int] = 16000,
+ s2v_pose_video: Optional[list[Image.Image]] = None,
+ s2v_pose_latents: Optional[torch.Tensor] = None,
+ motion_video: Optional[list[Image.Image]] = None,
+ # ControlNet
+ control_video: Optional[list[Image.Image]] = None,
+ reference_image: Optional[Image.Image] = None,
+ # Camera control
+ camera_control_direction: Optional[Literal["Left", "Right", "Up", "Down", "LeftUp", "LeftDown", "RightUp", "RightDown"]] = None,
+ camera_control_speed: Optional[float] = 1/54,
+ camera_control_origin: Optional[tuple] = (0, 0.532139961, 0.946026558, 0.5, 0.5, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0),
+ # VACE
+ vace_video: Optional[list[Image.Image]] = None,
+ vace_video_mask: Optional[Image.Image] = None,
+ vace_reference_image: Optional[Image.Image] = None,
+ vace_scale: Optional[float] = 1.0,
+ # Animate
+ animate_pose_video: Optional[list[Image.Image]] = None,
+ animate_face_video: Optional[list[Image.Image]] = None,
+ animate_inpaint_video: Optional[list[Image.Image]] = None,
+ animate_mask_video: Optional[list[Image.Image]] = None,
+ # VAP
+ vap_video: Optional[list[Image.Image]] = None,
+ vap_prompt: Optional[str] = " ",
+ negative_vap_prompt: Optional[str] = " ",
+ # Randomness
+ seed: Optional[int] = None,
+ rand_device: Optional[str] = "cpu",
+ # Shape
+ height: Optional[int] = 480,
+ width: Optional[int] = 832,
+ num_frames=81,
+ # Classifier-free guidance
+ cfg_scale: Optional[float] = 5.0,
+ cfg_merge: Optional[bool] = False,
+ # Boundary
+ switch_DiT_boundary: Optional[float] = 0.875,
+ # Scheduler
+ num_inference_steps: Optional[int] = 50,
+ sigma_shift: Optional[float] = 5.0,
+ # Speed control
+ motion_bucket_id: Optional[int] = None,
+ # LongCat-Video
+ longcat_video: Optional[list[Image.Image]] = None,
+ # VAE tiling
+ tiled: Optional[bool] = True,
+ tile_size: Optional[tuple[int, int]] = (30, 52),
+ tile_stride: Optional[tuple[int, int]] = (15, 26),
+ # Sliding window
+ sliding_window_size: Optional[int] = None,
+ sliding_window_stride: Optional[int] = None,
+ # Teacache
+ tea_cache_l1_thresh: Optional[float] = None,
+ tea_cache_model_id: Optional[str] = "",
+ # progress_bar
+ progress_bar_cmd=tqdm,
+ ):
+ # Scheduler
+ self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift)
+
+ # Inputs
+ inputs_posi = {
+ "prompt": prompt,
+ "vap_prompt": vap_prompt,
+ "tea_cache_l1_thresh": tea_cache_l1_thresh, "tea_cache_model_id": tea_cache_model_id, "num_inference_steps": num_inference_steps,
+ }
+ inputs_nega = {
+ "negative_prompt": negative_prompt,
+ "negative_vap_prompt": negative_vap_prompt,
+ "tea_cache_l1_thresh": tea_cache_l1_thresh, "tea_cache_model_id": tea_cache_model_id, "num_inference_steps": num_inference_steps,
+ }
+ inputs_shared = {
+ "input_image": input_image,
+ "end_image": end_image,
+ "input_video": input_video, "denoising_strength": denoising_strength,
+ "control_video": control_video, "reference_image": reference_image,
+ "camera_control_direction": camera_control_direction, "camera_control_speed": camera_control_speed, "camera_control_origin": camera_control_origin,
+ "vace_video": vace_video, "vace_video_mask": vace_video_mask, "vace_reference_image": vace_reference_image, "vace_scale": vace_scale,
+ "seed": seed, "rand_device": rand_device,
+ "height": height, "width": width, "num_frames": num_frames,
+ "cfg_scale": cfg_scale, "cfg_merge": cfg_merge,
+ "sigma_shift": sigma_shift,
+ "motion_bucket_id": motion_bucket_id,
+ "longcat_video": longcat_video,
+ "tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride,
+ "sliding_window_size": sliding_window_size, "sliding_window_stride": sliding_window_stride,
+ "input_audio": input_audio, "audio_sample_rate": audio_sample_rate, "s2v_pose_video": s2v_pose_video, "audio_embeds": audio_embeds, "s2v_pose_latents": s2v_pose_latents, "motion_video": motion_video,
+ "animate_pose_video": animate_pose_video, "animate_face_video": animate_face_video, "animate_inpaint_video": animate_inpaint_video, "animate_mask_video": animate_mask_video,
+ "vap_video": vap_video,
+ }
+ for unit in self.units:
+ inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
+
+ # Denoise
+ self.load_models_to_device(self.in_iteration_models)
+ models = {name: getattr(self, name) for name in self.in_iteration_models}
+ for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
+ # Switch DiT if necessary
+ if timestep.item() < switch_DiT_boundary * self.scheduler.num_train_timesteps and self.dit2 is not None and not models["dit"] is self.dit2:
+ self.load_models_to_device(self.in_iteration_models_2)
+ models["dit"] = self.dit2
+ models["vace"] = self.vace2
+
+ # Timestep
+ timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
+
+ # Inference
+ noise_pred_posi = self.model_fn(**models, **inputs_shared, **inputs_posi, timestep=timestep)
+ if cfg_scale != 1.0:
+ if cfg_merge:
+ noise_pred_posi, noise_pred_nega = noise_pred_posi.chunk(2, dim=0)
+ else:
+ noise_pred_nega = self.model_fn(**models, **inputs_shared, **inputs_nega, timestep=timestep)
+ noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
+ else:
+ noise_pred = noise_pred_posi
+
+ # Scheduler
+ inputs_shared["latents"] = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], inputs_shared["latents"])
+ if "first_frame_latents" in inputs_shared:
+ inputs_shared["latents"][:, :, 0:1] = inputs_shared["first_frame_latents"]
+
+ # VACE (TODO: remove it)
+ if vace_reference_image is not None or (animate_pose_video is not None and animate_face_video is not None):
+ if vace_reference_image is not None and isinstance(vace_reference_image, list):
+ f = len(vace_reference_image)
+ else:
+ f = 1
+ inputs_shared["latents"] = inputs_shared["latents"][:, :, f:]
+ # post-denoising, pre-decoding processing logic
+ for unit in self.post_units:
+ inputs_shared, _, _ = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
+ # Decode
+ self.load_models_to_device(['vae'])
+ video = self.vae.decode(inputs_shared["latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
+ video = self.vae_output_to_video(video)
+ self.load_models_to_device([])
+
+ return video
+
+
+
+class WanVideoUnit_ShapeChecker(PipelineUnit):
+ def __init__(self):
+ super().__init__(input_params=("height", "width", "num_frames"))
+
+ def process(self, pipe: WanVideoPipeline, height, width, num_frames):
+ height, width, num_frames = pipe.check_resize_height_width(height, width, num_frames)
+ return {"height": height, "width": width, "num_frames": num_frames}
+
+
+
+class WanVideoUnit_NoiseInitializer(PipelineUnit):
+ def __init__(self):
+ super().__init__(input_params=("height", "width", "num_frames", "seed", "rand_device", "vace_reference_image"))
+
+ def process(self, pipe: WanVideoPipeline, height, width, num_frames, seed, rand_device, vace_reference_image):
+ length = (num_frames - 1) // 4 + 1
+ if vace_reference_image is not None:
+ f = len(vace_reference_image) if isinstance(vace_reference_image, list) else 1
+ length += f
+ shape = (1, pipe.vae.model.z_dim, length, height // pipe.vae.upsampling_factor, width // pipe.vae.upsampling_factor)
+ noise = pipe.generate_noise(shape, seed=seed, rand_device=rand_device)
+ if vace_reference_image is not None:
+ noise = torch.concat((noise[:, :, -f:], noise[:, :, :-f]), dim=2)
+ return {"noise": noise}
+
+
+
+class WanVideoUnit_InputVideoEmbedder(PipelineUnit):
+ def __init__(self):
+ super().__init__(
+ input_params=("input_video", "noise", "tiled", "tile_size", "tile_stride", "vace_reference_image"),
+ onload_model_names=("vae",)
+ )
+
+ def process(self, pipe: WanVideoPipeline, input_video, noise, tiled, tile_size, tile_stride, vace_reference_image):
+ if input_video is None:
+ return {"latents": noise}
+ pipe.load_models_to_device(["vae"])
+ input_video = pipe.preprocess_video(input_video)
+ input_latents = pipe.vae.encode(input_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
+ if vace_reference_image is not None:
+ if not isinstance(vace_reference_image, list):
+ vace_reference_image = [vace_reference_image]
+ vace_reference_image = pipe.preprocess_video(vace_reference_image)
+ vace_reference_latents = pipe.vae.encode(vace_reference_image, device=pipe.device).to(dtype=pipe.torch_dtype, device=pipe.device)
+ input_latents = torch.concat([vace_reference_latents, input_latents], dim=2)
+ if pipe.scheduler.training:
+ return {"latents": noise, "input_latents": input_latents}
+ else:
+ latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0])
+ return {"latents": latents}
+
+
+
+class WanVideoUnit_PromptEmbedder(PipelineUnit):
+ def __init__(self):
+ super().__init__(
+ seperate_cfg=True,
+ input_params_posi={"prompt": "prompt", "positive": "positive"},
+ input_params_nega={"prompt": "negative_prompt", "positive": "positive"},
+ onload_model_names=("text_encoder",)
+ )
+
+ def process(self, pipe: WanVideoPipeline, prompt, positive) -> dict:
+ pipe.load_models_to_device(self.onload_model_names)
+ prompt_emb = pipe.prompter.encode_prompt(prompt, positive=positive, device=pipe.device)
+ return {"context": prompt_emb}
+
+
+
+class WanVideoUnit_ImageEmbedder(PipelineUnit):
+ """
+ Deprecated
+ """
+ def __init__(self):
+ super().__init__(
+ input_params=("input_image", "end_image", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"),
+ onload_model_names=("image_encoder", "vae")
+ )
+
+ def process(self, pipe: WanVideoPipeline, input_image, end_image, num_frames, height, width, tiled, tile_size, tile_stride):
+ if input_image is None or pipe.image_encoder is None:
+ return {}
+ pipe.load_models_to_device(self.onload_model_names)
+ image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device)
+ clip_context = pipe.image_encoder.encode_image([image])
+ msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device)
+ msk[:, 1:] = 0
+ if end_image is not None:
+ end_image = pipe.preprocess_image(end_image.resize((width, height))).to(pipe.device)
+ vae_input = torch.concat([image.transpose(0,1), torch.zeros(3, num_frames-2, height, width).to(image.device), end_image.transpose(0,1)],dim=1)
+ if pipe.dit.has_image_pos_emb:
+ clip_context = torch.concat([clip_context, pipe.image_encoder.encode_image([end_image])], dim=1)
+ msk[:, -1:] = 1
+ else:
+ vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)
+
+ msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
+ msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
+ msk = msk.transpose(1, 2)[0]
+
+ y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
+ y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
+ y = torch.concat([msk, y])
+ y = y.unsqueeze(0)
+ clip_context = clip_context.to(dtype=pipe.torch_dtype, device=pipe.device)
+ y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
+ return {"clip_feature": clip_context, "y": y}
+
+
+
+class WanVideoUnit_ImageEmbedderCLIP(PipelineUnit):
+ def __init__(self):
+ super().__init__(
+ input_params=("input_image", "end_image", "height", "width"),
+ onload_model_names=("image_encoder",)
+ )
+
+ def process(self, pipe: WanVideoPipeline, input_image, end_image, height, width):
+ if input_image is None or pipe.image_encoder is None or not pipe.dit.require_clip_embedding:
+ return {}
+ pipe.load_models_to_device(self.onload_model_names)
+ image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device)
+ clip_context = pipe.image_encoder.encode_image([image])
+ if end_image is not None:
+ end_image = pipe.preprocess_image(end_image.resize((width, height))).to(pipe.device)
+ if pipe.dit.has_image_pos_emb:
+ clip_context = torch.concat([clip_context, pipe.image_encoder.encode_image([end_image])], dim=1)
+ clip_context = clip_context.to(dtype=pipe.torch_dtype, device=pipe.device)
+ return {"clip_feature": clip_context}
+
+
+
+class WanVideoUnit_ImageEmbedderVAE(PipelineUnit):
+ def __init__(self):
+ super().__init__(
+ input_params=("input_image", "end_image", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"),
+ onload_model_names=("vae",)
+ )
+
+ def process(self, pipe: WanVideoPipeline, input_image, end_image, num_frames, height, width, tiled, tile_size, tile_stride):
+ if input_image is None or not pipe.dit.require_vae_embedding:
+ return {}
+ pipe.load_models_to_device(self.onload_model_names)
+ image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device)
+ msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device)
+ msk[:, 1:] = 0
+ if end_image is not None:
+ end_image = pipe.preprocess_image(end_image.resize((width, height))).to(pipe.device)
+ vae_input = torch.concat([image.transpose(0,1), torch.zeros(3, num_frames-2, height, width).to(image.device), end_image.transpose(0,1)],dim=1)
+ msk[:, -1:] = 1
+ else:
+ vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)
+
+ msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
+ msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
+ msk = msk.transpose(1, 2)[0]
+
+ y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
+ y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
+ y = torch.concat([msk, y])
+ y = y.unsqueeze(0)
+ y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
+ return {"y": y}
+
+
+
+class WanVideoUnit_ImageEmbedderFused(PipelineUnit):
+ """
+ Encode input image to latents using VAE. This unit is for Wan-AI/Wan2.2-TI2V-5B.
+ """
+ def __init__(self):
+ super().__init__(
+ input_params=("input_image", "latents", "height", "width", "tiled", "tile_size", "tile_stride"),
+ onload_model_names=("vae",)
+ )
+
+ def process(self, pipe: WanVideoPipeline, input_image, latents, height, width, tiled, tile_size, tile_stride):
+ if input_image is None or not pipe.dit.fuse_vae_embedding_in_latents:
+ return {}
+ pipe.load_models_to_device(self.onload_model_names)
+ image = pipe.preprocess_image(input_image.resize((width, height))).transpose(0, 1)
+ z = pipe.vae.encode([image], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
+ latents[:, :, 0: 1] = z
+ return {"latents": latents, "fuse_vae_embedding_in_latents": True, "first_frame_latents": z}
+
+
+
+class WanVideoUnit_FunControl(PipelineUnit):
+ def __init__(self):
+ super().__init__(
+ input_params=("control_video", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride", "clip_feature", "y", "latents"),
+ onload_model_names=("vae",)
+ )
+
+ def process(self, pipe: WanVideoPipeline, control_video, num_frames, height, width, tiled, tile_size, tile_stride, clip_feature, y, latents):
+ if control_video is None:
+ return {}
+ pipe.load_models_to_device(self.onload_model_names)
+ control_video = pipe.preprocess_video(control_video)
+ control_latents = pipe.vae.encode(control_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
+ control_latents = control_latents.to(dtype=pipe.torch_dtype, device=pipe.device)
+ y_dim = pipe.dit.in_dim-control_latents.shape[1]-latents.shape[1]
+ if clip_feature is None or y is None:
+ clip_feature = torch.zeros((1, 257, 1280), dtype=pipe.torch_dtype, device=pipe.device)
+ y = torch.zeros((1, y_dim, (num_frames - 1) // 4 + 1, height//8, width//8), dtype=pipe.torch_dtype, device=pipe.device)
+ else:
+ y = y[:, -y_dim:]
+ y = torch.concat([control_latents, y], dim=1)
+ return {"clip_feature": clip_feature, "y": y}
+
+
+
+class WanVideoUnit_FunReference(PipelineUnit):
+ def __init__(self):
+ super().__init__(
+ input_params=("reference_image", "height", "width", "reference_image"),
+ onload_model_names=("vae",)
+ )
+
+ def process(self, pipe: WanVideoPipeline, reference_image, height, width):
+ if reference_image is None:
+ return {}
+ pipe.load_models_to_device(["vae"])
+ reference_image = reference_image.resize((width, height))
+ reference_latents = pipe.preprocess_video([reference_image])
+ reference_latents = pipe.vae.encode(reference_latents, device=pipe.device)
+ if pipe.image_encoder is None:
+ return {"reference_latents": reference_latents}
+ clip_feature = pipe.preprocess_image(reference_image)
+ clip_feature = pipe.image_encoder.encode_image([clip_feature])
+ return {"reference_latents": reference_latents, "clip_feature": clip_feature}
+
+
+
+class WanVideoUnit_FunCameraControl(PipelineUnit):
+ def __init__(self):
+ super().__init__(
+ input_params=("height", "width", "num_frames", "camera_control_direction", "camera_control_speed", "camera_control_origin", "latents", "input_image", "tiled", "tile_size", "tile_stride"),
+ onload_model_names=("vae",)
+ )
+
+ def process(self, pipe: WanVideoPipeline, height, width, num_frames, camera_control_direction, camera_control_speed, camera_control_origin, latents, input_image, tiled, tile_size, tile_stride):
+ if camera_control_direction is None:
+ return {}
+ pipe.load_models_to_device(self.onload_model_names)
+ camera_control_plucker_embedding = pipe.dit.control_adapter.process_camera_coordinates(
+ camera_control_direction, num_frames, height, width, camera_control_speed, camera_control_origin)
+
+ control_camera_video = camera_control_plucker_embedding[:num_frames].permute([3, 0, 1, 2]).unsqueeze(0)
+ control_camera_latents = torch.concat(
+ [
+ torch.repeat_interleave(control_camera_video[:, :, 0:1], repeats=4, dim=2),
+ control_camera_video[:, :, 1:]
+ ], dim=2
+ ).transpose(1, 2)
+ b, f, c, h, w = control_camera_latents.shape
+ control_camera_latents = control_camera_latents.contiguous().view(b, f // 4, 4, c, h, w).transpose(2, 3)
+ control_camera_latents = control_camera_latents.contiguous().view(b, f // 4, c * 4, h, w).transpose(1, 2)
+ control_camera_latents_input = control_camera_latents.to(device=pipe.device, dtype=pipe.torch_dtype)
+
+ input_image = input_image.resize((width, height))
+ input_latents = pipe.preprocess_video([input_image])
+ input_latents = pipe.vae.encode(input_latents, device=pipe.device)
+ y = torch.zeros_like(latents).to(pipe.device)
+ y[:, :, :1] = input_latents
+ y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
+
+ if y.shape[1] != pipe.dit.in_dim - latents.shape[1]:
+ image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device)
+ vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)
+ y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
+ y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
+ msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device)
+ msk[:, 1:] = 0
+ msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
+ msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
+ msk = msk.transpose(1, 2)[0]
+ y = torch.cat([msk,y])
+ y = y.unsqueeze(0)
+ y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
+ return {"control_camera_latents_input": control_camera_latents_input, "y": y}
+
+
+
+class WanVideoUnit_SpeedControl(PipelineUnit):
+ def __init__(self):
+ super().__init__(input_params=("motion_bucket_id",))
+
+ def process(self, pipe: WanVideoPipeline, motion_bucket_id):
+ if motion_bucket_id is None:
+ return {}
+ motion_bucket_id = torch.Tensor((motion_bucket_id,)).to(dtype=pipe.torch_dtype, device=pipe.device)
+ return {"motion_bucket_id": motion_bucket_id}
+
+
+
+class WanVideoUnit_VACE(PipelineUnit):
+ def __init__(self):
+ super().__init__(
+ input_params=("vace_video", "vace_video_mask", "vace_reference_image", "vace_scale", "height", "width", "num_frames", "tiled", "tile_size", "tile_stride"),
+ onload_model_names=("vae",)
+ )
+
+ def process(
+ self,
+ pipe: WanVideoPipeline,
+ vace_video, vace_video_mask, vace_reference_image, vace_scale,
+ height, width, num_frames,
+ tiled, tile_size, tile_stride
+ ):
+ if vace_video is not None or vace_video_mask is not None or vace_reference_image is not None:
+ pipe.load_models_to_device(["vae"])
+ if vace_video is None:
+ vace_video = torch.zeros((1, 3, num_frames, height, width), dtype=pipe.torch_dtype, device=pipe.device)
+ else:
+ vace_video = pipe.preprocess_video(vace_video)
+
+ if vace_video_mask is None:
+ vace_video_mask = torch.ones_like(vace_video)
+ else:
+ vace_video_mask = pipe.preprocess_video(vace_video_mask, min_value=0, max_value=1)
+
+ inactive = vace_video * (1 - vace_video_mask) + 0 * vace_video_mask
+ reactive = vace_video * vace_video_mask + 0 * (1 - vace_video_mask)
+ inactive = pipe.vae.encode(inactive, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
+ reactive = pipe.vae.encode(reactive, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
+ vace_video_latents = torch.concat((inactive, reactive), dim=1)
+
+ vace_mask_latents = rearrange(vace_video_mask[0,0], "T (H P) (W Q) -> 1 (P Q) T H W", P=8, Q=8)
+ vace_mask_latents = torch.nn.functional.interpolate(vace_mask_latents, size=((vace_mask_latents.shape[2] + 3) // 4, vace_mask_latents.shape[3], vace_mask_latents.shape[4]), mode='nearest-exact')
+
+ if vace_reference_image is None:
+ pass
+ else:
+ if not isinstance(vace_reference_image,list):
+ vace_reference_image = [vace_reference_image]
+
+ vace_reference_image = pipe.preprocess_video(vace_reference_image)
+
+ bs, c, f, h, w = vace_reference_image.shape
+ new_vace_ref_images = []
+ for j in range(f):
+ new_vace_ref_images.append(vace_reference_image[0, :, j:j+1])
+ vace_reference_image = new_vace_ref_images
+
+ vace_reference_latents = pipe.vae.encode(vace_reference_image, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
+ vace_reference_latents = torch.concat((vace_reference_latents, torch.zeros_like(vace_reference_latents)), dim=1)
+ vace_reference_latents = [u.unsqueeze(0) for u in vace_reference_latents]
+
+ vace_video_latents = torch.concat((*vace_reference_latents, vace_video_latents), dim=2)
+ vace_mask_latents = torch.concat((torch.zeros_like(vace_mask_latents[:, :, :f]), vace_mask_latents), dim=2)
+
+ vace_context = torch.concat((vace_video_latents, vace_mask_latents), dim=1)
+ return {"vace_context": vace_context, "vace_scale": vace_scale}
+ else:
+ return {"vace_context": None, "vace_scale": vace_scale}
+
+class WanVideoUnit_VAP(PipelineUnit):
+ def __init__(self):
+ super().__init__(
+ take_over=True,
+ onload_model_names=("text_encoder", "vae", "image_encoder")
+ )
+
+ def process(self, pipe: WanVideoPipeline, inputs_shared, inputs_posi, inputs_nega):
+ if inputs_shared.get("vap_video") is None:
+ return inputs_shared, inputs_posi, inputs_nega
+ else:
+ # 1. encode vap prompt
+ pipe.load_models_to_device(["text_encoder"])
+ vap_prompt, negative_vap_prompt = inputs_posi.get("vap_prompt", ""), inputs_nega.get("negative_vap_prompt", "")
+ vap_prompt_emb = pipe.prompter.encode_prompt(vap_prompt, positive=inputs_posi.get('positive',None), device=pipe.device)
+ negative_vap_prompt_emb = pipe.prompter.encode_prompt(negative_vap_prompt, positive=inputs_nega.get('positive',None), device=pipe.device)
+ inputs_posi.update({"context_vap":vap_prompt_emb})
+ inputs_nega.update({"context_vap":negative_vap_prompt_emb})
+ # 2. prepare vap image clip embedding
+ pipe.load_models_to_device(["vae", "image_encoder"])
+ vap_video, end_image = inputs_shared.get("vap_video"), inputs_shared.get("end_image")
+
+ num_frames, height, width, mot_num = inputs_shared.get("num_frames"),inputs_shared.get("height"), inputs_shared.get("width"), inputs_shared.get("mot_num",1)
+
+ image_vap = pipe.preprocess_image(vap_video[0].resize((width, height))).to(pipe.device)
+
+ vap_clip_context = pipe.image_encoder.encode_image([image_vap])
+ if end_image is not None:
+ vap_end_image = pipe.preprocess_image(vap_video[-1].resize((width, height))).to(pipe.device)
+ if pipe.dit.has_image_pos_emb:
+ vap_clip_context = torch.concat([vap_clip_context, pipe.image_encoder.encode_image([vap_end_image])], dim=1)
+ vap_clip_context = vap_clip_context.to(dtype=pipe.torch_dtype, device=pipe.device)
+ inputs_shared.update({"vap_clip_feature":vap_clip_context})
+
+ # 3. prepare vap latents
+ msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device)
+ msk[:, 1:] = 0
+ if end_image is not None:
+ msk[:, -1:] = 1
+ last_image_vap = pipe.preprocess_image(vap_video[-1].resize((width, height))).to(pipe.device)
+ vae_input = torch.concat([image_vap.transpose(0,1), torch.zeros(3, num_frames-2, height, width).to(image_vap.device), last_image_vap.transpose(0,1)],dim=1)
+ else:
+ vae_input = torch.concat([image_vap.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image_vap.device)], dim=1)
+
+ msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
+ msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
+ msk = msk.transpose(1, 2)[0]
+
+ tiled,tile_size,tile_stride = inputs_shared.get("tiled"), inputs_shared.get("tile_size"), inputs_shared.get("tile_stride")
+
+ y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
+ y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
+ y = torch.concat([msk, y])
+ y = y.unsqueeze(0)
+ y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
+
+ vap_video = pipe.preprocess_video(vap_video)
+ vap_latent = pipe.vae.encode(vap_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
+
+ vap_latent = torch.concat([vap_latent,y], dim=1).to(dtype=pipe.torch_dtype, device=pipe.device)
+ inputs_shared.update({"vap_hidden_state":vap_latent})
+ pipe.load_models_to_device([])
+
+ return inputs_shared, inputs_posi, inputs_nega
+
+
+
+class WanVideoUnit_UnifiedSequenceParallel(PipelineUnit):
+ def __init__(self):
+ super().__init__(input_params=())
+
+ def process(self, pipe: WanVideoPipeline):
+ if hasattr(pipe, "use_unified_sequence_parallel"):
+ if pipe.use_unified_sequence_parallel:
+ return {"use_unified_sequence_parallel": True}
+ return {}
+
+
+
+class WanVideoUnit_TeaCache(PipelineUnit):
+ def __init__(self):
+ super().__init__(
+ seperate_cfg=True,
+ input_params_posi={"num_inference_steps": "num_inference_steps", "tea_cache_l1_thresh": "tea_cache_l1_thresh", "tea_cache_model_id": "tea_cache_model_id"},
+ input_params_nega={"num_inference_steps": "num_inference_steps", "tea_cache_l1_thresh": "tea_cache_l1_thresh", "tea_cache_model_id": "tea_cache_model_id"},
+ )
+
+ def process(self, pipe: WanVideoPipeline, num_inference_steps, tea_cache_l1_thresh, tea_cache_model_id):
+ if tea_cache_l1_thresh is None:
+ return {}
+ return {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id)}
+
+
+
+class WanVideoUnit_CfgMerger(PipelineUnit):
+ def __init__(self):
+ super().__init__(take_over=True)
+ self.concat_tensor_names = ["context", "clip_feature", "y", "reference_latents"]
+
+ def process(self, pipe: WanVideoPipeline, inputs_shared, inputs_posi, inputs_nega):
+ if not inputs_shared["cfg_merge"]:
+ return inputs_shared, inputs_posi, inputs_nega
+ for name in self.concat_tensor_names:
+ tensor_posi = inputs_posi.get(name)
+ tensor_nega = inputs_nega.get(name)
+ tensor_shared = inputs_shared.get(name)
+ if tensor_posi is not None and tensor_nega is not None:
+ inputs_shared[name] = torch.concat((tensor_posi, tensor_nega), dim=0)
+ elif tensor_shared is not None:
+ inputs_shared[name] = torch.concat((tensor_shared, tensor_shared), dim=0)
+ inputs_posi.clear()
+ inputs_nega.clear()
+ return inputs_shared, inputs_posi, inputs_nega
+
+
+class WanVideoUnit_S2V(PipelineUnit):
+ def __init__(self):
+ super().__init__(
+ take_over=True,
+ onload_model_names=("audio_encoder", "vae",)
+ )
+
+ def process_audio(self, pipe: WanVideoPipeline, input_audio, audio_sample_rate, num_frames, fps=16, audio_embeds=None, return_all=False):
+ if audio_embeds is not None:
+ return {"audio_embeds": audio_embeds}
+ pipe.load_models_to_device(["audio_encoder"])
+ audio_embeds = pipe.audio_encoder.get_audio_feats_per_inference(input_audio, audio_sample_rate, pipe.audio_processor, fps=fps, batch_frames=num_frames-1, dtype=pipe.torch_dtype, device=pipe.device)
+ if return_all:
+ return audio_embeds
+ else:
+ return {"audio_embeds": audio_embeds[0]}
+
+ def process_motion_latents(self, pipe: WanVideoPipeline, height, width, tiled, tile_size, tile_stride, motion_video=None):
+ pipe.load_models_to_device(["vae"])
+ motion_frames = 73
+ kwargs = {}
+ if motion_video is not None and len(motion_video) > 0:
+ assert len(motion_video) == motion_frames, f"motion video must have {motion_frames} frames, but got {len(motion_video)}"
+ motion_latents = pipe.preprocess_video(motion_video)
+ kwargs["drop_motion_frames"] = False
+ else:
+ motion_latents = torch.zeros([1, 3, motion_frames, height, width], dtype=pipe.torch_dtype, device=pipe.device)
+ kwargs["drop_motion_frames"] = True
+ motion_latents = pipe.vae.encode(motion_latents, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
+ kwargs.update({"motion_latents": motion_latents})
+ return kwargs
+
+ def process_pose_cond(self, pipe: WanVideoPipeline, s2v_pose_video, num_frames, height, width, tiled, tile_size, tile_stride, s2v_pose_latents=None, num_repeats=1, return_all=False):
+ if s2v_pose_latents is not None:
+ return {"s2v_pose_latents": s2v_pose_latents}
+ if s2v_pose_video is None:
+ return {"s2v_pose_latents": None}
+ pipe.load_models_to_device(["vae"])
+ infer_frames = num_frames - 1
+ input_video = pipe.preprocess_video(s2v_pose_video)[:, :, :infer_frames * num_repeats]
+ # pad if not enough frames
+ padding_frames = infer_frames * num_repeats - input_video.shape[2]
+ input_video = torch.cat([input_video, -torch.ones(1, 3, padding_frames, height, width, device=input_video.device, dtype=input_video.dtype)], dim=2)
+ input_videos = input_video.chunk(num_repeats, dim=2)
+ pose_conds = []
+ for r in range(num_repeats):
+ cond = input_videos[r]
+ cond = torch.cat([cond[:, :, 0:1].repeat(1, 1, 1, 1, 1), cond], dim=2)
+ cond_latents = pipe.vae.encode(cond, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
+ pose_conds.append(cond_latents[:,:,1:])
+ if return_all:
+ return pose_conds
+ else:
+ return {"s2v_pose_latents": pose_conds[0]}
+
+ def process(self, pipe: WanVideoPipeline, inputs_shared, inputs_posi, inputs_nega):
+ if (inputs_shared.get("input_audio") is None and inputs_shared.get("audio_embeds") is None) or pipe.audio_encoder is None or pipe.audio_processor is None:
+ return inputs_shared, inputs_posi, inputs_nega
+ num_frames, height, width, tiled, tile_size, tile_stride = inputs_shared.get("num_frames"), inputs_shared.get("height"), inputs_shared.get("width"), inputs_shared.get("tiled"), inputs_shared.get("tile_size"), inputs_shared.get("tile_stride")
+ input_audio, audio_embeds, audio_sample_rate = inputs_shared.pop("input_audio", None), inputs_shared.pop("audio_embeds", None), inputs_shared.get("audio_sample_rate", 16000)
+ s2v_pose_video, s2v_pose_latents, motion_video = inputs_shared.pop("s2v_pose_video", None), inputs_shared.pop("s2v_pose_latents", None), inputs_shared.pop("motion_video", None)
+
+ audio_input_positive = self.process_audio(pipe, input_audio, audio_sample_rate, num_frames, audio_embeds=audio_embeds)
+ inputs_posi.update(audio_input_positive)
+ inputs_nega.update({"audio_embeds": 0.0 * audio_input_positive["audio_embeds"]})
+
+ inputs_shared.update(self.process_motion_latents(pipe, height, width, tiled, tile_size, tile_stride, motion_video))
+ inputs_shared.update(self.process_pose_cond(pipe, s2v_pose_video, num_frames, height, width, tiled, tile_size, tile_stride, s2v_pose_latents=s2v_pose_latents))
+ return inputs_shared, inputs_posi, inputs_nega
+
+ @staticmethod
+ def pre_calculate_audio_pose(pipe: WanVideoPipeline, input_audio=None, audio_sample_rate=16000, s2v_pose_video=None, num_frames=81, height=448, width=832, fps=16, tiled=True, tile_size=(30, 52), tile_stride=(15, 26)):
+ assert pipe.audio_encoder is not None and pipe.audio_processor is not None, "Please load audio encoder and audio processor first."
+ shapes = WanVideoUnit_ShapeChecker().process(pipe, height, width, num_frames)
+ height, width, num_frames = shapes["height"], shapes["width"], shapes["num_frames"]
+ unit = WanVideoUnit_S2V()
+ audio_embeds = unit.process_audio(pipe, input_audio, audio_sample_rate, num_frames, fps, return_all=True)
+ pose_latents = unit.process_pose_cond(pipe, s2v_pose_video, num_frames, height, width, num_repeats=len(audio_embeds), return_all=True, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
+ pose_latents = None if s2v_pose_video is None else pose_latents
+ return audio_embeds, pose_latents, len(audio_embeds)
+
+
+class WanVideoPostUnit_S2V(PipelineUnit):
+ def __init__(self):
+ super().__init__(input_params=("latents", "motion_latents", "drop_motion_frames"))
+
+ def process(self, pipe: WanVideoPipeline, latents, motion_latents, drop_motion_frames):
+ if pipe.audio_encoder is None or motion_latents is None or drop_motion_frames:
+ return {}
+ latents = torch.cat([motion_latents, latents[:,:,1:]], dim=2)
+ return {"latents": latents}
+
+
+class WanVideoPostUnit_AnimateVideoSplit(PipelineUnit):
+ def __init__(self):
+ super().__init__(input_params=("input_video", "animate_pose_video", "animate_face_video", "animate_inpaint_video", "animate_mask_video"))
+
+ def process(self, pipe: WanVideoPipeline, input_video, animate_pose_video, animate_face_video, animate_inpaint_video, animate_mask_video):
+ if input_video is None:
+ return {}
+ if animate_pose_video is not None:
+ animate_pose_video = animate_pose_video[:len(input_video) - 4]
+ if animate_face_video is not None:
+ animate_face_video = animate_face_video[:len(input_video) - 4]
+ if animate_inpaint_video is not None:
+ animate_inpaint_video = animate_inpaint_video[:len(input_video) - 4]
+ if animate_mask_video is not None:
+ animate_mask_video = animate_mask_video[:len(input_video) - 4]
+ return {"animate_pose_video": animate_pose_video, "animate_face_video": animate_face_video, "animate_inpaint_video": animate_inpaint_video, "animate_mask_video": animate_mask_video}
+
+
+class WanVideoPostUnit_AnimatePoseLatents(PipelineUnit):
+ def __init__(self):
+ super().__init__(
+ input_params=("animate_pose_video", "tiled", "tile_size", "tile_stride"),
+ onload_model_names=("vae",)
+ )
+
+ def process(self, pipe: WanVideoPipeline, animate_pose_video, tiled, tile_size, tile_stride):
+ if animate_pose_video is None:
+ return {}
+ pipe.load_models_to_device(self.onload_model_names)
+ animate_pose_video = pipe.preprocess_video(animate_pose_video)
+ pose_latents = pipe.vae.encode(animate_pose_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
+ return {"pose_latents": pose_latents}
+
+
+class WanVideoPostUnit_AnimateFacePixelValues(PipelineUnit):
+ def __init__(self):
+ super().__init__(take_over=True)
+
+ def process(self, pipe: WanVideoPipeline, inputs_shared, inputs_posi, inputs_nega):
+ if inputs_shared.get("animate_face_video", None) is None:
+ return inputs_shared, inputs_posi, inputs_nega
+ inputs_posi["face_pixel_values"] = pipe.preprocess_video(inputs_shared["animate_face_video"])
+ inputs_nega["face_pixel_values"] = torch.zeros_like(inputs_posi["face_pixel_values"]) - 1
+ return inputs_shared, inputs_posi, inputs_nega
+
+
+class WanVideoPostUnit_AnimateInpaint(PipelineUnit):
+ def __init__(self):
+ super().__init__(
+ input_params=("animate_inpaint_video", "animate_mask_video", "input_image", "tiled", "tile_size", "tile_stride"),
+ onload_model_names=("vae",)
+ )
+
+ def get_i2v_mask(self, lat_t, lat_h, lat_w, mask_len=1, mask_pixel_values=None, device="cuda"):
+ if mask_pixel_values is None:
+ msk = torch.zeros(1, (lat_t-1) * 4 + 1, lat_h, lat_w, device=device)
+ else:
+ msk = mask_pixel_values.clone()
+ msk[:, :mask_len] = 1
+ msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
+ msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
+ msk = msk.transpose(1, 2)[0]
+ return msk
+
+ def process(self, pipe: WanVideoPipeline, animate_inpaint_video, animate_mask_video, input_image, tiled, tile_size, tile_stride):
+ if animate_inpaint_video is None or animate_mask_video is None:
+ return {}
+ pipe.load_models_to_device(self.onload_model_names)
+
+ bg_pixel_values = pipe.preprocess_video(animate_inpaint_video)
+ y_reft = pipe.vae.encode(bg_pixel_values, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0].to(dtype=pipe.torch_dtype, device=pipe.device)
+ _, lat_t, lat_h, lat_w = y_reft.shape
+
+ ref_pixel_values = pipe.preprocess_video([input_image])
+ ref_latents = pipe.vae.encode(ref_pixel_values, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
+ mask_ref = self.get_i2v_mask(1, lat_h, lat_w, 1, device=pipe.device)
+ y_ref = torch.concat([mask_ref, ref_latents[0]]).to(dtype=torch.bfloat16, device=pipe.device)
+
+ mask_pixel_values = 1 - pipe.preprocess_video(animate_mask_video, max_value=1, min_value=0)
+ mask_pixel_values = rearrange(mask_pixel_values, "b c t h w -> (b t) c h w")
+ mask_pixel_values = torch.nn.functional.interpolate(mask_pixel_values, size=(lat_h, lat_w), mode='nearest')
+ mask_pixel_values = rearrange(mask_pixel_values, "(b t) c h w -> b t c h w", b=1)[:,:,0]
+ msk_reft = self.get_i2v_mask(lat_t, lat_h, lat_w, 0, mask_pixel_values=mask_pixel_values, device=pipe.device)
+
+ y_reft = torch.concat([msk_reft, y_reft]).to(dtype=torch.bfloat16, device=pipe.device)
+ y = torch.concat([y_ref, y_reft], dim=1).unsqueeze(0)
+ return {"y": y}
+
+
+class WanVideoUnit_LongCatVideo(PipelineUnit):
+ def __init__(self):
+ super().__init__(
+ input_params=("longcat_video",),
+ onload_model_names=("vae",)
+ )
+
+ def process(self, pipe: WanVideoPipeline, longcat_video):
+ if longcat_video is None:
+ return {}
+ pipe.load_models_to_device(self.onload_model_names)
+ longcat_video = pipe.preprocess_video(longcat_video)
+ longcat_latents = pipe.vae.encode(longcat_video, device=pipe.device).to(dtype=pipe.torch_dtype, device=pipe.device)
+ return {"longcat_latents": longcat_latents}
+
+
+class TeaCache:
+ def __init__(self, num_inference_steps, rel_l1_thresh, model_id):
+ self.num_inference_steps = num_inference_steps
+ self.step = 0
+ self.accumulated_rel_l1_distance = 0
+ self.previous_modulated_input = None
+ self.rel_l1_thresh = rel_l1_thresh
+ self.previous_residual = None
+ self.previous_hidden_states = None
+
+ self.coefficients_dict = {
+ "Wan2.1-T2V-1.3B": [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02],
+ "Wan2.1-T2V-14B": [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01],
+ "Wan2.1-I2V-14B-480P": [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01],
+ "Wan2.1-I2V-14B-720P": [ 8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02],
+ }
+ if model_id not in self.coefficients_dict:
+ supported_model_ids = ", ".join([i for i in self.coefficients_dict])
+ raise ValueError(f"{model_id} is not a supported TeaCache model id. Please choose a valid model id in ({supported_model_ids}).")
+ self.coefficients = self.coefficients_dict[model_id]
+
+ def check(self, dit: WanModel, x, t_mod):
+ modulated_inp = t_mod.clone()
+ if self.step == 0 or self.step == self.num_inference_steps - 1:
+ should_calc = True
+ self.accumulated_rel_l1_distance = 0
+ else:
+ coefficients = self.coefficients
+ rescale_func = np.poly1d(coefficients)
+ self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
+ if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
+ should_calc = False
+ else:
+ should_calc = True
+ self.accumulated_rel_l1_distance = 0
+ self.previous_modulated_input = modulated_inp
+ self.step += 1
+ if self.step == self.num_inference_steps:
+ self.step = 0
+ if should_calc:
+ self.previous_hidden_states = x.clone()
+ return not should_calc
+
+ def store(self, hidden_states):
+ self.previous_residual = hidden_states - self.previous_hidden_states
+ self.previous_hidden_states = None
+
+ def update(self, hidden_states):
+ hidden_states = hidden_states + self.previous_residual
+ return hidden_states
+
+
+
+class TemporalTiler_BCTHW:
+ def __init__(self):
+ pass
+
+ def build_1d_mask(self, length, left_bound, right_bound, border_width):
+ x = torch.ones((length,))
+ if border_width == 0:
+ return x
+
+ shift = 0.5
+ if not left_bound:
+ x[:border_width] = (torch.arange(border_width) + shift) / border_width
+ if not right_bound:
+ x[-border_width:] = torch.flip((torch.arange(border_width) + shift) / border_width, dims=(0,))
+ return x
+
+ def build_mask(self, data, is_bound, border_width):
+ _, _, T, _, _ = data.shape
+ t = self.build_1d_mask(T, is_bound[0], is_bound[1], border_width[0])
+ mask = repeat(t, "T -> 1 1 T 1 1")
+ return mask
+
+ def run(self, model_fn, sliding_window_size, sliding_window_stride, computation_device, computation_dtype, model_kwargs, tensor_names, batch_size=None):
+ tensor_names = [tensor_name for tensor_name in tensor_names if model_kwargs.get(tensor_name) is not None]
+ tensor_dict = {tensor_name: model_kwargs[tensor_name] for tensor_name in tensor_names}
+ B, C, T, H, W = tensor_dict[tensor_names[0]].shape
+ if batch_size is not None:
+ B *= batch_size
+ data_device, data_dtype = tensor_dict[tensor_names[0]].device, tensor_dict[tensor_names[0]].dtype
+ value = torch.zeros((B, C, T, H, W), device=data_device, dtype=data_dtype)
+ weight = torch.zeros((1, 1, T, 1, 1), device=data_device, dtype=data_dtype)
+ for t in range(0, T, sliding_window_stride):
+ if t - sliding_window_stride >= 0 and t - sliding_window_stride + sliding_window_size >= T:
+ continue
+ t_ = min(t + sliding_window_size, T)
+ model_kwargs.update({
+ tensor_name: tensor_dict[tensor_name][:, :, t: t_:, :].to(device=computation_device, dtype=computation_dtype) \
+ for tensor_name in tensor_names
+ })
+ model_output = model_fn(**model_kwargs).to(device=data_device, dtype=data_dtype)
+ mask = self.build_mask(
+ model_output,
+ is_bound=(t == 0, t_ == T),
+ border_width=(sliding_window_size - sliding_window_stride,)
+ ).to(device=data_device, dtype=data_dtype)
+ value[:, :, t: t_, :, :] += model_output * mask
+ weight[:, :, t: t_, :, :] += mask
+ value /= weight
+ model_kwargs.update(tensor_dict)
+ return value
+
+
+
+def model_fn_wan_video(
+ dit: WanModel,
+ motion_controller: WanMotionControllerModel = None,
+ vace: VaceWanModel = None,
+ vap: MotWanModel = None,
+ animate_adapter: WanAnimateAdapter = None,
+ latents: torch.Tensor = None,
+ timestep: torch.Tensor = None,
+ context: torch.Tensor = None,
+ clip_feature: Optional[torch.Tensor] = None,
+ y: Optional[torch.Tensor] = None,
+ reference_latents = None,
+ vace_context = None,
+ vace_scale = 1.0,
+ audio_embeds: Optional[torch.Tensor] = None,
+ motion_latents: Optional[torch.Tensor] = None,
+ s2v_pose_latents: Optional[torch.Tensor] = None,
+ vap_hidden_state = None,
+ vap_clip_feature = None,
+ context_vap = None,
+ drop_motion_frames: bool = True,
+ tea_cache: TeaCache = None,
+ use_unified_sequence_parallel: bool = False,
+ motion_bucket_id: Optional[torch.Tensor] = None,
+ pose_latents=None,
+ face_pixel_values=None,
+ longcat_latents=None,
+ sliding_window_size: Optional[int] = None,
+ sliding_window_stride: Optional[int] = None,
+ cfg_merge: bool = False,
+ use_gradient_checkpointing: bool = False,
+ use_gradient_checkpointing_offload: bool = False,
+ control_camera_latents_input = None,
+ fuse_vae_embedding_in_latents: bool = False,
+ **kwargs,
+):
+ if sliding_window_size is not None and sliding_window_stride is not None:
+ model_kwargs = dict(
+ dit=dit,
+ motion_controller=motion_controller,
+ vace=vace,
+ latents=latents,
+ timestep=timestep,
+ context=context,
+ clip_feature=clip_feature,
+ y=y,
+ reference_latents=reference_latents,
+ vace_context=vace_context,
+ vace_scale=vace_scale,
+ tea_cache=tea_cache,
+ use_unified_sequence_parallel=use_unified_sequence_parallel,
+ motion_bucket_id=motion_bucket_id,
+ )
+ return TemporalTiler_BCTHW().run(
+ model_fn_wan_video,
+ sliding_window_size, sliding_window_stride,
+ latents.device, latents.dtype,
+ model_kwargs=model_kwargs,
+ tensor_names=["latents", "y"],
+ batch_size=2 if cfg_merge else 1
+ )
+ # LongCat-Video
+ if isinstance(dit, LongCatVideoTransformer3DModel):
+ return model_fn_longcat_video(
+ dit=dit,
+ latents=latents,
+ timestep=timestep,
+ context=context,
+ longcat_latents=longcat_latents,
+ use_gradient_checkpointing=use_gradient_checkpointing,
+ use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
+ )
+
+ # wan2.2 s2v
+ if audio_embeds is not None:
+ return model_fn_wans2v(
+ dit=dit,
+ latents=latents,
+ timestep=timestep,
+ context=context,
+ audio_embeds=audio_embeds,
+ motion_latents=motion_latents,
+ s2v_pose_latents=s2v_pose_latents,
+ drop_motion_frames=drop_motion_frames,
+ use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
+ use_gradient_checkpointing=use_gradient_checkpointing,
+ use_unified_sequence_parallel=use_unified_sequence_parallel,
+ )
+
+ if use_unified_sequence_parallel:
+ import torch.distributed as dist
+ from xfuser.core.distributed import (get_sequence_parallel_rank,
+ get_sequence_parallel_world_size,
+ get_sp_group)
+
+ # Timestep
+ if dit.seperated_timestep and fuse_vae_embedding_in_latents:
+ timestep = torch.concat([
+ torch.zeros((1, latents.shape[3] * latents.shape[4] // 4), dtype=latents.dtype, device=latents.device),
+ torch.ones((latents.shape[2] - 1, latents.shape[3] * latents.shape[4] // 4), dtype=latents.dtype, device=latents.device) * timestep
+ ]).flatten()
+ t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep).unsqueeze(0))
+ if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1:
+ t_chunks = torch.chunk(t, get_sequence_parallel_world_size(), dim=1)
+ t_chunks = [torch.nn.functional.pad(chunk, (0, 0, 0, t_chunks[0].shape[1]-chunk.shape[1]), value=0) for chunk in t_chunks]
+ t = t_chunks[get_sequence_parallel_rank()]
+ t_mod = dit.time_projection(t).unflatten(2, (6, dit.dim))
+ else:
+ t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep))
+ t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim))
+
+ # Motion Controller
+ if motion_bucket_id is not None and motion_controller is not None:
+ t_mod = t_mod + motion_controller(motion_bucket_id).unflatten(1, (6, dit.dim))
+ context = dit.text_embedding(context)
+
+ x = latents
+ # Merged cfg
+ if x.shape[0] != context.shape[0]:
+ x = torch.concat([x] * context.shape[0], dim=0)
+ if timestep.shape[0] != context.shape[0]:
+ timestep = torch.concat([timestep] * context.shape[0], dim=0)
+
+ # Image Embedding
+ if y is not None and dit.require_vae_embedding:
+ x = torch.cat([x, y], dim=1)
+ if clip_feature is not None and dit.require_clip_embedding:
+ clip_embdding = dit.img_emb(clip_feature)
+ context = torch.cat([clip_embdding, context], dim=1)
+
+ # Camera control
+ x = dit.patchify(x, control_camera_latents_input)
+
+ # Animate
+ if pose_latents is not None and face_pixel_values is not None:
+ x, motion_vec = animate_adapter.after_patch_embedding(x, pose_latents, face_pixel_values)
+
+ # Patchify
+ f, h, w = x.shape[2:]
+ x = rearrange(x, 'b c f h w -> b (f h w) c').contiguous()
+
+ # Reference image
+ if reference_latents is not None:
+ if len(reference_latents.shape) == 5:
+ reference_latents = reference_latents[:, :, 0]
+ reference_latents = dit.ref_conv(reference_latents).flatten(2).transpose(1, 2)
+ x = torch.concat([reference_latents, x], dim=1)
+ f += 1
+
+ freqs = torch.cat([
+ dit.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
+ dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
+ dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
+ ], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
+
+ # VAP
+ if vap is not None:
+ # hidden state
+ x_vap = vap_hidden_state
+ x_vap = vap.patchify(x_vap)
+ x_vap = rearrange(x_vap, 'b c f h w -> b (f h w) c').contiguous()
+ # Timestep
+ clean_timestep = torch.ones(timestep.shape, device=timestep.device).to(timestep.dtype)
+ t = vap.time_embedding(sinusoidal_embedding_1d(vap.freq_dim, clean_timestep))
+ t_mod_vap = vap.time_projection(t).unflatten(1, (6, vap.dim))
+
+ # rope
+ freqs_vap = vap.compute_freqs_mot(f,h,w).to(x.device)
+
+ # context
+ vap_clip_embedding = vap.img_emb(vap_clip_feature)
+ context_vap = vap.text_embedding(context_vap)
+ context_vap = torch.cat([vap_clip_embedding, context_vap], dim=1)
+
+ # TeaCache
+ if tea_cache is not None:
+ tea_cache_update = tea_cache.check(dit, x, t_mod)
+ else:
+ tea_cache_update = False
+
+ if vace_context is not None:
+ vace_hints = vace(
+ x, vace_context, context, t_mod, freqs,
+ use_gradient_checkpointing=use_gradient_checkpointing,
+ use_gradient_checkpointing_offload=use_gradient_checkpointing_offload
+ )
+
+ # blocks
+ if use_unified_sequence_parallel:
+ if dist.is_initialized() and dist.get_world_size() > 1:
+ chunks = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)
+ pad_shape = chunks[0].shape[1] - chunks[-1].shape[1]
+ chunks = [torch.nn.functional.pad(chunk, (0, 0, 0, chunks[0].shape[1]-chunk.shape[1]), value=0) for chunk in chunks]
+ x = chunks[get_sequence_parallel_rank()]
+ if tea_cache_update:
+ x = tea_cache.update(x)
+ else:
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+ return custom_forward
+
+ def create_custom_forward_vap(block, vap):
+ def custom_forward(*inputs):
+ return vap(block, *inputs)
+ return custom_forward
+
+ for block_id, block in enumerate(dit.blocks):
+ # Block
+ if vap is not None and block_id in vap.mot_layers_mapping:
+ if use_gradient_checkpointing_offload:
+ with torch.autograd.graph.save_on_cpu():
+ x, x_vap = torch.utils.checkpoint.checkpoint(
+ create_custom_forward_vap(block, vap),
+ x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id,
+ use_reentrant=False,
+ )
+ elif use_gradient_checkpointing:
+ x, x_vap = torch.utils.checkpoint.checkpoint(
+ create_custom_forward_vap(block, vap),
+ x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id,
+ use_reentrant=False,
+ )
+ else:
+ x, x_vap = vap(block, x, context, t_mod, freqs, x_vap, context_vap, t_mod_vap, freqs_vap, block_id)
+ else:
+ if use_gradient_checkpointing_offload:
+ with torch.autograd.graph.save_on_cpu():
+ x = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(block),
+ x, context, t_mod, freqs,
+ use_reentrant=False,
+ )
+ elif use_gradient_checkpointing:
+ x = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(block),
+ x, context, t_mod, freqs,
+ use_reentrant=False,
+ )
+ else:
+ x = block(x, context, t_mod, freqs)
+
+ # VACE
+ if vace_context is not None and block_id in vace.vace_layers_mapping:
+ current_vace_hint = vace_hints[vace.vace_layers_mapping[block_id]]
+ if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1:
+ current_vace_hint = torch.chunk(current_vace_hint, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
+ current_vace_hint = torch.nn.functional.pad(current_vace_hint, (0, 0, 0, chunks[0].shape[1] - current_vace_hint.shape[1]), value=0)
+ x = x + current_vace_hint * vace_scale
+
+ # Animate
+ if pose_latents is not None and face_pixel_values is not None:
+ x = animate_adapter.after_transformer_block(block_id, x, motion_vec)
+ if tea_cache is not None:
+ tea_cache.store(x)
+
+ x = dit.head(x, t)
+ if use_unified_sequence_parallel:
+ if dist.is_initialized() and dist.get_world_size() > 1:
+ x = get_sp_group().all_gather(x, dim=1)
+ x = x[:, :-pad_shape] if pad_shape > 0 else x
+ # Remove reference latents
+ if reference_latents is not None:
+ x = x[:, reference_latents.shape[1]:]
+ f -= 1
+ x = dit.unpatchify(x, (f, h, w))
+ return x
+
+
+def model_fn_longcat_video(
+ dit: LongCatVideoTransformer3DModel,
+ latents: torch.Tensor = None,
+ timestep: torch.Tensor = None,
+ context: torch.Tensor = None,
+ longcat_latents: torch.Tensor = None,
+ use_gradient_checkpointing=False,
+ use_gradient_checkpointing_offload=False,
+):
+ if longcat_latents is not None:
+ latents[:, :, :longcat_latents.shape[2]] = longcat_latents
+ num_cond_latents = longcat_latents.shape[2]
+ else:
+ num_cond_latents = 0
+ context = context.unsqueeze(0)
+ encoder_attention_mask = torch.any(context != 0, dim=-1)[:, 0].to(torch.int64)
+ output = dit(
+ latents,
+ timestep,
+ context,
+ encoder_attention_mask,
+ num_cond_latents=num_cond_latents,
+ use_gradient_checkpointing=use_gradient_checkpointing,
+ use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
+ )
+ output = -output
+ output = output.to(latents.dtype)
+ return output
+
+
+def model_fn_wans2v(
+ dit,
+ latents,
+ timestep,
+ context,
+ audio_embeds,
+ motion_latents,
+ s2v_pose_latents,
+ drop_motion_frames=True,
+ use_gradient_checkpointing_offload=False,
+ use_gradient_checkpointing=False,
+ use_unified_sequence_parallel=False,
+):
+ if use_unified_sequence_parallel:
+ import torch.distributed as dist
+ from xfuser.core.distributed import (get_sequence_parallel_rank,
+ get_sequence_parallel_world_size,
+ get_sp_group)
+ origin_ref_latents = latents[:, :, 0:1]
+ x = latents[:, :, 1:]
+
+ # context embedding
+ context = dit.text_embedding(context)
+
+ # audio encode
+ audio_emb_global, merged_audio_emb = dit.cal_audio_emb(audio_embeds)
+
+ # x and s2v_pose_latents
+ s2v_pose_latents = torch.zeros_like(x) if s2v_pose_latents is None else s2v_pose_latents
+ x, (f, h, w) = dit.patchify(dit.patch_embedding(x) + dit.cond_encoder(s2v_pose_latents))
+ seq_len_x = seq_len_x_global = x.shape[1] # global used for unified sequence parallel
+
+ # reference image
+ ref_latents, (rf, rh, rw) = dit.patchify(dit.patch_embedding(origin_ref_latents))
+ grid_sizes = dit.get_grid_sizes((f, h, w), (rf, rh, rw))
+ x = torch.cat([x, ref_latents], dim=1)
+ # mask
+ mask = torch.cat([torch.zeros([1, seq_len_x]), torch.ones([1, ref_latents.shape[1]])], dim=1).to(torch.long).to(x.device)
+ # freqs
+ pre_compute_freqs = rope_precompute(x.detach().view(1, x.size(1), dit.num_heads, dit.dim // dit.num_heads), grid_sizes, dit.freqs, start=None)
+ # motion
+ x, pre_compute_freqs, mask = dit.inject_motion(x, pre_compute_freqs, mask, motion_latents, drop_motion_frames=drop_motion_frames, add_last_motion=2)
+
+ x = x + dit.trainable_cond_mask(mask).to(x.dtype)
+
+ # tmod
+ timestep = torch.cat([timestep, torch.zeros([1], dtype=timestep.dtype, device=timestep.device)])
+ t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep))
+ t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim)).unsqueeze(2).transpose(0, 2)
+
+ if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1:
+ world_size, sp_rank = get_sequence_parallel_world_size(), get_sequence_parallel_rank()
+ assert x.shape[1] % world_size == 0, f"the dimension after chunk must be divisible by world size, but got {x.shape[1]} and {get_sequence_parallel_world_size()}"
+ x = torch.chunk(x, world_size, dim=1)[sp_rank]
+ seg_idxs = [0] + list(torch.cumsum(torch.tensor([x.shape[1]] * world_size), dim=0).cpu().numpy())
+ seq_len_x_list = [min(max(0, seq_len_x - seg_idxs[i]), x.shape[1]) for i in range(len(seg_idxs)-1)]
+ seq_len_x = seq_len_x_list[sp_rank]
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+ return custom_forward
+
+ for block_id, block in enumerate(dit.blocks):
+ if use_gradient_checkpointing_offload:
+ with torch.autograd.graph.save_on_cpu():
+ x = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(block),
+ x, context, t_mod, seq_len_x, pre_compute_freqs[0],
+ use_reentrant=False,
+ )
+ x = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(lambda x: dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)),
+ x,
+ use_reentrant=False,
+ )
+ elif use_gradient_checkpointing:
+ x = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(block),
+ x, context, t_mod, seq_len_x, pre_compute_freqs[0],
+ use_reentrant=False,
+ )
+ x = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(lambda x: dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)),
+ x,
+ use_reentrant=False,
+ )
+ else:
+ x = block(x, context, t_mod, seq_len_x, pre_compute_freqs[0])
+ x = dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x_global, use_unified_sequence_parallel)
+
+ if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1:
+ x = get_sp_group().all_gather(x, dim=1)
+
+ x = x[:, :seq_len_x_global]
+ x = dit.head(x, t[:-1])
+ x = dit.unpatchify(x, (f, h, w))
+ # make compatible with wan video
+ x = torch.cat([origin_ref_latents, x], dim=2)
+ return x
diff --git a/diffsynth/processors/FastBlend.py b/diffsynth/processors/FastBlend.py
new file mode 100644
index 0000000000000000000000000000000000000000..fed33f4fdd215c8c9dc46f3b07d9453a12cc6b98
--- /dev/null
+++ b/diffsynth/processors/FastBlend.py
@@ -0,0 +1,142 @@
+from PIL import Image
+import cupy as cp
+import numpy as np
+from tqdm import tqdm
+from ..extensions.FastBlend.patch_match import PyramidPatchMatcher
+from ..extensions.FastBlend.runners.fast import TableManager
+from .base import VideoProcessor
+
+
+class FastBlendSmoother(VideoProcessor):
+ def __init__(
+ self,
+ inference_mode="fast", batch_size=8, window_size=60,
+ minimum_patch_size=5, threads_per_block=8, num_iter=5, gpu_id=0, guide_weight=10.0, initialize="identity", tracking_window_size=0
+ ):
+ self.inference_mode = inference_mode
+ self.batch_size = batch_size
+ self.window_size = window_size
+ self.ebsynth_config = {
+ "minimum_patch_size": minimum_patch_size,
+ "threads_per_block": threads_per_block,
+ "num_iter": num_iter,
+ "gpu_id": gpu_id,
+ "guide_weight": guide_weight,
+ "initialize": initialize,
+ "tracking_window_size": tracking_window_size
+ }
+
+ @staticmethod
+ def from_model_manager(model_manager, **kwargs):
+ # TODO: fetch GPU ID from model_manager
+ return FastBlendSmoother(**kwargs)
+
+ def inference_fast(self, frames_guide, frames_style):
+ table_manager = TableManager()
+ patch_match_engine = PyramidPatchMatcher(
+ image_height=frames_style[0].shape[0],
+ image_width=frames_style[0].shape[1],
+ channel=3,
+ **self.ebsynth_config
+ )
+ # left part
+ table_l = table_manager.build_remapping_table(frames_guide, frames_style, patch_match_engine, self.batch_size, desc="Fast Mode Step 1/4")
+ table_l = table_manager.remapping_table_to_blending_table(table_l)
+ table_l = table_manager.process_window_sum(frames_guide, table_l, patch_match_engine, self.window_size, self.batch_size, desc="Fast Mode Step 2/4")
+ # right part
+ table_r = table_manager.build_remapping_table(frames_guide[::-1], frames_style[::-1], patch_match_engine, self.batch_size, desc="Fast Mode Step 3/4")
+ table_r = table_manager.remapping_table_to_blending_table(table_r)
+ table_r = table_manager.process_window_sum(frames_guide[::-1], table_r, patch_match_engine, self.window_size, self.batch_size, desc="Fast Mode Step 4/4")[::-1]
+ # merge
+ frames = []
+ for (frame_l, weight_l), frame_m, (frame_r, weight_r) in zip(table_l, frames_style, table_r):
+ weight_m = -1
+ weight = weight_l + weight_m + weight_r
+ frame = frame_l * (weight_l / weight) + frame_m * (weight_m / weight) + frame_r * (weight_r / weight)
+ frames.append(frame)
+ frames = [frame.clip(0, 255).astype("uint8") for frame in frames]
+ frames = [Image.fromarray(frame) for frame in frames]
+ return frames
+
+ def inference_balanced(self, frames_guide, frames_style):
+ patch_match_engine = PyramidPatchMatcher(
+ image_height=frames_style[0].shape[0],
+ image_width=frames_style[0].shape[1],
+ channel=3,
+ **self.ebsynth_config
+ )
+ output_frames = []
+ # tasks
+ n = len(frames_style)
+ tasks = []
+ for target in range(n):
+ for source in range(target - self.window_size, target + self.window_size + 1):
+ if source >= 0 and source < n and source != target:
+ tasks.append((source, target))
+ # run
+ frames = [(None, 1) for i in range(n)]
+ for batch_id in tqdm(range(0, len(tasks), self.batch_size), desc="Balanced Mode"):
+ tasks_batch = tasks[batch_id: min(batch_id+self.batch_size, len(tasks))]
+ source_guide = np.stack([frames_guide[source] for source, target in tasks_batch])
+ target_guide = np.stack([frames_guide[target] for source, target in tasks_batch])
+ source_style = np.stack([frames_style[source] for source, target in tasks_batch])
+ _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
+ for (source, target), result in zip(tasks_batch, target_style):
+ frame, weight = frames[target]
+ if frame is None:
+ frame = frames_style[target]
+ frames[target] = (
+ frame * (weight / (weight + 1)) + result / (weight + 1),
+ weight + 1
+ )
+ if weight + 1 == min(n, target + self.window_size + 1) - max(0, target - self.window_size):
+ frame = frame.clip(0, 255).astype("uint8")
+ output_frames.append(Image.fromarray(frame))
+ frames[target] = (None, 1)
+ return output_frames
+
+ def inference_accurate(self, frames_guide, frames_style):
+ patch_match_engine = PyramidPatchMatcher(
+ image_height=frames_style[0].shape[0],
+ image_width=frames_style[0].shape[1],
+ channel=3,
+ use_mean_target_style=True,
+ **self.ebsynth_config
+ )
+ output_frames = []
+ # run
+ n = len(frames_style)
+ for target in tqdm(range(n), desc="Accurate Mode"):
+ l, r = max(target - self.window_size, 0), min(target + self.window_size + 1, n)
+ remapped_frames = []
+ for i in range(l, r, self.batch_size):
+ j = min(i + self.batch_size, r)
+ source_guide = np.stack([frames_guide[source] for source in range(i, j)])
+ target_guide = np.stack([frames_guide[target]] * (j - i))
+ source_style = np.stack([frames_style[source] for source in range(i, j)])
+ _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
+ remapped_frames.append(target_style)
+ frame = np.concatenate(remapped_frames, axis=0).mean(axis=0)
+ frame = frame.clip(0, 255).astype("uint8")
+ output_frames.append(Image.fromarray(frame))
+ return output_frames
+
+ def release_vram(self):
+ mempool = cp.get_default_memory_pool()
+ pinned_mempool = cp.get_default_pinned_memory_pool()
+ mempool.free_all_blocks()
+ pinned_mempool.free_all_blocks()
+
+ def __call__(self, rendered_frames, original_frames=None, **kwargs):
+ rendered_frames = [np.array(frame) for frame in rendered_frames]
+ original_frames = [np.array(frame) for frame in original_frames]
+ if self.inference_mode == "fast":
+ output_frames = self.inference_fast(original_frames, rendered_frames)
+ elif self.inference_mode == "balanced":
+ output_frames = self.inference_balanced(original_frames, rendered_frames)
+ elif self.inference_mode == "accurate":
+ output_frames = self.inference_accurate(original_frames, rendered_frames)
+ else:
+ raise ValueError("inference_mode must be fast, balanced or accurate")
+ self.release_vram()
+ return output_frames
diff --git a/diffsynth/processors/PILEditor.py b/diffsynth/processors/PILEditor.py
new file mode 100644
index 0000000000000000000000000000000000000000..01011d8724f61283550d503c5c20ae6fd0375ec7
--- /dev/null
+++ b/diffsynth/processors/PILEditor.py
@@ -0,0 +1,28 @@
+from PIL import ImageEnhance
+from .base import VideoProcessor
+
+
+class ContrastEditor(VideoProcessor):
+ def __init__(self, rate=1.5):
+ self.rate = rate
+
+ @staticmethod
+ def from_model_manager(model_manager, **kwargs):
+ return ContrastEditor(**kwargs)
+
+ def __call__(self, rendered_frames, **kwargs):
+ rendered_frames = [ImageEnhance.Contrast(i).enhance(self.rate) for i in rendered_frames]
+ return rendered_frames
+
+
+class SharpnessEditor(VideoProcessor):
+ def __init__(self, rate=1.5):
+ self.rate = rate
+
+ @staticmethod
+ def from_model_manager(model_manager, **kwargs):
+ return SharpnessEditor(**kwargs)
+
+ def __call__(self, rendered_frames, **kwargs):
+ rendered_frames = [ImageEnhance.Sharpness(i).enhance(self.rate) for i in rendered_frames]
+ return rendered_frames
diff --git a/diffsynth/processors/RIFE.py b/diffsynth/processors/RIFE.py
new file mode 100644
index 0000000000000000000000000000000000000000..4186eb31496e9a1bf38df06eb64921226f07ee09
--- /dev/null
+++ b/diffsynth/processors/RIFE.py
@@ -0,0 +1,77 @@
+import torch
+import numpy as np
+from PIL import Image
+from .base import VideoProcessor
+
+
+class RIFESmoother(VideoProcessor):
+ def __init__(self, model, device="cuda", scale=1.0, batch_size=4, interpolate=True):
+ self.model = model
+ self.device = device
+
+ # IFNet only does not support float16
+ self.torch_dtype = torch.float32
+
+ # Other parameters
+ self.scale = scale
+ self.batch_size = batch_size
+ self.interpolate = interpolate
+
+ @staticmethod
+ def from_model_manager(model_manager, **kwargs):
+ return RIFESmoother(model_manager.RIFE, device=model_manager.device, **kwargs)
+
+ def process_image(self, image):
+ width, height = image.size
+ if width % 32 != 0 or height % 32 != 0:
+ width = (width + 31) // 32
+ height = (height + 31) // 32
+ image = image.resize((width, height))
+ image = torch.Tensor(np.array(image, dtype=np.float32)[:, :, [2,1,0]] / 255).permute(2, 0, 1)
+ return image
+
+ def process_images(self, images):
+ images = [self.process_image(image) for image in images]
+ images = torch.stack(images)
+ return images
+
+ def decode_images(self, images):
+ images = (images[:, [2,1,0]].permute(0, 2, 3, 1) * 255).clip(0, 255).numpy().astype(np.uint8)
+ images = [Image.fromarray(image) for image in images]
+ return images
+
+ def process_tensors(self, input_tensor, scale=1.0, batch_size=4):
+ output_tensor = []
+ for batch_id in range(0, input_tensor.shape[0], batch_size):
+ batch_id_ = min(batch_id + batch_size, input_tensor.shape[0])
+ batch_input_tensor = input_tensor[batch_id: batch_id_]
+ batch_input_tensor = batch_input_tensor.to(device=self.device, dtype=self.torch_dtype)
+ flow, mask, merged = self.model(batch_input_tensor, [4/scale, 2/scale, 1/scale])
+ output_tensor.append(merged[2].cpu())
+ output_tensor = torch.concat(output_tensor, dim=0)
+ return output_tensor
+
+ @torch.no_grad()
+ def __call__(self, rendered_frames, **kwargs):
+ # Preprocess
+ processed_images = self.process_images(rendered_frames)
+
+ # Input
+ input_tensor = torch.cat((processed_images[:-2], processed_images[2:]), dim=1)
+
+ # Interpolate
+ output_tensor = self.process_tensors(input_tensor, scale=self.scale, batch_size=self.batch_size)
+
+ if self.interpolate:
+ # Blend
+ input_tensor = torch.cat((processed_images[1:-1], output_tensor), dim=1)
+ output_tensor = self.process_tensors(input_tensor, scale=self.scale, batch_size=self.batch_size)
+ processed_images[1:-1] = output_tensor
+ else:
+ processed_images[1:-1] = (processed_images[1:-1] + output_tensor) / 2
+
+ # To images
+ output_images = self.decode_images(processed_images)
+ if output_images[0].size != rendered_frames[0].size:
+ output_images = [image.resize(rendered_frames[0].size) for image in output_images]
+ return output_images
diff --git a/diffsynth/processors/__init__.py b/diffsynth/processors/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/diffsynth/processors/base.py b/diffsynth/processors/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..278a9c1b74044987cc116de35292a96de8b13737
--- /dev/null
+++ b/diffsynth/processors/base.py
@@ -0,0 +1,6 @@
+class VideoProcessor:
+ def __init__(self):
+ pass
+
+ def __call__(self):
+ raise NotImplementedError
diff --git a/diffsynth/processors/sequencial_processor.py b/diffsynth/processors/sequencial_processor.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b5bc9454f0b9d74f10bb4a6bff92db77f26325c
--- /dev/null
+++ b/diffsynth/processors/sequencial_processor.py
@@ -0,0 +1,41 @@
+from .base import VideoProcessor
+
+
+class AutoVideoProcessor(VideoProcessor):
+ def __init__(self):
+ pass
+
+ @staticmethod
+ def from_model_manager(model_manager, processor_type, **kwargs):
+ if processor_type == "FastBlend":
+ from .FastBlend import FastBlendSmoother
+ return FastBlendSmoother.from_model_manager(model_manager, **kwargs)
+ elif processor_type == "Contrast":
+ from .PILEditor import ContrastEditor
+ return ContrastEditor.from_model_manager(model_manager, **kwargs)
+ elif processor_type == "Sharpness":
+ from .PILEditor import SharpnessEditor
+ return SharpnessEditor.from_model_manager(model_manager, **kwargs)
+ elif processor_type == "RIFE":
+ from .RIFE import RIFESmoother
+ return RIFESmoother.from_model_manager(model_manager, **kwargs)
+ else:
+ raise ValueError(f"invalid processor_type: {processor_type}")
+
+
+class SequencialProcessor(VideoProcessor):
+ def __init__(self, processors=[]):
+ self.processors = processors
+
+ @staticmethod
+ def from_model_manager(model_manager, configs):
+ processors = [
+ AutoVideoProcessor.from_model_manager(model_manager, config["processor_type"], **config["config"])
+ for config in configs
+ ]
+ return SequencialProcessor(processors)
+
+ def __call__(self, rendered_frames, **kwargs):
+ for processor in self.processors:
+ rendered_frames = processor(rendered_frames, **kwargs)
+ return rendered_frames
diff --git a/diffsynth/prompters/__init__.py b/diffsynth/prompters/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f27c6f153b076de484c2b650e8bf16d7142d1099
--- /dev/null
+++ b/diffsynth/prompters/__init__.py
@@ -0,0 +1,12 @@
+from .prompt_refiners import Translator, BeautifulPrompt, QwenPrompt
+from .sd_prompter import SDPrompter
+from .sdxl_prompter import SDXLPrompter
+from .sd3_prompter import SD3Prompter
+from .hunyuan_dit_prompter import HunyuanDiTPrompter
+from .kolors_prompter import KolorsPrompter
+from .flux_prompter import FluxPrompter
+from .omost import OmostPromter
+from .cog_prompter import CogPrompter
+from .hunyuan_video_prompter import HunyuanVideoPrompter
+from .stepvideo_prompter import StepVideoPrompter
+from .wan_prompter import WanPrompter
diff --git a/diffsynth/prompters/base_prompter.py b/diffsynth/prompters/base_prompter.py
new file mode 100644
index 0000000000000000000000000000000000000000..136abd18fabdb04e618f59801420c9ce5fb94634
--- /dev/null
+++ b/diffsynth/prompters/base_prompter.py
@@ -0,0 +1,70 @@
+from ..models.model_manager import ModelManager
+import torch
+
+
+
+def tokenize_long_prompt(tokenizer, prompt, max_length=None):
+ # Get model_max_length from self.tokenizer
+ length = tokenizer.model_max_length if max_length is None else max_length
+
+ # To avoid the warning. set self.tokenizer.model_max_length to +oo.
+ tokenizer.model_max_length = 99999999
+
+ # Tokenize it!
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids
+
+ # Determine the real length.
+ max_length = (input_ids.shape[1] + length - 1) // length * length
+
+ # Restore tokenizer.model_max_length
+ tokenizer.model_max_length = length
+
+ # Tokenize it again with fixed length.
+ input_ids = tokenizer(
+ prompt,
+ return_tensors="pt",
+ padding="max_length",
+ max_length=max_length,
+ truncation=True
+ ).input_ids
+
+ # Reshape input_ids to fit the text encoder.
+ num_sentence = input_ids.shape[1] // length
+ input_ids = input_ids.reshape((num_sentence, length))
+
+ return input_ids
+
+
+
+class BasePrompter:
+ def __init__(self):
+ self.refiners = []
+ self.extenders = []
+
+
+ def load_prompt_refiners(self, model_manager: ModelManager, refiner_classes=[]):
+ for refiner_class in refiner_classes:
+ refiner = refiner_class.from_model_manager(model_manager)
+ self.refiners.append(refiner)
+
+ def load_prompt_extenders(self,model_manager:ModelManager,extender_classes=[]):
+ for extender_class in extender_classes:
+ extender = extender_class.from_model_manager(model_manager)
+ self.extenders.append(extender)
+
+
+ @torch.no_grad()
+ def process_prompt(self, prompt, positive=True):
+ if isinstance(prompt, list):
+ prompt = [self.process_prompt(prompt_, positive=positive) for prompt_ in prompt]
+ else:
+ for refiner in self.refiners:
+ prompt = refiner(prompt, positive=positive)
+ return prompt
+
+ @torch.no_grad()
+ def extend_prompt(self, prompt:str, positive=True):
+ extended_prompt = dict(prompt=prompt)
+ for extender in self.extenders:
+ extended_prompt = extender(extended_prompt)
+ return extended_prompt
\ No newline at end of file
diff --git a/diffsynth/prompters/cog_prompter.py b/diffsynth/prompters/cog_prompter.py
new file mode 100644
index 0000000000000000000000000000000000000000..a1ab84a69c32e681e087ba7ed0642a6177fe1f7a
--- /dev/null
+++ b/diffsynth/prompters/cog_prompter.py
@@ -0,0 +1,46 @@
+from .base_prompter import BasePrompter
+from ..models.flux_text_encoder import FluxTextEncoder2
+from transformers import T5TokenizerFast
+import os
+
+
+class CogPrompter(BasePrompter):
+ def __init__(
+ self,
+ tokenizer_path=None
+ ):
+ if tokenizer_path is None:
+ base_path = os.path.dirname(os.path.dirname(__file__))
+ tokenizer_path = os.path.join(base_path, "tokenizer_configs/cog/tokenizer")
+ super().__init__()
+ self.tokenizer = T5TokenizerFast.from_pretrained(tokenizer_path)
+ self.text_encoder: FluxTextEncoder2 = None
+
+
+ def fetch_models(self, text_encoder: FluxTextEncoder2 = None):
+ self.text_encoder = text_encoder
+
+
+ def encode_prompt_using_t5(self, prompt, text_encoder, tokenizer, max_length, device):
+ input_ids = tokenizer(
+ prompt,
+ return_tensors="pt",
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ ).input_ids.to(device)
+ prompt_emb = text_encoder(input_ids)
+ prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1))
+
+ return prompt_emb
+
+
+ def encode_prompt(
+ self,
+ prompt,
+ positive=True,
+ device="cuda"
+ ):
+ prompt = self.process_prompt(prompt, positive=positive)
+ prompt_emb = self.encode_prompt_using_t5(prompt, self.text_encoder, self.tokenizer, 226, device)
+ return prompt_emb
diff --git a/diffsynth/prompters/flux_prompter.py b/diffsynth/prompters/flux_prompter.py
new file mode 100644
index 0000000000000000000000000000000000000000..a3a06ff8df29345f505873cf1b79c963229f3efb
--- /dev/null
+++ b/diffsynth/prompters/flux_prompter.py
@@ -0,0 +1,74 @@
+from .base_prompter import BasePrompter
+from ..models.flux_text_encoder import FluxTextEncoder2
+from ..models.sd3_text_encoder import SD3TextEncoder1
+from transformers import CLIPTokenizer, T5TokenizerFast
+import os, torch
+
+
+class FluxPrompter(BasePrompter):
+ def __init__(
+ self,
+ tokenizer_1_path=None,
+ tokenizer_2_path=None
+ ):
+ if tokenizer_1_path is None:
+ base_path = os.path.dirname(os.path.dirname(__file__))
+ tokenizer_1_path = os.path.join(base_path, "tokenizer_configs/flux/tokenizer_1")
+ if tokenizer_2_path is None:
+ base_path = os.path.dirname(os.path.dirname(__file__))
+ tokenizer_2_path = os.path.join(base_path, "tokenizer_configs/flux/tokenizer_2")
+ super().__init__()
+ self.tokenizer_1 = CLIPTokenizer.from_pretrained(tokenizer_1_path)
+ self.tokenizer_2 = T5TokenizerFast.from_pretrained(tokenizer_2_path)
+ self.text_encoder_1: SD3TextEncoder1 = None
+ self.text_encoder_2: FluxTextEncoder2 = None
+
+
+ def fetch_models(self, text_encoder_1: SD3TextEncoder1 = None, text_encoder_2: FluxTextEncoder2 = None):
+ self.text_encoder_1 = text_encoder_1
+ self.text_encoder_2 = text_encoder_2
+
+
+ def encode_prompt_using_clip(self, prompt, text_encoder, tokenizer, max_length, device):
+ input_ids = tokenizer(
+ prompt,
+ return_tensors="pt",
+ padding="max_length",
+ max_length=max_length,
+ truncation=True
+ ).input_ids.to(device)
+ pooled_prompt_emb, _ = text_encoder(input_ids)
+ return pooled_prompt_emb
+
+
+ def encode_prompt_using_t5(self, prompt, text_encoder, tokenizer, max_length, device):
+ input_ids = tokenizer(
+ prompt,
+ return_tensors="pt",
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ ).input_ids.to(device)
+ prompt_emb = text_encoder(input_ids)
+ return prompt_emb
+
+
+ def encode_prompt(
+ self,
+ prompt,
+ positive=True,
+ device="cuda",
+ t5_sequence_length=512,
+ ):
+ prompt = self.process_prompt(prompt, positive=positive)
+
+ # CLIP
+ pooled_prompt_emb = self.encode_prompt_using_clip(prompt, self.text_encoder_1, self.tokenizer_1, 77, device)
+
+ # T5
+ prompt_emb = self.encode_prompt_using_t5(prompt, self.text_encoder_2, self.tokenizer_2, t5_sequence_length, device)
+
+ # text_ids
+ text_ids = torch.zeros(prompt_emb.shape[0], prompt_emb.shape[1], 3).to(device=device, dtype=prompt_emb.dtype)
+
+ return prompt_emb, pooled_prompt_emb, text_ids
diff --git a/diffsynth/prompters/hunyuan_dit_prompter.py b/diffsynth/prompters/hunyuan_dit_prompter.py
new file mode 100644
index 0000000000000000000000000000000000000000..52a22ed72ab77ef668183119fff67db3141ee561
--- /dev/null
+++ b/diffsynth/prompters/hunyuan_dit_prompter.py
@@ -0,0 +1,69 @@
+from .base_prompter import BasePrompter
+from ..models.model_manager import ModelManager
+from ..models import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
+from transformers import BertTokenizer, AutoTokenizer
+import warnings, os
+
+
+class HunyuanDiTPrompter(BasePrompter):
+ def __init__(
+ self,
+ tokenizer_path=None,
+ tokenizer_t5_path=None
+ ):
+ if tokenizer_path is None:
+ base_path = os.path.dirname(os.path.dirname(__file__))
+ tokenizer_path = os.path.join(base_path, "tokenizer_configs/hunyuan_dit/tokenizer")
+ if tokenizer_t5_path is None:
+ base_path = os.path.dirname(os.path.dirname(__file__))
+ tokenizer_t5_path = os.path.join(base_path, "tokenizer_configs/hunyuan_dit/tokenizer_t5")
+ super().__init__()
+ self.tokenizer = BertTokenizer.from_pretrained(tokenizer_path)
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ self.tokenizer_t5 = AutoTokenizer.from_pretrained(tokenizer_t5_path)
+ self.text_encoder: HunyuanDiTCLIPTextEncoder = None
+ self.text_encoder_t5: HunyuanDiTT5TextEncoder = None
+
+
+ def fetch_models(self, text_encoder: HunyuanDiTCLIPTextEncoder = None, text_encoder_t5: HunyuanDiTT5TextEncoder = None):
+ self.text_encoder = text_encoder
+ self.text_encoder_t5 = text_encoder_t5
+
+
+ def encode_prompt_using_signle_model(self, prompt, text_encoder, tokenizer, max_length, clip_skip, device):
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_attention_mask=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ attention_mask = text_inputs.attention_mask.to(device)
+ prompt_embeds = text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ clip_skip=clip_skip
+ )
+ return prompt_embeds, attention_mask
+
+
+ def encode_prompt(
+ self,
+ prompt,
+ clip_skip=1,
+ clip_skip_2=1,
+ positive=True,
+ device="cuda"
+ ):
+ prompt = self.process_prompt(prompt, positive=positive)
+
+ # CLIP
+ prompt_emb, attention_mask = self.encode_prompt_using_signle_model(prompt, self.text_encoder, self.tokenizer, self.tokenizer.model_max_length, clip_skip, device)
+
+ # T5
+ prompt_emb_t5, attention_mask_t5 = self.encode_prompt_using_signle_model(prompt, self.text_encoder_t5, self.tokenizer_t5, self.tokenizer_t5.model_max_length, clip_skip_2, device)
+
+ return prompt_emb, attention_mask, prompt_emb_t5, attention_mask_t5
diff --git a/diffsynth/prompters/hunyuan_video_prompter.py b/diffsynth/prompters/hunyuan_video_prompter.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b97356cacd4b9ccd9d0912b5694e1c1b4868ae9
--- /dev/null
+++ b/diffsynth/prompters/hunyuan_video_prompter.py
@@ -0,0 +1,275 @@
+from .base_prompter import BasePrompter
+from ..models.sd3_text_encoder import SD3TextEncoder1
+from ..models.hunyuan_video_text_encoder import HunyuanVideoLLMEncoder, HunyuanVideoMLLMEncoder
+from transformers import CLIPTokenizer, LlamaTokenizerFast, CLIPImageProcessor
+import os, torch
+from typing import Union
+
+PROMPT_TEMPLATE_ENCODE = (
+ "<|start_header_id|>system<|end_header_id|>\n\nDescribe the image by detailing the color, shape, size, texture, "
+ "quantity, text, spatial relationships of the objects and background:<|eot_id|>"
+ "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>")
+
+PROMPT_TEMPLATE_ENCODE_VIDEO = (
+ "<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: "
+ "1. The main content and theme of the video."
+ "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
+ "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
+ "4. background environment, light, style and atmosphere."
+ "5. camera angles, movements, and transitions used in the video:<|eot_id|>"
+ "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>")
+
+PROMPT_TEMPLATE_ENCODE_I2V = (
+ "<|start_header_id|>system<|end_header_id|>\n\n\nDescribe the image by detailing the color, shape, size, texture, "
+ "quantity, text, spatial relationships of the objects and background:<|eot_id|>"
+ "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
+ "<|start_header_id|>assistant<|end_header_id|>\n\n"
+)
+
+PROMPT_TEMPLATE_ENCODE_VIDEO_I2V = (
+ "<|start_header_id|>system<|end_header_id|>\n\n\nDescribe the video by detailing the following aspects according to the reference image: "
+ "1. The main content and theme of the video."
+ "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
+ "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
+ "4. background environment, light, style and atmosphere."
+ "5. camera angles, movements, and transitions used in the video:<|eot_id|>\n\n"
+ "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
+ "<|start_header_id|>assistant<|end_header_id|>\n\n"
+)
+
+PROMPT_TEMPLATE = {
+ "dit-llm-encode": {
+ "template": PROMPT_TEMPLATE_ENCODE,
+ "crop_start": 36,
+ },
+ "dit-llm-encode-video": {
+ "template": PROMPT_TEMPLATE_ENCODE_VIDEO,
+ "crop_start": 95,
+ },
+ "dit-llm-encode-i2v": {
+ "template": PROMPT_TEMPLATE_ENCODE_I2V,
+ "crop_start": 36,
+ "image_emb_start": 5,
+ "image_emb_end": 581,
+ "image_emb_len": 576,
+ "double_return_token_id": 271
+ },
+ "dit-llm-encode-video-i2v": {
+ "template": PROMPT_TEMPLATE_ENCODE_VIDEO_I2V,
+ "crop_start": 103,
+ "image_emb_start": 5,
+ "image_emb_end": 581,
+ "image_emb_len": 576,
+ "double_return_token_id": 271
+ },
+}
+
+NEGATIVE_PROMPT = "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion"
+
+
+class HunyuanVideoPrompter(BasePrompter):
+
+ def __init__(
+ self,
+ tokenizer_1_path=None,
+ tokenizer_2_path=None,
+ ):
+ if tokenizer_1_path is None:
+ base_path = os.path.dirname(os.path.dirname(__file__))
+ tokenizer_1_path = os.path.join(
+ base_path, "tokenizer_configs/hunyuan_video/tokenizer_1")
+ if tokenizer_2_path is None:
+ base_path = os.path.dirname(os.path.dirname(__file__))
+ tokenizer_2_path = os.path.join(
+ base_path, "tokenizer_configs/hunyuan_video/tokenizer_2")
+ super().__init__()
+ self.tokenizer_1 = CLIPTokenizer.from_pretrained(tokenizer_1_path)
+ self.tokenizer_2 = LlamaTokenizerFast.from_pretrained(tokenizer_2_path, padding_side='right')
+ self.text_encoder_1: SD3TextEncoder1 = None
+ self.text_encoder_2: HunyuanVideoLLMEncoder = None
+
+ self.prompt_template = PROMPT_TEMPLATE['dit-llm-encode']
+ self.prompt_template_video = PROMPT_TEMPLATE['dit-llm-encode-video']
+
+ def fetch_models(self,
+ text_encoder_1: SD3TextEncoder1 = None,
+ text_encoder_2: Union[HunyuanVideoLLMEncoder, HunyuanVideoMLLMEncoder] = None):
+ self.text_encoder_1 = text_encoder_1
+ self.text_encoder_2 = text_encoder_2
+ if isinstance(text_encoder_2, HunyuanVideoMLLMEncoder):
+ # processor
+ # TODO: may need to replace processor with local implementation
+ base_path = os.path.dirname(os.path.dirname(__file__))
+ tokenizer_2_path = os.path.join(base_path, "tokenizer_configs/hunyuan_video/tokenizer_2")
+ self.processor = CLIPImageProcessor.from_pretrained(tokenizer_2_path)
+ # template
+ self.prompt_template = PROMPT_TEMPLATE['dit-llm-encode-i2v']
+ self.prompt_template_video = PROMPT_TEMPLATE['dit-llm-encode-video-i2v']
+
+ def apply_text_to_template(self, text, template):
+ assert isinstance(template, str)
+ if isinstance(text, list):
+ return [self.apply_text_to_template(text_) for text_ in text]
+ elif isinstance(text, str):
+ # Will send string to tokenizer. Used for llm
+ return template.format(text)
+ else:
+ raise TypeError(f"Unsupported prompt type: {type(text)}")
+
+ def encode_prompt_using_clip(self, prompt, max_length, device):
+ tokenized_result = self.tokenizer_1(
+ prompt,
+ return_tensors="pt",
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_attention_mask=True
+ )
+ input_ids = tokenized_result.input_ids.to(device)
+ attention_mask = tokenized_result.attention_mask.to(device)
+ return self.text_encoder_1(input_ids=input_ids, extra_mask=attention_mask)[0]
+
+ def encode_prompt_using_llm(self,
+ prompt,
+ max_length,
+ device,
+ crop_start,
+ hidden_state_skip_layer=2,
+ use_attention_mask=True):
+ max_length += crop_start
+ inputs = self.tokenizer_2(prompt,
+ return_tensors="pt",
+ padding="max_length",
+ max_length=max_length,
+ truncation=True)
+ input_ids = inputs.input_ids.to(device)
+ attention_mask = inputs.attention_mask.to(device)
+ last_hidden_state = self.text_encoder_2(input_ids, attention_mask, hidden_state_skip_layer)
+
+ # crop out
+ if crop_start > 0:
+ last_hidden_state = last_hidden_state[:, crop_start:]
+ attention_mask = (attention_mask[:, crop_start:] if use_attention_mask else None)
+
+ return last_hidden_state, attention_mask
+
+ def encode_prompt_using_mllm(self,
+ prompt,
+ images,
+ max_length,
+ device,
+ crop_start,
+ hidden_state_skip_layer=2,
+ use_attention_mask=True,
+ image_embed_interleave=4):
+ image_outputs = self.processor(images, return_tensors="pt")["pixel_values"].to(device)
+ max_length += crop_start
+ inputs = self.tokenizer_2(prompt,
+ return_tensors="pt",
+ padding="max_length",
+ max_length=max_length,
+ truncation=True)
+ input_ids = inputs.input_ids.to(device)
+ attention_mask = inputs.attention_mask.to(device)
+ last_hidden_state = self.text_encoder_2(input_ids=input_ids,
+ attention_mask=attention_mask,
+ hidden_state_skip_layer=hidden_state_skip_layer,
+ pixel_values=image_outputs)
+
+ text_crop_start = (crop_start - 1 + self.prompt_template_video.get("image_emb_len", 576))
+ image_crop_start = self.prompt_template_video.get("image_emb_start", 5)
+ image_crop_end = self.prompt_template_video.get("image_emb_end", 581)
+ batch_indices, last_double_return_token_indices = torch.where(
+ input_ids == self.prompt_template_video.get("double_return_token_id", 271))
+ if last_double_return_token_indices.shape[0] == 3:
+ # in case the prompt is too long
+ last_double_return_token_indices = torch.cat((
+ last_double_return_token_indices,
+ torch.tensor([input_ids.shape[-1]]),
+ ))
+ batch_indices = torch.cat((batch_indices, torch.tensor([0])))
+ last_double_return_token_indices = (last_double_return_token_indices.reshape(input_ids.shape[0], -1)[:, -1])
+ batch_indices = batch_indices.reshape(input_ids.shape[0], -1)[:, -1]
+ assistant_crop_start = (last_double_return_token_indices - 1 + self.prompt_template_video.get("image_emb_len", 576) - 4)
+ assistant_crop_end = (last_double_return_token_indices - 1 + self.prompt_template_video.get("image_emb_len", 576))
+ attention_mask_assistant_crop_start = (last_double_return_token_indices - 4)
+ attention_mask_assistant_crop_end = last_double_return_token_indices
+ text_last_hidden_state = []
+ text_attention_mask = []
+ image_last_hidden_state = []
+ image_attention_mask = []
+ for i in range(input_ids.shape[0]):
+ text_last_hidden_state.append(
+ torch.cat([
+ last_hidden_state[i, text_crop_start:assistant_crop_start[i].item()],
+ last_hidden_state[i, assistant_crop_end[i].item():],
+ ]))
+ text_attention_mask.append(
+ torch.cat([
+ attention_mask[
+ i,
+ crop_start:attention_mask_assistant_crop_start[i].item(),
+ ],
+ attention_mask[i, attention_mask_assistant_crop_end[i].item():],
+ ]) if use_attention_mask else None)
+ image_last_hidden_state.append(last_hidden_state[i, image_crop_start:image_crop_end])
+ image_attention_mask.append(
+ torch.ones(image_last_hidden_state[-1].shape[0]).to(last_hidden_state.device).
+ to(attention_mask.dtype) if use_attention_mask else None)
+
+ text_last_hidden_state = torch.stack(text_last_hidden_state)
+ text_attention_mask = torch.stack(text_attention_mask)
+ image_last_hidden_state = torch.stack(image_last_hidden_state)
+ image_attention_mask = torch.stack(image_attention_mask)
+
+ image_last_hidden_state = image_last_hidden_state[:, ::image_embed_interleave, :]
+ image_attention_mask = image_attention_mask[:, ::image_embed_interleave]
+
+ assert (text_last_hidden_state.shape[0] == text_attention_mask.shape[0] and
+ image_last_hidden_state.shape[0] == image_attention_mask.shape[0])
+
+ last_hidden_state = torch.cat([image_last_hidden_state, text_last_hidden_state], dim=1)
+ attention_mask = torch.cat([image_attention_mask, text_attention_mask], dim=1)
+
+ return last_hidden_state, attention_mask
+
+ def encode_prompt(self,
+ prompt,
+ images=None,
+ positive=True,
+ device="cuda",
+ clip_sequence_length=77,
+ llm_sequence_length=256,
+ data_type='video',
+ use_template=True,
+ hidden_state_skip_layer=2,
+ use_attention_mask=True,
+ image_embed_interleave=4):
+
+ prompt = self.process_prompt(prompt, positive=positive)
+
+ # apply template
+ if use_template:
+ template = self.prompt_template_video if data_type == 'video' else self.prompt_template
+ prompt_formated = self.apply_text_to_template(prompt, template['template'])
+ else:
+ prompt_formated = prompt
+ # Text encoder
+ if data_type == 'video':
+ crop_start = self.prompt_template_video.get("crop_start", 0)
+ else:
+ crop_start = self.prompt_template.get("crop_start", 0)
+
+ # CLIP
+ pooled_prompt_emb = self.encode_prompt_using_clip(prompt, clip_sequence_length, device)
+
+ # LLM
+ if images is None:
+ prompt_emb, attention_mask = self.encode_prompt_using_llm(prompt_formated, llm_sequence_length, device, crop_start,
+ hidden_state_skip_layer, use_attention_mask)
+ else:
+ prompt_emb, attention_mask = self.encode_prompt_using_mllm(prompt_formated, images, llm_sequence_length, device,
+ crop_start, hidden_state_skip_layer, use_attention_mask,
+ image_embed_interleave)
+
+ return prompt_emb, pooled_prompt_emb, attention_mask
diff --git a/diffsynth/prompters/kolors_prompter.py b/diffsynth/prompters/kolors_prompter.py
new file mode 100644
index 0000000000000000000000000000000000000000..e3d5d58a9dbb816ea8c8e0e3b4f0433bd11d3306
--- /dev/null
+++ b/diffsynth/prompters/kolors_prompter.py
@@ -0,0 +1,354 @@
+from .base_prompter import BasePrompter
+from ..models.model_manager import ModelManager
+import json, os, re
+from typing import List, Optional, Union, Dict
+from sentencepiece import SentencePieceProcessor
+from transformers import PreTrainedTokenizer
+from transformers.utils import PaddingStrategy
+from transformers.tokenization_utils_base import EncodedInput, BatchEncoding
+from ..models.kolors_text_encoder import ChatGLMModel
+
+
+class SPTokenizer:
+ def __init__(self, model_path: str):
+ # reload tokenizer
+ assert os.path.isfile(model_path), model_path
+ self.sp_model = SentencePieceProcessor(model_file=model_path)
+
+ # BOS / EOS token IDs
+ self.n_words: int = self.sp_model.vocab_size()
+ self.bos_id: int = self.sp_model.bos_id()
+ self.eos_id: int = self.sp_model.eos_id()
+ self.pad_id: int = self.sp_model.unk_id()
+ assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
+
+ role_special_tokens = ["<|system|>", "<|user|>", "<|assistant|>", "<|observation|>"]
+ special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "sop", "eop"] + role_special_tokens
+ self.special_tokens = {}
+ self.index_special_tokens = {}
+ for token in special_tokens:
+ self.special_tokens[token] = self.n_words
+ self.index_special_tokens[self.n_words] = token
+ self.n_words += 1
+ self.role_special_token_expression = "|".join([re.escape(token) for token in role_special_tokens])
+
+ def tokenize(self, s: str, encode_special_tokens=False):
+ if encode_special_tokens:
+ last_index = 0
+ t = []
+ for match in re.finditer(self.role_special_token_expression, s):
+ if last_index < match.start():
+ t.extend(self.sp_model.EncodeAsPieces(s[last_index:match.start()]))
+ t.append(s[match.start():match.end()])
+ last_index = match.end()
+ if last_index < len(s):
+ t.extend(self.sp_model.EncodeAsPieces(s[last_index:]))
+ return t
+ else:
+ return self.sp_model.EncodeAsPieces(s)
+
+ def encode(self, s: str, bos: bool = False, eos: bool = False) -> List[int]:
+ assert type(s) is str
+ t = self.sp_model.encode(s)
+ if bos:
+ t = [self.bos_id] + t
+ if eos:
+ t = t + [self.eos_id]
+ return t
+
+ def decode(self, t: List[int]) -> str:
+ text, buffer = "", []
+ for token in t:
+ if token in self.index_special_tokens:
+ if buffer:
+ text += self.sp_model.decode(buffer)
+ buffer = []
+ text += self.index_special_tokens[token]
+ else:
+ buffer.append(token)
+ if buffer:
+ text += self.sp_model.decode(buffer)
+ return text
+
+ def decode_tokens(self, tokens: List[str]) -> str:
+ text = self.sp_model.DecodePieces(tokens)
+ return text
+
+ def convert_token_to_id(self, token):
+ """ Converts a token (str) in an id using the vocab. """
+ if token in self.special_tokens:
+ return self.special_tokens[token]
+ return self.sp_model.PieceToId(token)
+
+ def convert_id_to_token(self, index):
+ """Converts an index (integer) in a token (str) using the vocab."""
+ if index in self.index_special_tokens:
+ return self.index_special_tokens[index]
+ if index in [self.eos_id, self.bos_id, self.pad_id] or index < 0:
+ return ""
+ return self.sp_model.IdToPiece(index)
+
+
+
+class ChatGLMTokenizer(PreTrainedTokenizer):
+ vocab_files_names = {"vocab_file": "tokenizer.model"}
+
+ model_input_names = ["input_ids", "attention_mask", "position_ids"]
+
+ def __init__(self, vocab_file, padding_side="left", clean_up_tokenization_spaces=False, encode_special_tokens=False,
+ **kwargs):
+ self.name = "GLMTokenizer"
+
+ self.vocab_file = vocab_file
+ self.tokenizer = SPTokenizer(vocab_file)
+ self.special_tokens = {
+ "": self.tokenizer.bos_id,
+ "": self.tokenizer.eos_id,
+ "": self.tokenizer.pad_id
+ }
+ self.encode_special_tokens = encode_special_tokens
+ super().__init__(padding_side=padding_side, clean_up_tokenization_spaces=clean_up_tokenization_spaces,
+ encode_special_tokens=encode_special_tokens,
+ **kwargs)
+
+ def get_command(self, token):
+ if token in self.special_tokens:
+ return self.special_tokens[token]
+ assert token in self.tokenizer.special_tokens, f"{token} is not a special token for {self.name}"
+ return self.tokenizer.special_tokens[token]
+
+ @property
+ def unk_token(self) -> str:
+ return ""
+
+ @property
+ def pad_token(self) -> str:
+ return ""
+
+ @property
+ def pad_token_id(self):
+ return self.get_command("")
+
+ @property
+ def eos_token(self) -> str:
+ return ""
+
+ @property
+ def eos_token_id(self):
+ return self.get_command("")
+
+ @property
+ def vocab_size(self):
+ return self.tokenizer.n_words
+
+ def get_vocab(self):
+ """ Returns vocab as a dict """
+ vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)}
+ vocab.update(self.added_tokens_encoder)
+ return vocab
+
+ def _tokenize(self, text, **kwargs):
+ return self.tokenizer.tokenize(text, encode_special_tokens=self.encode_special_tokens)
+
+ def _convert_token_to_id(self, token):
+ """ Converts a token (str) in an id using the vocab. """
+ return self.tokenizer.convert_token_to_id(token)
+
+ def _convert_id_to_token(self, index):
+ """Converts an index (integer) in a token (str) using the vocab."""
+ return self.tokenizer.convert_id_to_token(index)
+
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
+ return self.tokenizer.decode_tokens(tokens)
+
+ def save_vocabulary(self, save_directory, filename_prefix=None):
+ """
+ Save the vocabulary and special tokens file to a directory.
+
+ Args:
+ save_directory (`str`):
+ The directory in which to save the vocabulary.
+ filename_prefix (`str`, *optional*):
+ An optional prefix to add to the named of the saved files.
+
+ Returns:
+ `Tuple(str)`: Paths to the files saved.
+ """
+ if os.path.isdir(save_directory):
+ vocab_file = os.path.join(
+ save_directory, self.vocab_files_names["vocab_file"]
+ )
+ else:
+ vocab_file = save_directory
+
+ with open(self.vocab_file, 'rb') as fin:
+ proto_str = fin.read()
+
+ with open(vocab_file, "wb") as writer:
+ writer.write(proto_str)
+
+ return (vocab_file,)
+
+ def get_prefix_tokens(self):
+ prefix_tokens = [self.get_command("[gMASK]"), self.get_command("sop")]
+ return prefix_tokens
+
+ def build_single_message(self, role, metadata, message):
+ assert role in ["system", "user", "assistant", "observation"], role
+ role_tokens = [self.get_command(f"<|{role}|>")] + self.tokenizer.encode(f"{metadata}\n")
+ message_tokens = self.tokenizer.encode(message)
+ tokens = role_tokens + message_tokens
+ return tokens
+
+ def build_chat_input(self, query, history=None, role="user"):
+ if history is None:
+ history = []
+ input_ids = []
+ for item in history:
+ content = item["content"]
+ if item["role"] == "system" and "tools" in item:
+ content = content + "\n" + json.dumps(item["tools"], indent=4, ensure_ascii=False)
+ input_ids.extend(self.build_single_message(item["role"], item.get("metadata", ""), content))
+ input_ids.extend(self.build_single_message(role, "", query))
+ input_ids.extend([self.get_command("<|assistant|>")])
+ return self.batch_encode_plus([input_ids], return_tensors="pt", is_split_into_words=True)
+
+ def build_inputs_with_special_tokens(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
+ ) -> List[int]:
+ """
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
+ adding special tokens. A BERT sequence has the following format:
+
+ - single sequence: `[CLS] X [SEP]`
+ - pair of sequences: `[CLS] A [SEP] B [SEP]`
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs to which the special tokens will be added.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
+ """
+ prefix_tokens = self.get_prefix_tokens()
+ token_ids_0 = prefix_tokens + token_ids_0
+ if token_ids_1 is not None:
+ token_ids_0 = token_ids_0 + token_ids_1 + [self.get_command("")]
+ return token_ids_0
+
+ def _pad(
+ self,
+ encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
+ max_length: Optional[int] = None,
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
+ pad_to_multiple_of: Optional[int] = None,
+ return_attention_mask: Optional[bool] = None,
+ padding_side: Optional[str] = None,
+ ) -> dict:
+ """
+ Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
+
+ Args:
+ encoded_inputs:
+ Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
+ max_length: maximum length of the returned list and optionally padding length (see below).
+ Will truncate by taking into account the special tokens.
+ padding_strategy: PaddingStrategy to use for padding.
+
+ - PaddingStrategy.LONGEST Pad to the longest sequence in the batch
+ - PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
+ - PaddingStrategy.DO_NOT_PAD: Do not pad
+ The tokenizer padding sides are defined in self.padding_side:
+
+ - 'left': pads on the left of the sequences
+ - 'right': pads on the right of the sequences
+ pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
+ This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
+ `>= 7.5` (Volta).
+ return_attention_mask:
+ (optional) Set to False to avoid returning attention mask (default: set to model specifics)
+ """
+ # Load from model defaults
+ assert self.padding_side == "left"
+
+ required_input = encoded_inputs[self.model_input_names[0]]
+ seq_length = len(required_input)
+
+ if padding_strategy == PaddingStrategy.LONGEST:
+ max_length = len(required_input)
+
+ if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
+ max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
+
+ needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
+
+ # Initialize attention mask if not present.
+ if "attention_mask" not in encoded_inputs:
+ encoded_inputs["attention_mask"] = [1] * seq_length
+
+ if "position_ids" not in encoded_inputs:
+ encoded_inputs["position_ids"] = list(range(seq_length))
+
+ if needs_to_be_padded:
+ difference = max_length - len(required_input)
+
+ if "attention_mask" in encoded_inputs:
+ encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"]
+ if "position_ids" in encoded_inputs:
+ encoded_inputs["position_ids"] = [0] * difference + encoded_inputs["position_ids"]
+ encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
+
+ return encoded_inputs
+
+
+
+class KolorsPrompter(BasePrompter):
+ def __init__(
+ self,
+ tokenizer_path=None
+ ):
+ if tokenizer_path is None:
+ base_path = os.path.dirname(os.path.dirname(__file__))
+ tokenizer_path = os.path.join(base_path, "tokenizer_configs/kolors/tokenizer")
+ super().__init__()
+ self.tokenizer = ChatGLMTokenizer.from_pretrained(tokenizer_path)
+ self.text_encoder: ChatGLMModel = None
+
+
+ def fetch_models(self, text_encoder: ChatGLMModel = None):
+ self.text_encoder = text_encoder
+
+
+ def encode_prompt_using_ChatGLM(self, prompt, text_encoder, tokenizer, max_length, clip_skip, device):
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ ).to(device)
+ output = text_encoder(
+ input_ids=text_inputs['input_ids'] ,
+ attention_mask=text_inputs['attention_mask'],
+ position_ids=text_inputs['position_ids'],
+ output_hidden_states=True
+ )
+ prompt_emb = output.hidden_states[-clip_skip].permute(1, 0, 2).clone()
+ pooled_prompt_emb = output.hidden_states[-1][-1, :, :].clone()
+ return prompt_emb, pooled_prompt_emb
+
+
+ def encode_prompt(
+ self,
+ prompt,
+ clip_skip=1,
+ clip_skip_2=2,
+ positive=True,
+ device="cuda"
+ ):
+ prompt = self.process_prompt(prompt, positive=positive)
+ prompt_emb, pooled_prompt_emb = self.encode_prompt_using_ChatGLM(prompt, self.text_encoder, self.tokenizer, 256, clip_skip_2, device)
+
+ return pooled_prompt_emb, prompt_emb
diff --git a/diffsynth/prompters/omnigen_prompter.py b/diffsynth/prompters/omnigen_prompter.py
new file mode 100644
index 0000000000000000000000000000000000000000..616efabebb7d327ecf968165dd12341ab8f83894
--- /dev/null
+++ b/diffsynth/prompters/omnigen_prompter.py
@@ -0,0 +1,356 @@
+import os
+import re
+from typing import Dict, List
+
+import torch
+from PIL import Image
+from torchvision import transforms
+from transformers import AutoTokenizer
+from huggingface_hub import snapshot_download
+import numpy as np
+
+
+
+def crop_arr(pil_image, max_image_size):
+ while min(*pil_image.size) >= 2 * max_image_size:
+ pil_image = pil_image.resize(
+ tuple(x // 2 for x in pil_image.size), resample=Image.BOX
+ )
+
+ if max(*pil_image.size) > max_image_size:
+ scale = max_image_size / max(*pil_image.size)
+ pil_image = pil_image.resize(
+ tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
+ )
+
+ if min(*pil_image.size) < 16:
+ scale = 16 / min(*pil_image.size)
+ pil_image = pil_image.resize(
+ tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
+ )
+
+ arr = np.array(pil_image)
+ crop_y1 = (arr.shape[0] % 16) // 2
+ crop_y2 = arr.shape[0] % 16 - crop_y1
+
+ crop_x1 = (arr.shape[1] % 16) // 2
+ crop_x2 = arr.shape[1] % 16 - crop_x1
+
+ arr = arr[crop_y1:arr.shape[0]-crop_y2, crop_x1:arr.shape[1]-crop_x2]
+ return Image.fromarray(arr)
+
+
+
+class OmniGenPrompter:
+ def __init__(self,
+ text_tokenizer,
+ max_image_size: int=1024):
+ self.text_tokenizer = text_tokenizer
+ self.max_image_size = max_image_size
+
+ self.image_transform = transforms.Compose([
+ transforms.Lambda(lambda pil_image: crop_arr(pil_image, max_image_size)),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
+ ])
+
+ self.collator = OmniGenCollator()
+ self.separate_collator = OmniGenSeparateCollator()
+
+ @classmethod
+ def from_pretrained(cls, model_name):
+ if not os.path.exists(model_name):
+ cache_folder = os.getenv('HF_HUB_CACHE')
+ model_name = snapshot_download(repo_id=model_name,
+ cache_dir=cache_folder,
+ allow_patterns="*.json")
+ text_tokenizer = AutoTokenizer.from_pretrained(model_name)
+
+ return cls(text_tokenizer)
+
+
+ def process_image(self, image):
+ return self.image_transform(image)
+
+ def process_multi_modal_prompt(self, text, input_images):
+ text = self.add_prefix_instruction(text)
+ if input_images is None or len(input_images) == 0:
+ model_inputs = self.text_tokenizer(text)
+ return {"input_ids": model_inputs.input_ids, "pixel_values": None, "image_sizes": None}
+
+ pattern = r"<\|image_\d+\|>"
+ prompt_chunks = [self.text_tokenizer(chunk).input_ids for chunk in re.split(pattern, text)]
+
+ for i in range(1, len(prompt_chunks)):
+ if prompt_chunks[i][0] == 1:
+ prompt_chunks[i] = prompt_chunks[i][1:]
+
+ image_tags = re.findall(pattern, text)
+ image_ids = [int(s.split("|")[1].split("_")[-1]) for s in image_tags]
+
+ unique_image_ids = sorted(list(set(image_ids)))
+ assert unique_image_ids == list(range(1, len(unique_image_ids)+1)), f"image_ids must start from 1, and must be continuous int, e.g. [1, 2, 3], cannot be {unique_image_ids}"
+ # total images must be the same as the number of image tags
+ assert len(unique_image_ids) == len(input_images), f"total images must be the same as the number of image tags, got {len(unique_image_ids)} image tags and {len(input_images)} images"
+
+ input_images = [input_images[x-1] for x in image_ids]
+
+ all_input_ids = []
+ img_inx = []
+ idx = 0
+ for i in range(len(prompt_chunks)):
+ all_input_ids.extend(prompt_chunks[i])
+ if i != len(prompt_chunks) -1:
+ start_inx = len(all_input_ids)
+ size = input_images[i].size(-2) * input_images[i].size(-1) // 16 // 16
+ img_inx.append([start_inx, start_inx+size])
+ all_input_ids.extend([0]*size)
+
+ return {"input_ids": all_input_ids, "pixel_values": input_images, "image_sizes": img_inx}
+
+
+ def add_prefix_instruction(self, prompt):
+ user_prompt = '<|user|>\n'
+ generation_prompt = 'Generate an image according to the following instructions\n'
+ assistant_prompt = '<|assistant|>\n<|diffusion|>'
+ prompt_suffix = "<|end|>\n"
+ prompt = f"{user_prompt}{generation_prompt}{prompt}{prompt_suffix}{assistant_prompt}"
+ return prompt
+
+
+ def __call__(self,
+ instructions: List[str],
+ input_images: List[List[str]] = None,
+ height: int = 1024,
+ width: int = 1024,
+ negative_prompt: str = "low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers.",
+ use_img_cfg: bool = True,
+ separate_cfg_input: bool = False,
+ use_input_image_size_as_output: bool=False,
+ ) -> Dict:
+
+ if input_images is None:
+ use_img_cfg = False
+ if isinstance(instructions, str):
+ instructions = [instructions]
+ input_images = [input_images]
+
+ input_data = []
+ for i in range(len(instructions)):
+ cur_instruction = instructions[i]
+ cur_input_images = None if input_images is None else input_images[i]
+ if cur_input_images is not None and len(cur_input_images) > 0:
+ cur_input_images = [self.process_image(x) for x in cur_input_images]
+ else:
+ cur_input_images = None
+ assert "
<|image_1|>" not in cur_instruction
+
+ mllm_input = self.process_multi_modal_prompt(cur_instruction, cur_input_images)
+
+
+ neg_mllm_input, img_cfg_mllm_input = None, None
+ neg_mllm_input = self.process_multi_modal_prompt(negative_prompt, None)
+ if use_img_cfg:
+ if cur_input_images is not None and len(cur_input_images) >= 1:
+ img_cfg_prompt = [f"
<|image_{i+1}|>" for i in range(len(cur_input_images))]
+ img_cfg_mllm_input = self.process_multi_modal_prompt(" ".join(img_cfg_prompt), cur_input_images)
+ else:
+ img_cfg_mllm_input = neg_mllm_input
+
+ if use_input_image_size_as_output:
+ input_data.append((mllm_input, neg_mllm_input, img_cfg_mllm_input, [mllm_input['pixel_values'][0].size(-2), mllm_input['pixel_values'][0].size(-1)]))
+ else:
+ input_data.append((mllm_input, neg_mllm_input, img_cfg_mllm_input, [height, width]))
+
+ if separate_cfg_input:
+ return self.separate_collator(input_data)
+ return self.collator(input_data)
+
+
+
+
+class OmniGenCollator:
+ def __init__(self, pad_token_id=2, hidden_size=3072):
+ self.pad_token_id = pad_token_id
+ self.hidden_size = hidden_size
+
+ def create_position(self, attention_mask, num_tokens_for_output_images):
+ position_ids = []
+ text_length = attention_mask.size(-1)
+ img_length = max(num_tokens_for_output_images)
+ for mask in attention_mask:
+ temp_l = torch.sum(mask)
+ temp_position = [0]*(text_length-temp_l) + [i for i in range(temp_l+img_length+1)] # we add a time embedding into the sequence, so add one more token
+ position_ids.append(temp_position)
+ return torch.LongTensor(position_ids)
+
+ def create_mask(self, attention_mask, num_tokens_for_output_images):
+ extended_mask = []
+ padding_images = []
+ text_length = attention_mask.size(-1)
+ img_length = max(num_tokens_for_output_images)
+ seq_len = text_length + img_length + 1 # we add a time embedding into the sequence, so add one more token
+ inx = 0
+ for mask in attention_mask:
+ temp_l = torch.sum(mask)
+ pad_l = text_length - temp_l
+
+ temp_mask = torch.tril(torch.ones(size=(temp_l+1, temp_l+1)))
+
+ image_mask = torch.zeros(size=(temp_l+1, img_length))
+ temp_mask = torch.cat([temp_mask, image_mask], dim=-1)
+
+ image_mask = torch.ones(size=(img_length, temp_l+img_length+1))
+ temp_mask = torch.cat([temp_mask, image_mask], dim=0)
+
+ if pad_l > 0:
+ pad_mask = torch.zeros(size=(temp_l+1+img_length, pad_l))
+ temp_mask = torch.cat([pad_mask, temp_mask], dim=-1)
+
+ pad_mask = torch.ones(size=(pad_l, seq_len))
+ temp_mask = torch.cat([pad_mask, temp_mask], dim=0)
+
+ true_img_length = num_tokens_for_output_images[inx]
+ pad_img_length = img_length - true_img_length
+ if pad_img_length > 0:
+ temp_mask[:, -pad_img_length:] = 0
+ temp_padding_imgs = torch.zeros(size=(1, pad_img_length, self.hidden_size))
+ else:
+ temp_padding_imgs = None
+
+ extended_mask.append(temp_mask.unsqueeze(0))
+ padding_images.append(temp_padding_imgs)
+ inx += 1
+ return torch.cat(extended_mask, dim=0), padding_images
+
+ def adjust_attention_for_input_images(self, attention_mask, image_sizes):
+ for b_inx in image_sizes.keys():
+ for start_inx, end_inx in image_sizes[b_inx]:
+ attention_mask[b_inx][start_inx:end_inx, start_inx:end_inx] = 1
+
+ return attention_mask
+
+ def pad_input_ids(self, input_ids, image_sizes):
+ max_l = max([len(x) for x in input_ids])
+ padded_ids = []
+ attention_mask = []
+ new_image_sizes = []
+
+ for i in range(len(input_ids)):
+ temp_ids = input_ids[i]
+ temp_l = len(temp_ids)
+ pad_l = max_l - temp_l
+ if pad_l == 0:
+ attention_mask.append([1]*max_l)
+ padded_ids.append(temp_ids)
+ else:
+ attention_mask.append([0]*pad_l+[1]*temp_l)
+ padded_ids.append([self.pad_token_id]*pad_l+temp_ids)
+
+ if i in image_sizes:
+ new_inx = []
+ for old_inx in image_sizes[i]:
+ new_inx.append([x+pad_l for x in old_inx])
+ image_sizes[i] = new_inx
+
+ return torch.LongTensor(padded_ids), torch.LongTensor(attention_mask), image_sizes
+
+
+ def process_mllm_input(self, mllm_inputs, target_img_size):
+ num_tokens_for_output_images = []
+ for img_size in target_img_size:
+ num_tokens_for_output_images.append(img_size[0]*img_size[1]//16//16)
+
+ pixel_values, image_sizes = [], {}
+ b_inx = 0
+ for x in mllm_inputs:
+ if x['pixel_values'] is not None:
+ pixel_values.extend(x['pixel_values'])
+ for size in x['image_sizes']:
+ if b_inx not in image_sizes:
+ image_sizes[b_inx] = [size]
+ else:
+ image_sizes[b_inx].append(size)
+ b_inx += 1
+ pixel_values = [x.unsqueeze(0) for x in pixel_values]
+
+
+ input_ids = [x['input_ids'] for x in mllm_inputs]
+ padded_input_ids, attention_mask, image_sizes = self.pad_input_ids(input_ids, image_sizes)
+ position_ids = self.create_position(attention_mask, num_tokens_for_output_images)
+ attention_mask, padding_images = self.create_mask(attention_mask, num_tokens_for_output_images)
+ attention_mask = self.adjust_attention_for_input_images(attention_mask, image_sizes)
+
+ return padded_input_ids, position_ids, attention_mask, padding_images, pixel_values, image_sizes
+
+
+ def __call__(self, features):
+ mllm_inputs = [f[0] for f in features]
+ cfg_mllm_inputs = [f[1] for f in features]
+ img_cfg_mllm_input = [f[2] for f in features]
+ target_img_size = [f[3] for f in features]
+
+
+ if img_cfg_mllm_input[0] is not None:
+ mllm_inputs = mllm_inputs + cfg_mllm_inputs + img_cfg_mllm_input
+ target_img_size = target_img_size + target_img_size + target_img_size
+ else:
+ mllm_inputs = mllm_inputs + cfg_mllm_inputs
+ target_img_size = target_img_size + target_img_size
+
+
+ all_padded_input_ids, all_position_ids, all_attention_mask, all_padding_images, all_pixel_values, all_image_sizes = self.process_mllm_input(mllm_inputs, target_img_size)
+
+ data = {"input_ids": all_padded_input_ids,
+ "attention_mask": all_attention_mask,
+ "position_ids": all_position_ids,
+ "input_pixel_values": all_pixel_values,
+ "input_image_sizes": all_image_sizes,
+ "padding_images": all_padding_images,
+ }
+ return data
+
+
+class OmniGenSeparateCollator(OmniGenCollator):
+ def __call__(self, features):
+ mllm_inputs = [f[0] for f in features]
+ cfg_mllm_inputs = [f[1] for f in features]
+ img_cfg_mllm_input = [f[2] for f in features]
+ target_img_size = [f[3] for f in features]
+
+ all_padded_input_ids, all_attention_mask, all_position_ids, all_pixel_values, all_image_sizes, all_padding_images = [], [], [], [], [], []
+
+
+ padded_input_ids, position_ids, attention_mask, padding_images, pixel_values, image_sizes = self.process_mllm_input(mllm_inputs, target_img_size)
+ all_padded_input_ids.append(padded_input_ids)
+ all_attention_mask.append(attention_mask)
+ all_position_ids.append(position_ids)
+ all_pixel_values.append(pixel_values)
+ all_image_sizes.append(image_sizes)
+ all_padding_images.append(padding_images)
+
+ if cfg_mllm_inputs[0] is not None:
+ padded_input_ids, position_ids, attention_mask, padding_images, pixel_values, image_sizes = self.process_mllm_input(cfg_mllm_inputs, target_img_size)
+ all_padded_input_ids.append(padded_input_ids)
+ all_attention_mask.append(attention_mask)
+ all_position_ids.append(position_ids)
+ all_pixel_values.append(pixel_values)
+ all_image_sizes.append(image_sizes)
+ all_padding_images.append(padding_images)
+ if img_cfg_mllm_input[0] is not None:
+ padded_input_ids, position_ids, attention_mask, padding_images, pixel_values, image_sizes = self.process_mllm_input(img_cfg_mllm_input, target_img_size)
+ all_padded_input_ids.append(padded_input_ids)
+ all_attention_mask.append(attention_mask)
+ all_position_ids.append(position_ids)
+ all_pixel_values.append(pixel_values)
+ all_image_sizes.append(image_sizes)
+ all_padding_images.append(padding_images)
+
+ data = {"input_ids": all_padded_input_ids,
+ "attention_mask": all_attention_mask,
+ "position_ids": all_position_ids,
+ "input_pixel_values": all_pixel_values,
+ "input_image_sizes": all_image_sizes,
+ "padding_images": all_padding_images,
+ }
+ return data
diff --git a/diffsynth/prompters/omost.py b/diffsynth/prompters/omost.py
new file mode 100644
index 0000000000000000000000000000000000000000..81828ad79978103eea42389d439847c0877cbd85
--- /dev/null
+++ b/diffsynth/prompters/omost.py
@@ -0,0 +1,323 @@
+from transformers import AutoTokenizer, TextIteratorStreamer
+import difflib
+import torch
+import numpy as np
+import re
+from ..models.model_manager import ModelManager
+from PIL import Image
+
+valid_colors = { # r, g, b
+ 'aliceblue': (240, 248, 255), 'antiquewhite': (250, 235, 215), 'aqua': (0, 255, 255),
+ 'aquamarine': (127, 255, 212), 'azure': (240, 255, 255), 'beige': (245, 245, 220),
+ 'bisque': (255, 228, 196), 'black': (0, 0, 0), 'blanchedalmond': (255, 235, 205), 'blue': (0, 0, 255),
+ 'blueviolet': (138, 43, 226), 'brown': (165, 42, 42), 'burlywood': (222, 184, 135),
+ 'cadetblue': (95, 158, 160), 'chartreuse': (127, 255, 0), 'chocolate': (210, 105, 30),
+ 'coral': (255, 127, 80), 'cornflowerblue': (100, 149, 237), 'cornsilk': (255, 248, 220),
+ 'crimson': (220, 20, 60), 'cyan': (0, 255, 255), 'darkblue': (0, 0, 139), 'darkcyan': (0, 139, 139),
+ 'darkgoldenrod': (184, 134, 11), 'darkgray': (169, 169, 169), 'darkgrey': (169, 169, 169),
+ 'darkgreen': (0, 100, 0), 'darkkhaki': (189, 183, 107), 'darkmagenta': (139, 0, 139),
+ 'darkolivegreen': (85, 107, 47), 'darkorange': (255, 140, 0), 'darkorchid': (153, 50, 204),
+ 'darkred': (139, 0, 0), 'darksalmon': (233, 150, 122), 'darkseagreen': (143, 188, 143),
+ 'darkslateblue': (72, 61, 139), 'darkslategray': (47, 79, 79), 'darkslategrey': (47, 79, 79),
+ 'darkturquoise': (0, 206, 209), 'darkviolet': (148, 0, 211), 'deeppink': (255, 20, 147),
+ 'deepskyblue': (0, 191, 255), 'dimgray': (105, 105, 105), 'dimgrey': (105, 105, 105),
+ 'dodgerblue': (30, 144, 255), 'firebrick': (178, 34, 34), 'floralwhite': (255, 250, 240),
+ 'forestgreen': (34, 139, 34), 'fuchsia': (255, 0, 255), 'gainsboro': (220, 220, 220),
+ 'ghostwhite': (248, 248, 255), 'gold': (255, 215, 0), 'goldenrod': (218, 165, 32),
+ 'gray': (128, 128, 128), 'grey': (128, 128, 128), 'green': (0, 128, 0), 'greenyellow': (173, 255, 47),
+ 'honeydew': (240, 255, 240), 'hotpink': (255, 105, 180), 'indianred': (205, 92, 92),
+ 'indigo': (75, 0, 130), 'ivory': (255, 255, 240), 'khaki': (240, 230, 140), 'lavender': (230, 230, 250),
+ 'lavenderblush': (255, 240, 245), 'lawngreen': (124, 252, 0), 'lemonchiffon': (255, 250, 205),
+ 'lightblue': (173, 216, 230), 'lightcoral': (240, 128, 128), 'lightcyan': (224, 255, 255),
+ 'lightgoldenrodyellow': (250, 250, 210), 'lightgray': (211, 211, 211), 'lightgrey': (211, 211, 211),
+ 'lightgreen': (144, 238, 144), 'lightpink': (255, 182, 193), 'lightsalmon': (255, 160, 122),
+ 'lightseagreen': (32, 178, 170), 'lightskyblue': (135, 206, 250), 'lightslategray': (119, 136, 153),
+ 'lightslategrey': (119, 136, 153), 'lightsteelblue': (176, 196, 222), 'lightyellow': (255, 255, 224),
+ 'lime': (0, 255, 0), 'limegreen': (50, 205, 50), 'linen': (250, 240, 230), 'magenta': (255, 0, 255),
+ 'maroon': (128, 0, 0), 'mediumaquamarine': (102, 205, 170), 'mediumblue': (0, 0, 205),
+ 'mediumorchid': (186, 85, 211), 'mediumpurple': (147, 112, 219), 'mediumseagreen': (60, 179, 113),
+ 'mediumslateblue': (123, 104, 238), 'mediumspringgreen': (0, 250, 154),
+ 'mediumturquoise': (72, 209, 204), 'mediumvioletred': (199, 21, 133), 'midnightblue': (25, 25, 112),
+ 'mintcream': (245, 255, 250), 'mistyrose': (255, 228, 225), 'moccasin': (255, 228, 181),
+ 'navajowhite': (255, 222, 173), 'navy': (0, 0, 128), 'navyblue': (0, 0, 128),
+ 'oldlace': (253, 245, 230), 'olive': (128, 128, 0), 'olivedrab': (107, 142, 35),
+ 'orange': (255, 165, 0), 'orangered': (255, 69, 0), 'orchid': (218, 112, 214),
+ 'palegoldenrod': (238, 232, 170), 'palegreen': (152, 251, 152), 'paleturquoise': (175, 238, 238),
+ 'palevioletred': (219, 112, 147), 'papayawhip': (255, 239, 213), 'peachpuff': (255, 218, 185),
+ 'peru': (205, 133, 63), 'pink': (255, 192, 203), 'plum': (221, 160, 221), 'powderblue': (176, 224, 230),
+ 'purple': (128, 0, 128), 'rebeccapurple': (102, 51, 153), 'red': (255, 0, 0),
+ 'rosybrown': (188, 143, 143), 'royalblue': (65, 105, 225), 'saddlebrown': (139, 69, 19),
+ 'salmon': (250, 128, 114), 'sandybrown': (244, 164, 96), 'seagreen': (46, 139, 87),
+ 'seashell': (255, 245, 238), 'sienna': (160, 82, 45), 'silver': (192, 192, 192),
+ 'skyblue': (135, 206, 235), 'slateblue': (106, 90, 205), 'slategray': (112, 128, 144),
+ 'slategrey': (112, 128, 144), 'snow': (255, 250, 250), 'springgreen': (0, 255, 127),
+ 'steelblue': (70, 130, 180), 'tan': (210, 180, 140), 'teal': (0, 128, 128), 'thistle': (216, 191, 216),
+ 'tomato': (255, 99, 71), 'turquoise': (64, 224, 208), 'violet': (238, 130, 238),
+ 'wheat': (245, 222, 179), 'white': (255, 255, 255), 'whitesmoke': (245, 245, 245),
+ 'yellow': (255, 255, 0), 'yellowgreen': (154, 205, 50)
+}
+
+valid_locations = { # x, y in 90*90
+ 'in the center': (45, 45),
+ 'on the left': (15, 45),
+ 'on the right': (75, 45),
+ 'on the top': (45, 15),
+ 'on the bottom': (45, 75),
+ 'on the top-left': (15, 15),
+ 'on the top-right': (75, 15),
+ 'on the bottom-left': (15, 75),
+ 'on the bottom-right': (75, 75)
+}
+
+valid_offsets = { # x, y in 90*90
+ 'no offset': (0, 0),
+ 'slightly to the left': (-10, 0),
+ 'slightly to the right': (10, 0),
+ 'slightly to the upper': (0, -10),
+ 'slightly to the lower': (0, 10),
+ 'slightly to the upper-left': (-10, -10),
+ 'slightly to the upper-right': (10, -10),
+ 'slightly to the lower-left': (-10, 10),
+ 'slightly to the lower-right': (10, 10)}
+
+valid_areas = { # w, h in 90*90
+ "a small square area": (50, 50),
+ "a small vertical area": (40, 60),
+ "a small horizontal area": (60, 40),
+ "a medium-sized square area": (60, 60),
+ "a medium-sized vertical area": (50, 80),
+ "a medium-sized horizontal area": (80, 50),
+ "a large square area": (70, 70),
+ "a large vertical area": (60, 90),
+ "a large horizontal area": (90, 60)
+}
+
+def safe_str(x):
+ return x.strip(',. ') + '.'
+
+def closest_name(input_str, options):
+ input_str = input_str.lower()
+
+ closest_match = difflib.get_close_matches(input_str, list(options.keys()), n=1, cutoff=0.5)
+ assert isinstance(closest_match, list) and len(closest_match) > 0, f'The value [{input_str}] is not valid!'
+ result = closest_match[0]
+
+ if result != input_str:
+ print(f'Automatically corrected [{input_str}] -> [{result}].')
+
+ return result
+
+class Canvas:
+ @staticmethod
+ def from_bot_response(response: str):
+
+ matched = re.search(r'```python\n(.*?)\n```', response, re.DOTALL)
+ assert matched, 'Response does not contain codes!'
+ code_content = matched.group(1)
+ assert 'canvas = Canvas()' in code_content, 'Code block must include valid canvas var!'
+ local_vars = {'Canvas': Canvas}
+ exec(code_content, {}, local_vars)
+ canvas = local_vars.get('canvas', None)
+ assert isinstance(canvas, Canvas), 'Code block must produce valid canvas var!'
+ return canvas
+
+ def __init__(self):
+ self.components = []
+ self.color = None
+ self.record_tags = True
+ self.prefixes = []
+ self.suffixes = []
+ return
+
+ def set_global_description(self, description: str, detailed_descriptions: list, tags: str,
+ HTML_web_color_name: str):
+ assert isinstance(description, str), 'Global description is not valid!'
+ assert isinstance(detailed_descriptions, list) and all(isinstance(item, str) for item in detailed_descriptions), \
+ 'Global detailed_descriptions is not valid!'
+ assert isinstance(tags, str), 'Global tags is not valid!'
+
+ HTML_web_color_name = closest_name(HTML_web_color_name, valid_colors)
+ self.color = np.array([[valid_colors[HTML_web_color_name]]], dtype=np.uint8)
+
+ self.prefixes = [description]
+ self.suffixes = detailed_descriptions
+
+ if self.record_tags:
+ self.suffixes = self.suffixes + [tags]
+
+ self.prefixes = [safe_str(x) for x in self.prefixes]
+ self.suffixes = [safe_str(x) for x in self.suffixes]
+
+ return
+
+ def add_local_description(self, location: str, offset: str, area: str, distance_to_viewer: float, description: str,
+ detailed_descriptions: list, tags: str, atmosphere: str, style: str,
+ quality_meta: str, HTML_web_color_name: str):
+ assert isinstance(description, str), 'Local description is wrong!'
+ assert isinstance(distance_to_viewer, (int, float)) and distance_to_viewer > 0, \
+ f'The distance_to_viewer for [{description}] is not positive float number!'
+ assert isinstance(detailed_descriptions, list) and all(isinstance(item, str) for item in detailed_descriptions), \
+ f'The detailed_descriptions for [{description}] is not valid!'
+ assert isinstance(tags, str), f'The tags for [{description}] is not valid!'
+ assert isinstance(atmosphere, str), f'The atmosphere for [{description}] is not valid!'
+ assert isinstance(style, str), f'The style for [{description}] is not valid!'
+ assert isinstance(quality_meta, str), f'The quality_meta for [{description}] is not valid!'
+
+ location = closest_name(location, valid_locations)
+ offset = closest_name(offset, valid_offsets)
+ area = closest_name(area, valid_areas)
+ HTML_web_color_name = closest_name(HTML_web_color_name, valid_colors)
+
+ xb, yb = valid_locations[location]
+ xo, yo = valid_offsets[offset]
+ w, h = valid_areas[area]
+ rect = (yb + yo - h // 2, yb + yo + h // 2, xb + xo - w // 2, xb + xo + w // 2)
+ rect = [max(0, min(90, i)) for i in rect]
+ color = np.array([[valid_colors[HTML_web_color_name]]], dtype=np.uint8)
+
+ prefixes = self.prefixes + [description]
+ suffixes = detailed_descriptions
+
+ if self.record_tags:
+ suffixes = suffixes + [tags, atmosphere, style, quality_meta]
+
+ prefixes = [safe_str(x) for x in prefixes]
+ suffixes = [safe_str(x) for x in suffixes]
+
+ self.components.append(dict(
+ rect=rect,
+ distance_to_viewer=distance_to_viewer,
+ color=color,
+ prefixes=prefixes,
+ suffixes=suffixes,
+ location=location,
+ ))
+
+ return
+
+ def process(self):
+ # sort components
+ self.components = sorted(self.components, key=lambda x: x['distance_to_viewer'], reverse=True)
+
+ # compute initial latent
+ # print(self.color)
+ initial_latent = np.zeros(shape=(90, 90, 3), dtype=np.float32) + self.color
+
+ for component in self.components:
+ a, b, c, d = component['rect']
+ initial_latent[a:b, c:d] = 0.7 * component['color'] + 0.3 * initial_latent[a:b, c:d]
+
+ initial_latent = initial_latent.clip(0, 255).astype(np.uint8)
+
+ # compute conditions
+
+ bag_of_conditions = [
+ dict(mask=np.ones(shape=(90, 90), dtype=np.float32), prefixes=self.prefixes, suffixes=self.suffixes,location= "full")
+ ]
+
+ for i, component in enumerate(self.components):
+ a, b, c, d = component['rect']
+ m = np.zeros(shape=(90, 90), dtype=np.float32)
+ m[a:b, c:d] = 1.0
+ bag_of_conditions.append(dict(
+ mask = m,
+ prefixes = component['prefixes'],
+ suffixes = component['suffixes'],
+ location = component['location'],
+ ))
+
+ return dict(
+ initial_latent = initial_latent,
+ bag_of_conditions = bag_of_conditions,
+ )
+
+
+class OmostPromter(torch.nn.Module):
+
+ def __init__(self,model = None,tokenizer = None, template = "",device="cpu"):
+ super().__init__()
+ self.model=model
+ self.tokenizer = tokenizer
+ self.device = device
+ if template == "":
+ template = r'''You are a helpful AI assistant to compose images using the below python class `Canvas`:
+ ```python
+ class Canvas:
+ def set_global_description(self, description: str, detailed_descriptions: list[str], tags: str, HTML_web_color_name: str):
+ pass
+
+ def add_local_description(self, location: str, offset: str, area: str, distance_to_viewer: float, description: str, detailed_descriptions: list[str], tags: str, atmosphere: str, style: str, quality_meta: str, HTML_web_color_name: str):
+ assert location in ["in the center", "on the left", "on the right", "on the top", "on the bottom", "on the top-left", "on the top-right", "on the bottom-left", "on the bottom-right"]
+ assert offset in ["no offset", "slightly to the left", "slightly to the right", "slightly to the upper", "slightly to the lower", "slightly to the upper-left", "slightly to the upper-right", "slightly to the lower-left", "slightly to the lower-right"]
+ assert area in ["a small square area", "a small vertical area", "a small horizontal area", "a medium-sized square area", "a medium-sized vertical area", "a medium-sized horizontal area", "a large square area", "a large vertical area", "a large horizontal area"]
+ assert distance_to_viewer > 0
+ pass
+ ```'''
+ self.template = template
+
+ @staticmethod
+ def from_model_manager(model_manager: ModelManager):
+ model, model_path = model_manager.fetch_model("omost_prompt", require_model_path=True)
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
+ omost = OmostPromter(
+ model= model,
+ tokenizer = tokenizer,
+ device = model_manager.device
+ )
+ return omost
+
+
+ def __call__(self,prompt_dict:dict):
+ raw_prompt=prompt_dict["prompt"]
+ conversation = [{"role": "system", "content": self.template}]
+ conversation.append({"role": "user", "content": raw_prompt})
+
+ input_ids = self.tokenizer.apply_chat_template(conversation, return_tensors="pt", add_generation_prompt=True).to(self.device)
+ streamer = TextIteratorStreamer(self.tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
+ attention_mask = torch.ones(input_ids.shape, dtype=torch.bfloat16, device=self.device)
+
+ generate_kwargs = dict(
+ input_ids = input_ids,
+ streamer = streamer,
+ # stopping_criteria=stopping_criteria,
+ # max_new_tokens=max_new_tokens,
+ do_sample = True,
+ attention_mask = attention_mask,
+ pad_token_id = self.tokenizer.eos_token_id,
+ # temperature=temperature,
+ # top_p=top_p,
+ )
+ self.model.generate(**generate_kwargs)
+ outputs = []
+ for text in streamer:
+ outputs.append(text)
+ llm_outputs = "".join(outputs)
+
+ canvas = Canvas.from_bot_response(llm_outputs)
+ canvas_output = canvas.process()
+
+ prompts = [" ".join(_["prefixes"]+_["suffixes"][:2]) for _ in canvas_output["bag_of_conditions"]]
+ canvas_output["prompt"] = prompts[0]
+ canvas_output["prompts"] = prompts[1:]
+
+ raw_masks = [_["mask"] for _ in canvas_output["bag_of_conditions"]]
+ masks=[]
+ for mask in raw_masks:
+ mask[mask>0.5]=255
+ mask = np.stack([mask] * 3, axis=-1).astype("uint8")
+ masks.append(Image.fromarray(mask))
+
+ canvas_output["masks"] = masks
+ prompt_dict.update(canvas_output)
+ print(f"Your prompt is extended by Omost:\n")
+ cnt = 0
+ for component,pmt in zip(canvas_output["bag_of_conditions"],prompts):
+ loc = component["location"]
+ cnt += 1
+ print(f"Component {cnt} - Location : {loc}\nPrompt:{pmt}\n")
+
+ return prompt_dict
+
+
+
+
\ No newline at end of file
diff --git a/diffsynth/prompters/prompt_refiners.py b/diffsynth/prompters/prompt_refiners.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ac19f565b076cccb21d9e05149b604e4bb55854
--- /dev/null
+++ b/diffsynth/prompters/prompt_refiners.py
@@ -0,0 +1,130 @@
+from transformers import AutoTokenizer
+from ..models.model_manager import ModelManager
+import torch
+from .omost import OmostPromter
+
+class BeautifulPrompt(torch.nn.Module):
+ def __init__(self, tokenizer_path=None, model=None, template=""):
+ super().__init__()
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
+ self.model = model
+ self.template = template
+
+
+ @staticmethod
+ def from_model_manager(model_manager: ModelManager):
+ model, model_path = model_manager.fetch_model("beautiful_prompt", require_model_path=True)
+ template = 'Instruction: Give a simple description of the image to generate a drawing prompt.\nInput: {raw_prompt}\nOutput:'
+ if model_path.endswith("v2"):
+ template = """Converts a simple image description into a prompt. \
+Prompts are formatted as multiple related tags separated by commas, plus you can use () to increase the weight, [] to decrease the weight, \
+or use a number to specify the weight. You should add appropriate words to make the images described in the prompt more aesthetically pleasing, \
+but make sure there is a correlation between the input and output.\n\
+### Input: {raw_prompt}\n### Output:"""
+ beautiful_prompt = BeautifulPrompt(
+ tokenizer_path=model_path,
+ model=model,
+ template=template
+ )
+ return beautiful_prompt
+
+
+ def __call__(self, raw_prompt, positive=True, **kwargs):
+ if positive:
+ model_input = self.template.format(raw_prompt=raw_prompt)
+ input_ids = self.tokenizer.encode(model_input, return_tensors='pt').to(self.model.device)
+ outputs = self.model.generate(
+ input_ids,
+ max_new_tokens=384,
+ do_sample=True,
+ temperature=0.9,
+ top_k=50,
+ top_p=0.95,
+ repetition_penalty=1.1,
+ num_return_sequences=1
+ )
+ prompt = raw_prompt + ", " + self.tokenizer.batch_decode(
+ outputs[:, input_ids.size(1):],
+ skip_special_tokens=True
+ )[0].strip()
+ print(f"Your prompt is refined by BeautifulPrompt: {prompt}")
+ return prompt
+ else:
+ return raw_prompt
+
+
+
+class QwenPrompt(torch.nn.Module):
+ # This class leverages the open-source Qwen model to translate Chinese prompts into English,
+ # with an integrated optimization mechanism for enhanced translation quality.
+ def __init__(self, tokenizer_path=None, model=None, system_prompt=""):
+ super().__init__()
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
+ self.model = model
+ self.system_prompt = system_prompt
+
+
+ @staticmethod
+ def from_model_manager(model_nameger: ModelManager):
+ model, model_path = model_nameger.fetch_model("qwen_prompt", require_model_path=True)
+ system_prompt = """You are an English image describer. Here are some example image styles:\n\n1. Extreme close-up: Clear focus on a single object with a blurred background, highlighted under natural sunlight.\n2. Vintage: A photograph of a historical scene, using techniques such as Daguerreotype or cyanotype.\n3. Anime: A stylized cartoon image, emphasizing hyper-realistic portraits and luminous brushwork.\n4. Candid: A natural, unposed shot capturing spontaneous moments, often with cinematic qualities.\n5. Landscape: A photorealistic image of natural scenery, such as a sunrise over the sea.\n6. Design: Colorful and detailed illustrations, often in the style of 2D game art or botanical illustrations.\n7. Urban: An ultrarealistic scene in a modern setting, possibly a cityscape viewed from indoors.\n\nYour task is to translate a given Chinese image description into a concise and precise English description. Ensure that the imagery is vivid and descriptive, and include stylistic elements to enrich the description.\nPlease note the following points:\n\n1. Capture the essence and mood of the Chinese description without including direct phrases or words from the examples provided.\n2. You should add appropriate words to make the images described in the prompt more aesthetically pleasing. If the Chinese description does not specify a style, you need to add some stylistic descriptions based on the essence of the Chinese text.\n3. The generated English description should not exceed 200 words.\n\n"""
+ qwen_prompt = QwenPrompt(
+ tokenizer_path=model_path,
+ model=model,
+ system_prompt=system_prompt
+ )
+ return qwen_prompt
+
+
+ def __call__(self, raw_prompt, positive=True, **kwargs):
+ if positive:
+ messages = [{
+ 'role': 'system',
+ 'content': self.system_prompt
+ }, {
+ 'role': 'user',
+ 'content': raw_prompt
+ }]
+ text = self.tokenizer.apply_chat_template(
+ messages,
+ tokenize=False,
+ add_generation_prompt=True
+ )
+ model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device)
+
+ generated_ids = self.model.generate(
+ model_inputs.input_ids,
+ max_new_tokens=512
+ )
+ generated_ids = [
+ output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
+ ]
+
+ prompt = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
+ print(f"Your prompt is refined by Qwen: {prompt}")
+ return prompt
+ else:
+ return raw_prompt
+
+
+
+class Translator(torch.nn.Module):
+ def __init__(self, tokenizer_path=None, model=None):
+ super().__init__()
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
+ self.model = model
+
+
+ @staticmethod
+ def from_model_manager(model_manager: ModelManager):
+ model, model_path = model_manager.fetch_model("translator", require_model_path=True)
+ translator = Translator(tokenizer_path=model_path, model=model)
+ return translator
+
+
+ def __call__(self, prompt, **kwargs):
+ input_ids = self.tokenizer.encode(prompt, return_tensors='pt').to(self.model.device)
+ output_ids = self.model.generate(input_ids)
+ prompt = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
+ print(f"Your prompt is translated: {prompt}")
+ return prompt
diff --git a/diffsynth/prompters/sd3_prompter.py b/diffsynth/prompters/sd3_prompter.py
new file mode 100644
index 0000000000000000000000000000000000000000..ecf9bca30ae53e78822d06d769a65a6c79e8b5d8
--- /dev/null
+++ b/diffsynth/prompters/sd3_prompter.py
@@ -0,0 +1,93 @@
+from .base_prompter import BasePrompter
+from ..models.model_manager import ModelManager
+from ..models import SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3
+from transformers import CLIPTokenizer, T5TokenizerFast
+import os, torch
+
+
+class SD3Prompter(BasePrompter):
+ def __init__(
+ self,
+ tokenizer_1_path=None,
+ tokenizer_2_path=None,
+ tokenizer_3_path=None
+ ):
+ if tokenizer_1_path is None:
+ base_path = os.path.dirname(os.path.dirname(__file__))
+ tokenizer_1_path = os.path.join(base_path, "tokenizer_configs/stable_diffusion_3/tokenizer_1")
+ if tokenizer_2_path is None:
+ base_path = os.path.dirname(os.path.dirname(__file__))
+ tokenizer_2_path = os.path.join(base_path, "tokenizer_configs/stable_diffusion_3/tokenizer_2")
+ if tokenizer_3_path is None:
+ base_path = os.path.dirname(os.path.dirname(__file__))
+ tokenizer_3_path = os.path.join(base_path, "tokenizer_configs/stable_diffusion_3/tokenizer_3")
+ super().__init__()
+ self.tokenizer_1 = CLIPTokenizer.from_pretrained(tokenizer_1_path)
+ self.tokenizer_2 = CLIPTokenizer.from_pretrained(tokenizer_2_path)
+ self.tokenizer_3 = T5TokenizerFast.from_pretrained(tokenizer_3_path)
+ self.text_encoder_1: SD3TextEncoder1 = None
+ self.text_encoder_2: SD3TextEncoder2 = None
+ self.text_encoder_3: SD3TextEncoder3 = None
+
+
+ def fetch_models(self, text_encoder_1: SD3TextEncoder1 = None, text_encoder_2: SD3TextEncoder2 = None, text_encoder_3: SD3TextEncoder3 = None):
+ self.text_encoder_1 = text_encoder_1
+ self.text_encoder_2 = text_encoder_2
+ self.text_encoder_3 = text_encoder_3
+
+
+ def encode_prompt_using_clip(self, prompt, text_encoder, tokenizer, max_length, device):
+ input_ids = tokenizer(
+ prompt,
+ return_tensors="pt",
+ padding="max_length",
+ max_length=max_length,
+ truncation=True
+ ).input_ids.to(device)
+ pooled_prompt_emb, prompt_emb = text_encoder(input_ids)
+ return pooled_prompt_emb, prompt_emb
+
+
+ def encode_prompt_using_t5(self, prompt, text_encoder, tokenizer, max_length, device):
+ input_ids = tokenizer(
+ prompt,
+ return_tensors="pt",
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ add_special_tokens=True,
+ ).input_ids.to(device)
+ prompt_emb = text_encoder(input_ids)
+ prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1))
+
+ return prompt_emb
+
+
+ def encode_prompt(
+ self,
+ prompt,
+ positive=True,
+ device="cuda",
+ t5_sequence_length=77,
+ ):
+ prompt = self.process_prompt(prompt, positive=positive)
+
+ # CLIP
+ pooled_prompt_emb_1, prompt_emb_1 = self.encode_prompt_using_clip(prompt, self.text_encoder_1, self.tokenizer_1, 77, device)
+ pooled_prompt_emb_2, prompt_emb_2 = self.encode_prompt_using_clip(prompt, self.text_encoder_2, self.tokenizer_2, 77, device)
+
+ # T5
+ if self.text_encoder_3 is None:
+ prompt_emb_3 = torch.zeros((prompt_emb_1.shape[0], t5_sequence_length, 4096), dtype=prompt_emb_1.dtype, device=device)
+ else:
+ prompt_emb_3 = self.encode_prompt_using_t5(prompt, self.text_encoder_3, self.tokenizer_3, t5_sequence_length, device)
+ prompt_emb_3 = prompt_emb_3.to(prompt_emb_1.dtype) # float32 -> float16
+
+ # Merge
+ prompt_emb = torch.cat([
+ torch.nn.functional.pad(torch.cat([prompt_emb_1, prompt_emb_2], dim=-1), (0, 4096 - 768 - 1280)),
+ prompt_emb_3
+ ], dim=-2)
+ pooled_prompt_emb = torch.cat([pooled_prompt_emb_1, pooled_prompt_emb_2], dim=-1)
+
+ return prompt_emb, pooled_prompt_emb
diff --git a/diffsynth/prompters/sd_prompter.py b/diffsynth/prompters/sd_prompter.py
new file mode 100644
index 0000000000000000000000000000000000000000..e3b31ea2836b3b02edab37d7f610c13f2cf6cead
--- /dev/null
+++ b/diffsynth/prompters/sd_prompter.py
@@ -0,0 +1,73 @@
+from .base_prompter import BasePrompter, tokenize_long_prompt
+from ..models.utils import load_state_dict, search_for_embeddings
+from ..models import SDTextEncoder
+from transformers import CLIPTokenizer
+import torch, os
+
+
+
+class SDPrompter(BasePrompter):
+ def __init__(self, tokenizer_path=None):
+ if tokenizer_path is None:
+ base_path = os.path.dirname(os.path.dirname(__file__))
+ tokenizer_path = os.path.join(base_path, "tokenizer_configs/stable_diffusion/tokenizer")
+ super().__init__()
+ self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
+ self.text_encoder: SDTextEncoder = None
+ self.textual_inversion_dict = {}
+ self.keyword_dict = {}
+
+
+ def fetch_models(self, text_encoder: SDTextEncoder = None):
+ self.text_encoder = text_encoder
+
+
+ def add_textual_inversions_to_model(self, textual_inversion_dict, text_encoder):
+ dtype = next(iter(text_encoder.parameters())).dtype
+ state_dict = text_encoder.token_embedding.state_dict()
+ token_embeddings = [state_dict["weight"]]
+ for keyword in textual_inversion_dict:
+ _, embeddings = textual_inversion_dict[keyword]
+ token_embeddings.append(embeddings.to(dtype=dtype, device=token_embeddings[0].device))
+ token_embeddings = torch.concat(token_embeddings, dim=0)
+ state_dict["weight"] = token_embeddings
+ text_encoder.token_embedding = torch.nn.Embedding(token_embeddings.shape[0], token_embeddings.shape[1])
+ text_encoder.token_embedding = text_encoder.token_embedding.to(dtype=dtype, device=token_embeddings[0].device)
+ text_encoder.token_embedding.load_state_dict(state_dict)
+
+
+ def add_textual_inversions_to_tokenizer(self, textual_inversion_dict, tokenizer):
+ additional_tokens = []
+ for keyword in textual_inversion_dict:
+ tokens, _ = textual_inversion_dict[keyword]
+ additional_tokens += tokens
+ self.keyword_dict[keyword] = " " + " ".join(tokens) + " "
+ tokenizer.add_tokens(additional_tokens)
+
+
+ def load_textual_inversions(self, model_paths):
+ for model_path in model_paths:
+ keyword = os.path.splitext(os.path.split(model_path)[-1])[0]
+ state_dict = load_state_dict(model_path)
+
+ # Search for embeddings
+ for embeddings in search_for_embeddings(state_dict):
+ if len(embeddings.shape) == 2 and embeddings.shape[1] == 768:
+ tokens = [f"{keyword}_{i}" for i in range(embeddings.shape[0])]
+ self.textual_inversion_dict[keyword] = (tokens, embeddings)
+
+ self.add_textual_inversions_to_model(self.textual_inversion_dict, self.text_encoder)
+ self.add_textual_inversions_to_tokenizer(self.textual_inversion_dict, self.tokenizer)
+
+
+ def encode_prompt(self, prompt, clip_skip=1, device="cuda", positive=True):
+ prompt = self.process_prompt(prompt, positive=positive)
+ for keyword in self.keyword_dict:
+ if keyword in prompt:
+ print(f"Textual inversion {keyword} is enabled.")
+ prompt = prompt.replace(keyword, self.keyword_dict[keyword])
+ input_ids = tokenize_long_prompt(self.tokenizer, prompt).to(device)
+ prompt_emb = self.text_encoder(input_ids, clip_skip=clip_skip)
+ prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1))
+
+ return prompt_emb
\ No newline at end of file
diff --git a/diffsynth/prompters/sdxl_prompter.py b/diffsynth/prompters/sdxl_prompter.py
new file mode 100644
index 0000000000000000000000000000000000000000..d84145402538b89b23d39a98271cbad64c2d9fc3
--- /dev/null
+++ b/diffsynth/prompters/sdxl_prompter.py
@@ -0,0 +1,61 @@
+from .base_prompter import BasePrompter, tokenize_long_prompt
+from ..models.model_manager import ModelManager
+from ..models import SDXLTextEncoder, SDXLTextEncoder2
+from transformers import CLIPTokenizer
+import torch, os
+
+
+
+class SDXLPrompter(BasePrompter):
+ def __init__(
+ self,
+ tokenizer_path=None,
+ tokenizer_2_path=None
+ ):
+ if tokenizer_path is None:
+ base_path = os.path.dirname(os.path.dirname(__file__))
+ tokenizer_path = os.path.join(base_path, "tokenizer_configs/stable_diffusion/tokenizer")
+ if tokenizer_2_path is None:
+ base_path = os.path.dirname(os.path.dirname(__file__))
+ tokenizer_2_path = os.path.join(base_path, "tokenizer_configs/stable_diffusion_xl/tokenizer_2")
+ super().__init__()
+ self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
+ self.tokenizer_2 = CLIPTokenizer.from_pretrained(tokenizer_2_path)
+ self.text_encoder: SDXLTextEncoder = None
+ self.text_encoder_2: SDXLTextEncoder2 = None
+
+
+ def fetch_models(self, text_encoder: SDXLTextEncoder = None, text_encoder_2: SDXLTextEncoder2 = None):
+ self.text_encoder = text_encoder
+ self.text_encoder_2 = text_encoder_2
+
+
+ def encode_prompt(
+ self,
+ prompt,
+ clip_skip=1,
+ clip_skip_2=2,
+ positive=True,
+ device="cuda"
+ ):
+ prompt = self.process_prompt(prompt, positive=positive)
+
+ # 1
+ input_ids = tokenize_long_prompt(self.tokenizer, prompt).to(device)
+ prompt_emb_1 = self.text_encoder(input_ids, clip_skip=clip_skip)
+
+ # 2
+ input_ids_2 = tokenize_long_prompt(self.tokenizer_2, prompt).to(device)
+ add_text_embeds, prompt_emb_2 = self.text_encoder_2(input_ids_2, clip_skip=clip_skip_2)
+
+ # Merge
+ if prompt_emb_1.shape[0] != prompt_emb_2.shape[0]:
+ max_batch_size = min(prompt_emb_1.shape[0], prompt_emb_2.shape[0])
+ prompt_emb_1 = prompt_emb_1[: max_batch_size]
+ prompt_emb_2 = prompt_emb_2[: max_batch_size]
+ prompt_emb = torch.concatenate([prompt_emb_1, prompt_emb_2], dim=-1)
+
+ # For very long prompt, we only use the first 77 tokens to compute `add_text_embeds`.
+ add_text_embeds = add_text_embeds[0:1]
+ prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1))
+ return add_text_embeds, prompt_emb
diff --git a/diffsynth/prompters/stepvideo_prompter.py b/diffsynth/prompters/stepvideo_prompter.py
new file mode 100644
index 0000000000000000000000000000000000000000..79d374b1f8a4be2a2298520fcbf87800e0ca91d9
--- /dev/null
+++ b/diffsynth/prompters/stepvideo_prompter.py
@@ -0,0 +1,56 @@
+from .base_prompter import BasePrompter
+from ..models.hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder
+from ..models.stepvideo_text_encoder import STEP1TextEncoder
+from transformers import BertTokenizer
+import os, torch
+
+
+class StepVideoPrompter(BasePrompter):
+
+ def __init__(
+ self,
+ tokenizer_1_path=None,
+ ):
+ if tokenizer_1_path is None:
+ base_path = os.path.dirname(os.path.dirname(__file__))
+ tokenizer_1_path = os.path.join(
+ base_path, "tokenizer_configs/hunyuan_dit/tokenizer")
+ super().__init__()
+ self.tokenizer_1 = BertTokenizer.from_pretrained(tokenizer_1_path)
+
+ def fetch_models(self, text_encoder_1: HunyuanDiTCLIPTextEncoder = None, text_encoder_2: STEP1TextEncoder = None):
+ self.text_encoder_1 = text_encoder_1
+ self.text_encoder_2 = text_encoder_2
+
+ def encode_prompt_using_clip(self, prompt, max_length, device):
+ text_inputs = self.tokenizer_1(
+ prompt,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_attention_mask=True,
+ return_tensors="pt",
+ )
+ prompt_embeds = self.text_encoder_1(
+ text_inputs.input_ids.to(device),
+ attention_mask=text_inputs.attention_mask.to(device),
+ )
+ return prompt_embeds
+
+ def encode_prompt_using_llm(self, prompt, max_length, device):
+ y, y_mask = self.text_encoder_2(prompt, max_length=max_length, device=device)
+ return y, y_mask
+
+ def encode_prompt(self,
+ prompt,
+ positive=True,
+ device="cuda"):
+
+ prompt = self.process_prompt(prompt, positive=positive)
+
+ clip_embeds = self.encode_prompt_using_clip(prompt, max_length=77, device=device)
+ llm_embeds, llm_mask = self.encode_prompt_using_llm(prompt, max_length=320, device=device)
+
+ llm_mask = torch.nn.functional.pad(llm_mask, (clip_embeds.shape[1], 0), value=1)
+
+ return clip_embeds, llm_embeds, llm_mask
diff --git a/diffsynth/prompters/wan_prompter.py b/diffsynth/prompters/wan_prompter.py
new file mode 100644
index 0000000000000000000000000000000000000000..01a765d3cb3bf2ee4d06553fd061ed7dd75443b2
--- /dev/null
+++ b/diffsynth/prompters/wan_prompter.py
@@ -0,0 +1,109 @@
+from .base_prompter import BasePrompter
+from ..models.wan_video_text_encoder import WanTextEncoder
+from transformers import AutoTokenizer
+import os, torch
+import ftfy
+import html
+import string
+import regex as re
+
+
+def basic_clean(text):
+ text = ftfy.fix_text(text)
+ text = html.unescape(html.unescape(text))
+ return text.strip()
+
+
+def whitespace_clean(text):
+ text = re.sub(r'\s+', ' ', text)
+ text = text.strip()
+ return text
+
+
+def canonicalize(text, keep_punctuation_exact_string=None):
+ text = text.replace('_', ' ')
+ if keep_punctuation_exact_string:
+ text = keep_punctuation_exact_string.join(
+ part.translate(str.maketrans('', '', string.punctuation))
+ for part in text.split(keep_punctuation_exact_string))
+ else:
+ text = text.translate(str.maketrans('', '', string.punctuation))
+ text = text.lower()
+ text = re.sub(r'\s+', ' ', text)
+ return text.strip()
+
+
+class HuggingfaceTokenizer:
+
+ def __init__(self, name, seq_len=None, clean=None, **kwargs):
+ assert clean in (None, 'whitespace', 'lower', 'canonicalize')
+ self.name = name
+ self.seq_len = seq_len
+ self.clean = clean
+
+ # init tokenizer
+ self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs)
+ self.vocab_size = self.tokenizer.vocab_size
+
+ def __call__(self, sequence, **kwargs):
+ return_mask = kwargs.pop('return_mask', False)
+
+ # arguments
+ _kwargs = {'return_tensors': 'pt'}
+ if self.seq_len is not None:
+ _kwargs.update({
+ 'padding': 'max_length',
+ 'truncation': True,
+ 'max_length': self.seq_len
+ })
+ _kwargs.update(**kwargs)
+
+ # tokenization
+ if isinstance(sequence, str):
+ sequence = [sequence]
+ if self.clean:
+ sequence = [self._clean(u) for u in sequence]
+ ids = self.tokenizer(sequence, **_kwargs)
+
+ # output
+ if return_mask:
+ return ids.input_ids, ids.attention_mask
+ else:
+ return ids.input_ids
+
+ def _clean(self, text):
+ if self.clean == 'whitespace':
+ text = whitespace_clean(basic_clean(text))
+ elif self.clean == 'lower':
+ text = whitespace_clean(basic_clean(text)).lower()
+ elif self.clean == 'canonicalize':
+ text = canonicalize(basic_clean(text))
+ return text
+
+
+class WanPrompter(BasePrompter):
+
+ def __init__(self, tokenizer_path=None, text_len=512):
+ super().__init__()
+ self.text_len = text_len
+ self.text_encoder = None
+ self.fetch_tokenizer(tokenizer_path)
+
+ def fetch_tokenizer(self, tokenizer_path=None):
+ if tokenizer_path is not None:
+ self.tokenizer = HuggingfaceTokenizer(name=tokenizer_path, seq_len=self.text_len, clean='whitespace')
+
+ def fetch_models(self, text_encoder: WanTextEncoder = None):
+ self.text_encoder = text_encoder
+
+ def encode_prompt(self, prompt, positive=True, device="cuda"):
+ prompt = self.process_prompt(prompt, positive=positive)
+
+ ids, mask = self.tokenizer(prompt, return_mask=True, add_special_tokens=True)
+ ids = ids.to(device)
+ mask = mask.to(device)
+ seq_lens = mask.gt(0).sum(dim=1).long()
+ prompt_emb = self.text_encoder(ids, mask)
+ for i, v in enumerate(seq_lens):
+ prompt_emb[:, v:] = 0
+ return prompt_emb
diff --git a/diffsynth/schedulers/__init__.py b/diffsynth/schedulers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ec43257b687c9b5504e08e05763332755566ea5
--- /dev/null
+++ b/diffsynth/schedulers/__init__.py
@@ -0,0 +1,3 @@
+from .ddim import EnhancedDDIMScheduler
+from .continuous_ode import ContinuousODEScheduler
+from .flow_match import FlowMatchScheduler
diff --git a/diffsynth/schedulers/continuous_ode.py b/diffsynth/schedulers/continuous_ode.py
new file mode 100644
index 0000000000000000000000000000000000000000..c73b9e221aa54a8385322b42012c30c598550fcd
--- /dev/null
+++ b/diffsynth/schedulers/continuous_ode.py
@@ -0,0 +1,59 @@
+import torch
+
+
+class ContinuousODEScheduler():
+
+ def __init__(self, num_inference_steps=100, sigma_max=700.0, sigma_min=0.002, rho=7.0):
+ self.sigma_max = sigma_max
+ self.sigma_min = sigma_min
+ self.rho = rho
+ self.set_timesteps(num_inference_steps)
+
+
+ def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, **kwargs):
+ ramp = torch.linspace(1-denoising_strength, 1, num_inference_steps)
+ min_inv_rho = torch.pow(torch.tensor((self.sigma_min,)), (1 / self.rho))
+ max_inv_rho = torch.pow(torch.tensor((self.sigma_max,)), (1 / self.rho))
+ self.sigmas = torch.pow(max_inv_rho + ramp * (min_inv_rho - max_inv_rho), self.rho)
+ self.timesteps = torch.log(self.sigmas) * 0.25
+
+
+ def step(self, model_output, timestep, sample, to_final=False):
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
+ sigma = self.sigmas[timestep_id]
+ sample *= (sigma*sigma + 1).sqrt()
+ estimated_sample = -sigma / (sigma*sigma + 1).sqrt() * model_output + 1 / (sigma*sigma + 1) * sample
+ if to_final or timestep_id + 1 >= len(self.timesteps):
+ prev_sample = estimated_sample
+ else:
+ sigma_ = self.sigmas[timestep_id + 1]
+ derivative = 1 / sigma * (sample - estimated_sample)
+ prev_sample = sample + derivative * (sigma_ - sigma)
+ prev_sample /= (sigma_*sigma_ + 1).sqrt()
+ return prev_sample
+
+
+ def return_to_timestep(self, timestep, sample, sample_stablized):
+ # This scheduler doesn't support this function.
+ pass
+
+
+ def add_noise(self, original_samples, noise, timestep):
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
+ sigma = self.sigmas[timestep_id]
+ sample = (original_samples + noise * sigma) / (sigma*sigma + 1).sqrt()
+ return sample
+
+
+ def training_target(self, sample, noise, timestep):
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
+ sigma = self.sigmas[timestep_id]
+ target = (-(sigma*sigma + 1).sqrt() / sigma + 1 / (sigma*sigma + 1).sqrt() / sigma) * sample + 1 / (sigma*sigma + 1).sqrt() * noise
+ return target
+
+
+ def training_weight(self, timestep):
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
+ sigma = self.sigmas[timestep_id]
+ weight = (1 + sigma*sigma).sqrt() / sigma
+ return weight
diff --git a/diffsynth/schedulers/ddim.py b/diffsynth/schedulers/ddim.py
new file mode 100644
index 0000000000000000000000000000000000000000..da524963c62f662016b1429d5047ebe7b5922604
--- /dev/null
+++ b/diffsynth/schedulers/ddim.py
@@ -0,0 +1,105 @@
+import torch, math
+
+
+class EnhancedDDIMScheduler():
+
+ def __init__(self, num_train_timesteps=1000, beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", prediction_type="epsilon", rescale_zero_terminal_snr=False):
+ self.num_train_timesteps = num_train_timesteps
+ if beta_schedule == "scaled_linear":
+ betas = torch.square(torch.linspace(math.sqrt(beta_start), math.sqrt(beta_end), num_train_timesteps, dtype=torch.float32))
+ elif beta_schedule == "linear":
+ betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
+ else:
+ raise NotImplementedError(f"{beta_schedule} is not implemented")
+ self.alphas_cumprod = torch.cumprod(1.0 - betas, dim=0)
+ if rescale_zero_terminal_snr:
+ self.alphas_cumprod = self.rescale_zero_terminal_snr(self.alphas_cumprod)
+ self.alphas_cumprod = self.alphas_cumprod.tolist()
+ self.set_timesteps(10)
+ self.prediction_type = prediction_type
+
+
+ def rescale_zero_terminal_snr(self, alphas_cumprod):
+ alphas_bar_sqrt = alphas_cumprod.sqrt()
+
+ # Store old values.
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
+
+ # Shift so the last timestep is zero.
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
+
+ # Scale so the first timestep is back to the old value.
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
+
+ # Convert alphas_bar_sqrt to betas
+ alphas_bar = alphas_bar_sqrt.square() # Revert sqrt
+
+ return alphas_bar
+
+
+ def set_timesteps(self, num_inference_steps, denoising_strength=1.0, **kwargs):
+ # The timesteps are aligned to 999...0, which is different from other implementations,
+ # but I think this implementation is more reasonable in theory.
+ max_timestep = max(round(self.num_train_timesteps * denoising_strength) - 1, 0)
+ num_inference_steps = min(num_inference_steps, max_timestep + 1)
+ if num_inference_steps == 1:
+ self.timesteps = torch.Tensor([max_timestep])
+ else:
+ step_length = max_timestep / (num_inference_steps - 1)
+ self.timesteps = torch.Tensor([round(max_timestep - i*step_length) for i in range(num_inference_steps)])
+
+
+ def denoise(self, model_output, sample, alpha_prod_t, alpha_prod_t_prev):
+ if self.prediction_type == "epsilon":
+ weight_e = math.sqrt(1 - alpha_prod_t_prev) - math.sqrt(alpha_prod_t_prev * (1 - alpha_prod_t) / alpha_prod_t)
+ weight_x = math.sqrt(alpha_prod_t_prev / alpha_prod_t)
+ prev_sample = sample * weight_x + model_output * weight_e
+ elif self.prediction_type == "v_prediction":
+ weight_e = -math.sqrt(alpha_prod_t_prev * (1 - alpha_prod_t)) + math.sqrt(alpha_prod_t * (1 - alpha_prod_t_prev))
+ weight_x = math.sqrt(alpha_prod_t * alpha_prod_t_prev) + math.sqrt((1 - alpha_prod_t) * (1 - alpha_prod_t_prev))
+ prev_sample = sample * weight_x + model_output * weight_e
+ else:
+ raise NotImplementedError(f"{self.prediction_type} is not implemented")
+ return prev_sample
+
+
+ def step(self, model_output, timestep, sample, to_final=False):
+ alpha_prod_t = self.alphas_cumprod[int(timestep.flatten().tolist()[0])]
+ if isinstance(timestep, torch.Tensor):
+ timestep = timestep.cpu()
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
+ if to_final or timestep_id + 1 >= len(self.timesteps):
+ alpha_prod_t_prev = 1.0
+ else:
+ timestep_prev = int(self.timesteps[timestep_id + 1])
+ alpha_prod_t_prev = self.alphas_cumprod[timestep_prev]
+
+ return self.denoise(model_output, sample, alpha_prod_t, alpha_prod_t_prev)
+
+
+ def return_to_timestep(self, timestep, sample, sample_stablized):
+ alpha_prod_t = self.alphas_cumprod[int(timestep.flatten().tolist()[0])]
+ noise_pred = (sample - math.sqrt(alpha_prod_t) * sample_stablized) / math.sqrt(1 - alpha_prod_t)
+ return noise_pred
+
+
+ def add_noise(self, original_samples, noise, timestep):
+ sqrt_alpha_prod = math.sqrt(self.alphas_cumprod[int(timestep.flatten().tolist()[0])])
+ sqrt_one_minus_alpha_prod = math.sqrt(1 - self.alphas_cumprod[int(timestep.flatten().tolist()[0])])
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
+ return noisy_samples
+
+
+ def training_target(self, sample, noise, timestep):
+ if self.prediction_type == "epsilon":
+ return noise
+ else:
+ sqrt_alpha_prod = math.sqrt(self.alphas_cumprod[int(timestep.flatten().tolist()[0])])
+ sqrt_one_minus_alpha_prod = math.sqrt(1 - self.alphas_cumprod[int(timestep.flatten().tolist()[0])])
+ target = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
+ return target
+
+
+ def training_weight(self, timestep):
+ return 1.0
diff --git a/diffsynth/schedulers/flow_match.py b/diffsynth/schedulers/flow_match.py
new file mode 100644
index 0000000000000000000000000000000000000000..3bb2405a0069586595a675896876ce6ad37054b3
--- /dev/null
+++ b/diffsynth/schedulers/flow_match.py
@@ -0,0 +1,125 @@
+import torch, math
+
+
+
+class FlowMatchScheduler():
+
+ def __init__(
+ self,
+ num_inference_steps=100,
+ num_train_timesteps=1000,
+ shift=3.0,
+ sigma_max=1.0,
+ sigma_min=0.003/1.002,
+ inverse_timesteps=False,
+ extra_one_step=False,
+ reverse_sigmas=False,
+ exponential_shift=False,
+ exponential_shift_mu=None,
+ shift_terminal=None,
+ ):
+ self.num_train_timesteps = num_train_timesteps
+ self.shift = shift
+ self.sigma_max = sigma_max
+ self.sigma_min = sigma_min
+ self.inverse_timesteps = inverse_timesteps
+ self.extra_one_step = extra_one_step
+ self.reverse_sigmas = reverse_sigmas
+ self.exponential_shift = exponential_shift
+ self.exponential_shift_mu = exponential_shift_mu
+ self.shift_terminal = shift_terminal
+ self.set_timesteps(num_inference_steps)
+
+
+ def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, shift=None, dynamic_shift_len=None, exponential_shift_mu=None):
+ if shift is not None:
+ self.shift = shift
+ sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min) * denoising_strength
+ if self.extra_one_step:
+ self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps + 1)[:-1]
+ else:
+ self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps)
+ if self.inverse_timesteps:
+ self.sigmas = torch.flip(self.sigmas, dims=[0])
+ if self.exponential_shift:
+ if exponential_shift_mu is not None:
+ mu = exponential_shift_mu
+ elif dynamic_shift_len is not None:
+ mu = self.calculate_shift(dynamic_shift_len)
+ else:
+ mu = self.exponential_shift_mu
+ self.sigmas = math.exp(mu) / (math.exp(mu) + (1 / self.sigmas - 1))
+ else:
+ self.sigmas = self.shift * self.sigmas / (1 + (self.shift - 1) * self.sigmas)
+ if self.shift_terminal is not None:
+ one_minus_z = 1 - self.sigmas
+ scale_factor = one_minus_z[-1] / (1 - self.shift_terminal)
+ self.sigmas = 1 - (one_minus_z / scale_factor)
+ if self.reverse_sigmas:
+ self.sigmas = 1 - self.sigmas
+ self.timesteps = self.sigmas * self.num_train_timesteps
+ if training:
+ x = self.timesteps
+ y = torch.exp(-2 * ((x - num_inference_steps / 2) / num_inference_steps) ** 2)
+ y_shifted = y - y.min()
+ bsmntw_weighing = y_shifted * (num_inference_steps / y_shifted.sum())
+ self.linear_timesteps_weights = bsmntw_weighing
+ self.training = True
+ else:
+ self.training = False
+
+
+ def step(self, model_output, timestep, sample, to_final=False, **kwargs):
+ if isinstance(timestep, torch.Tensor):
+ timestep = timestep.cpu()
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
+ sigma = self.sigmas[timestep_id]
+ if to_final or timestep_id + 1 >= len(self.timesteps):
+ sigma_ = 1 if (self.inverse_timesteps or self.reverse_sigmas) else 0
+ else:
+ sigma_ = self.sigmas[timestep_id + 1]
+ prev_sample = sample + model_output * (sigma_ - sigma)
+ return prev_sample
+
+
+ def return_to_timestep(self, timestep, sample, sample_stablized):
+ if isinstance(timestep, torch.Tensor):
+ timestep = timestep.cpu()
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
+ sigma = self.sigmas[timestep_id]
+ model_output = (sample - sample_stablized) / sigma
+ return model_output
+
+
+ def add_noise(self, original_samples, noise, timestep):
+ if isinstance(timestep, torch.Tensor):
+ timestep = timestep.cpu()
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
+ sigma = self.sigmas[timestep_id]
+ sample = (1 - sigma) * original_samples + sigma * noise
+ return sample
+
+
+ def training_target(self, sample, noise, timestep):
+ target = noise - sample
+ return target
+
+
+ def training_weight(self, timestep):
+ timestep_id = torch.argmin((self.timesteps - timestep.to(self.timesteps.device)).abs())
+ weights = self.linear_timesteps_weights[timestep_id]
+ return weights
+
+
+ def calculate_shift(
+ self,
+ image_seq_len,
+ base_seq_len: int = 256,
+ max_seq_len: int = 8192,
+ base_shift: float = 0.5,
+ max_shift: float = 0.9,
+ ):
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
+ b = base_shift - m * base_seq_len
+ mu = image_seq_len * m + b
+ return mu
diff --git a/diffsynth/tokenizer_configs/__init__.py b/diffsynth/tokenizer_configs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/diffsynth/tokenizer_configs/cog/tokenizer/added_tokens.json b/diffsynth/tokenizer_configs/cog/tokenizer/added_tokens.json
new file mode 100644
index 0000000000000000000000000000000000000000..3f5132007c4fcf42b75b65c8b6aa49c7098bcdf4
--- /dev/null
+++ b/diffsynth/tokenizer_configs/cog/tokenizer/added_tokens.json
@@ -0,0 +1,102 @@
+{
+ "": 32099,
+ "": 32089,
+ "": 32088,
+ "": 32087,
+ "": 32086,
+ "": 32085,
+ "": 32084,
+ "": 32083,
+ "": 32082,
+ "": 32081,
+ "": 32080,
+ "": 32098,
+ "": 32079,
+ "": 32078,
+ "": 32077,
+ "": 32076,
+ "": 32075,
+ "": 32074,
+ "": 32073,
+ "": 32072,
+ "": 32071,
+ "": 32070,
+ "": 32097,
+ "": 32069,
+ "": 32068,
+ "": 32067,
+ "": 32066,
+ "": 32065,
+ "": 32064,
+ "": 32063,
+ "": 32062,
+ "": 32061,
+ "": 32060,
+ "": 32096,
+ "": 32059,
+ "": 32058,
+ "": 32057,
+ "": 32056,
+ "": 32055,
+ "": 32054,
+ "": 32053,
+ "": 32052,
+ "": 32051,
+ "": 32050,
+ "": 32095,
+ "": 32049,
+ "": 32048,
+ "": 32047,
+ "": 32046,
+ "": 32045,
+ "": 32044,
+ "": 32043,
+ "": 32042,
+ "": 32041,
+ "": 32040,
+ "": 32094,
+ "": 32039,
+ "": 32038,
+ "": 32037,
+ "": 32036,
+ "": 32035,
+ "": 32034,
+ "": 32033,
+ "": 32032,
+ "": 32031,
+ "": 32030,
+ "": 32093,
+ "": 32029,
+ "": 32028,
+ "": 32027,
+ "": 32026,
+ "": 32025,
+ "": 32024,
+ "": 32023,
+ "": 32022,
+ "": 32021,
+ "": 32020,
+ "": 32092,
+ "": 32019,
+ "": 32018,
+ "": 32017,
+ "": 32016,
+ "": 32015,
+ "": 32014,
+ "": 32013,
+ "": 32012,
+ "": 32011,
+ "": 32010,
+ "": 32091,
+ "": 32009,
+ "": 32008,
+ "": 32007,
+ "": 32006,
+ "": 32005,
+ "": 32004,
+ "": 32003,
+ "": 32002,
+ "": 32001,
+ "": 32000,
+ "": 32090
+}
diff --git a/diffsynth/tokenizer_configs/cog/tokenizer/special_tokens_map.json b/diffsynth/tokenizer_configs/cog/tokenizer/special_tokens_map.json
new file mode 100644
index 0000000000000000000000000000000000000000..17ade346a1042cbe0c1436f5bedcbd85c099d582
--- /dev/null
+++ b/diffsynth/tokenizer_configs/cog/tokenizer/special_tokens_map.json
@@ -0,0 +1,125 @@
+{
+ "additional_special_tokens": [
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ ""
+ ],
+ "eos_token": {
+ "content": "",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false
+ },
+ "pad_token": {
+ "content": "",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false
+ },
+ "unk_token": {
+ "content": "",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false
+ }
+}
diff --git a/diffsynth/tokenizer_configs/cog/tokenizer/spiece.model b/diffsynth/tokenizer_configs/cog/tokenizer/spiece.model
new file mode 100644
index 0000000000000000000000000000000000000000..317a5ccbde45300f5d1d970d4d449af2108b147e
--- /dev/null
+++ b/diffsynth/tokenizer_configs/cog/tokenizer/spiece.model
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d60acb128cf7b7f2536e8f38a5b18a05535c9e14c7a355904270e15b0945ea86
+size 791656
diff --git a/diffsynth/tokenizer_configs/cog/tokenizer/tokenizer_config.json b/diffsynth/tokenizer_configs/cog/tokenizer/tokenizer_config.json
new file mode 100644
index 0000000000000000000000000000000000000000..161715af5ee99558c9fcce7b31d3d547a72c349b
--- /dev/null
+++ b/diffsynth/tokenizer_configs/cog/tokenizer/tokenizer_config.json
@@ -0,0 +1,940 @@
+{
+ "add_prefix_space": true,
+ "added_tokens_decoder": {
+ "0": {
+ "content": "",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": true
+ },
+ "1": {
+ "content": "",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": true
+ },
+ "2": {
+ "content": "",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": true
+ },
+ "32000": {
+ "content": "",
+ "lstrip": true,
+ "normalized": false,
+ "rstrip": true,
+ "single_word": false,
+ "special": true
+ },
+ "32001": {
+ "content": "",
+ "lstrip": true,
+ "normalized": false,
+ "rstrip": true,
+ "single_word": false,
+ "special": true
+ },
+ "32002": {
+ "content": "",
+ "lstrip": true,
+ "normalized": false,
+ "rstrip": true,
+ "single_word": false,
+ "special": true
+ },
+ "32003": {
+ "content": "",
+ "lstrip": true,
+ "normalized": false,
+ "rstrip": true,
+ "single_word": false,
+ "special": true
+ },
+ "32004": {
+ "content": "",
+ "lstrip": true,
+ "normalized": false,
+ "rstrip": true,
+ "single_word": false,
+ "special": true
+ },
+ "32005": {
+ "content": "",
+ "lstrip": true,
+ "normalized": false,
+ "rstrip": true,
+ "single_word": false,
+ "special": true
+ },
+ "32006": {
+ "content": "",
+ "lstrip": true,
+ "normalized": false,
+ "rstrip": true,
+ "single_word": false,
+ "special": true
+ },
+ "32007": {
+ "content": "",
+ "lstrip": true,
+ "normalized": false,
+ "rstrip": true,
+ "single_word": false,
+ "special": true
+ },
+ "32008": {
+ "content": "",
+ "lstrip": true,
+ "normalized": false,
+ "rstrip": true,
+ "single_word": false,
+ "special": true
+ },
+ "32009": {
+ "content": "",
+ "lstrip": true,
+ "normalized": false,
+ "rstrip": true,
+ "single_word": false,
+ "special": true
+ },
+ "32010": {
+ "content": "",
+ "lstrip": true,
+ "normalized": false,
+ "rstrip": true,
+ "single_word": false,
+ "special": true
+ },
+ "32011": {
+ "content": "",
+ "lstrip": true,
+ "normalized": false,
+ "rstrip": true,
+ "single_word": false,
+ "special": true
+ },
+ "32012": {
+ "content": "",
+ "lstrip": true,
+ "normalized": false,
+ "rstrip": true,
+ "single_word": false,
+ "special": true
+ },
+ "32013": {
+ "content": "",
+ "lstrip": true,
+ "normalized": false,
+ "rstrip": true,
+ "single_word": false,
+ "special": true
+ },
+ "32014": {
+ "content": "",
+ "lstrip": true,
+ "normalized": false,
+ "rstrip": true,
+ "single_word": false,
+ "special": true
+ },
+ "32015": {
+ "content": "",
+ "lstrip": true,
+ "normalized": false,
+ "rstrip": true,
+ "single_word": false,
+ "special": true
+ },
+ "32016": {
+ "content": "",
+ "lstrip": true,
+ "normalized": false,
+ "rstrip": true,
+ "single_word": false,
+ "special": true
+ },
+ "32017": {
+ "content": "",
+ "lstrip": true,
+ "normalized": false,
+ "rstrip": true,
+ "single_word": false,
+ "special": true
+ },
+ "32018": {
+ "content": "",
+ "lstrip": true,
+ "normalized": false,
+ "rstrip": true,
+ "single_word": false,
+ "special": true
+ },
+ "32019": {
+ "content": "",
+ "lstrip": true,
+ "normalized": false,
+ "rstrip": true,
+ "single_word": false,
+ "special": true
+ },
+ "32020": {
+ "content": "",
+ "lstrip": true,
+ "normalized": false,
+ "rstrip": true,
+ "single_word": false,
+ "special": true
+ },
+ "32021": {
+ "content": "",
+ "lstrip": true,
+ "normalized": false,
+ "rstrip": true,
+ "single_word": false,
+ "special": true
+ },
+ "32022": {
+ "content": "",
+ "lstrip": true,
+ "normalized": false,
+ "rstrip": true,
+ "single_word": false,
+ "special": true
+ },
+ "32023": {
+ "content": "",
+ "lstrip": true,
+ "normalized": false,
+ "rstrip": true,
+ "single_word": false,
+ "special": true
+ },
+ "32024": {
+ "content": "",
+ "lstrip": true,
+ "normalized": false,
+ "rstrip": true,
+ "single_word": false,
+ "special": true
+ },
+ "32025": {
+ "content": "",
+ "lstrip": true,
+ "normalized": false,
+ "rstrip": true,
+ "single_word": false,
+ "special": true
+ },
+ "32026": {
+ "content": "",
+ "lstrip": true,
+ "normalized": false,
+ "rstrip": true,
+ "single_word": false,
+ "special": true
+ },
+ "32027": {
+ "content": "",
+ "lstrip": true,
+ "normalized": false,
+ "rstrip": true,
+ "single_word": false,
+ "special": true
+ },
+ "32028": {
+ "content": "",
+ "lstrip": true,
+ "normalized": false,
+ "rstrip": true,
+ "single_word": false,
+ "special": true
+ },
+ "32029": {
+ "content": "