diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..3d9bacc435abcf98a8b199699d931b91c6e682df 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,60 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +__assets__/videos/A_0_0.gif filter=lfs diff=lfs merge=lfs -text +__assets__/videos/A_0_1.gif filter=lfs diff=lfs merge=lfs -text +__assets__/videos/A_0_2.gif filter=lfs diff=lfs merge=lfs -text +__assets__/videos/A_0_3.gif filter=lfs diff=lfs merge=lfs -text +__assets__/videos/A_1_2.gif filter=lfs diff=lfs merge=lfs -text +__assets__/videos/A_2_0.gif filter=lfs diff=lfs merge=lfs -text +__assets__/videos/A_2_1.gif filter=lfs diff=lfs merge=lfs -text +__assets__/videos/A_2_2.gif filter=lfs diff=lfs merge=lfs -text +__assets__/videos/A_2_3.gif filter=lfs diff=lfs merge=lfs -text +__assets__/videos/A_3_0.gif filter=lfs diff=lfs merge=lfs -text +__assets__/videos/A_3_1.gif filter=lfs diff=lfs merge=lfs -text +__assets__/videos/A_3_2.gif filter=lfs diff=lfs merge=lfs -text +__assets__/videos/A_3_3.gif filter=lfs diff=lfs merge=lfs -text +__assets__/videos/A_4_0.gif filter=lfs diff=lfs merge=lfs -text +__assets__/videos/A_4_1.gif filter=lfs diff=lfs merge=lfs -text +__assets__/videos/A_4_2.gif filter=lfs diff=lfs merge=lfs -text +__assets__/videos/A_4_3.gif filter=lfs diff=lfs merge=lfs -text +__assets__/videos/A_5_0.gif filter=lfs diff=lfs merge=lfs -text +__assets__/videos/A_5_1.gif filter=lfs diff=lfs merge=lfs -text +__assets__/videos/A_5_3.gif filter=lfs diff=lfs merge=lfs -text +__assets__/videos/A_6_0.gif filter=lfs diff=lfs merge=lfs -text +__assets__/videos/A_6_1.gif filter=lfs diff=lfs merge=lfs -text +__assets__/videos/A_6_2.gif filter=lfs diff=lfs merge=lfs -text +__assets__/videos/A_6_3.gif filter=lfs diff=lfs merge=lfs -text +__assets__/videos/A_7_0.gif filter=lfs diff=lfs merge=lfs -text +__assets__/videos/A_7_1.gif filter=lfs diff=lfs merge=lfs -text +__assets__/videos/A_7_2.gif filter=lfs diff=lfs merge=lfs -text +__assets__/videos/A_7_3.gif filter=lfs diff=lfs merge=lfs -text +__assets__/videos/A_8_0.gif filter=lfs diff=lfs merge=lfs -text +__assets__/videos/A_8_1.gif filter=lfs diff=lfs merge=lfs -text +__assets__/videos/A_8_2.gif filter=lfs diff=lfs merge=lfs -text +__assets__/videos/A_8_3.gif filter=lfs diff=lfs merge=lfs -text +__assets__/videos/B_0_0.gif filter=lfs diff=lfs merge=lfs -text +__assets__/videos/B_0_1.gif filter=lfs diff=lfs merge=lfs -text +__assets__/videos/B_0_2.gif filter=lfs diff=lfs merge=lfs -text +__assets__/videos/B_1_0.gif filter=lfs diff=lfs merge=lfs -text +__assets__/videos/B_1_1.gif filter=lfs diff=lfs merge=lfs -text +__assets__/videos/B_1_2.gif filter=lfs diff=lfs merge=lfs -text +__assets__/videos/B_2_0.gif filter=lfs diff=lfs merge=lfs -text +__assets__/videos/B_2_1.gif filter=lfs diff=lfs merge=lfs -text +__assets__/videos/B_2_2.gif filter=lfs diff=lfs merge=lfs -text +__assets__/videos/C_0_0.gif filter=lfs diff=lfs merge=lfs -text +__assets__/videos/C_0_1.gif filter=lfs diff=lfs merge=lfs -text +__assets__/videos/C_0_2.gif filter=lfs diff=lfs merge=lfs -text +__assets__/videos/C_0_3.gif filter=lfs diff=lfs merge=lfs -text +__assets__/videos/C_1_0.gif filter=lfs diff=lfs merge=lfs -text +__assets__/videos/C_1_1.gif filter=lfs diff=lfs merge=lfs -text +__assets__/videos/C_1_2.gif filter=lfs diff=lfs merge=lfs -text +__assets__/videos/C_1_3.gif filter=lfs diff=lfs merge=lfs -text +__assets__/videos/D_0_0.gif filter=lfs diff=lfs merge=lfs -text +__assets__/videos/D_0_1.gif filter=lfs diff=lfs merge=lfs -text +__assets__/videos/D_0_2.gif filter=lfs diff=lfs merge=lfs -text +__assets__/videos/D_0_3.gif filter=lfs diff=lfs merge=lfs -text +__assets__/videos/D_0_4.gif filter=lfs diff=lfs merge=lfs -text +__assets__/videos/D_0_5.gif filter=lfs diff=lfs merge=lfs -text +__assets__/videos/D_0_6.gif filter=lfs diff=lfs merge=lfs -text +__assets__/videos/D_0_7.gif filter=lfs diff=lfs merge=lfs -text diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..261eeb9e9f8b2b4b0d119366dda99c6fd7d35c64 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..9548abf7d04b26ad7f3274f4c6d44c28bc457830 --- /dev/null +++ b/README.md @@ -0,0 +1,281 @@ +

MagicTime: Time-lapse Video Generation Models + +as Metamorphic Simulators

+
If you like our project, please give us a star ⭐ on GitHub for the latest update.
+ +
+ + +[![hf_space](https://img.shields.io/badge/🤗-Open%20In%20Spaces-blue.svg)](https://pku-yuangroup.github.io/MagicTime/) +[![arXiv](https://img.shields.io/badge/Arxiv-2404.05014-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2404.05014) +[![Home Page](https://img.shields.io/badge/Project--blue.svg)](https://pku-yuangroup.github.io/MagicTime/) +[![Dataset](https://img.shields.io/badge/Dataset--green)](https://drive.google.com/drive/folders/1WsomdkmSp3ql3ImcNsmzFuSQ9Qukuyr8?usp=sharing) +[![zhihu](https://img.shields.io/badge/-Twitter@AK%20-black?logo=twitter&logoColor=1D9BF0)](https://twitter.com/_akhaliq/status/1777538468043792473) +[![zhihu](https://img.shields.io/badge/-Twitter@Jinfa%20Huang%20-black?logo=twitter&logoColor=1D9BF0)](https://twitter.com/vhjf36495872/status/1777525817087553827?s=61&t=r2HzCsU2AnJKbR8yKSprKw) +[![License](https://img.shields.io/badge/License-Apache%202.0-yellow)](https://github.com/PKU-YuanGroup/MagicTime/blob/main/LICENSE) +![GitHub Repo stars](https://img.shields.io/github/stars/PKU-YuanGroup/MagicTime) + +
+ +
+This repository is the official implementation of MagicTime, a metamorphic video generation pipeline based on the given prompts. The main idea is to enhance the capacity of video generation models to accurately depict the real world through our proposed methods and dataset. +
+ +## 📣 News +* ⏳⏳⏳ Training a stronger model with the support of [Open-Sora-Plan](https://github.com/PKU-YuanGroup/Open-Sora-Plan) (e.g 257 x 512 × 512). +* ⏳⏳⏳ Release the training code of MagicTime. +* **[2024.04.10]** 🔥 We release the inference code, huggingface space and model weight of MagicTime. +* **[2024.04.09]** 🔥 We release the arXiv paper for MagicTime, and you can click [here](https://arxiv.org/abs/2404.05014) to see more details. +* **[2024.04.08]** 🔥 We released the subset of ChronoMagic dataset used to train MagicTime. The dataset includes 2,265 metamorphic video-text pairs and can be downloaded at [Google Drive](https://drive.google.com/drive/folders/1WsomdkmSp3ql3ImcNsmzFuSQ9Qukuyr8?usp=sharing). +* **[2024.04.08]** 🔥 **All codes & datasets** are coming soon! Stay tuned 👀! + +## 😮 Highlights + +MagicTime shows excellent performance in **metamorphic video generation**. + +### Metamorphic Videos vs. General Videos + +Compared to general videos, metamorphic videos contain physical knowledge, long persistence, and strong variation, making them difficult to generate. We show compressed .gif on github, which loses some quality. The general videos are generated by the [Animatediff](https://github.com/guoyww/AnimateDiff) and **MagicTime**. + + + + + + + + + + + + + + + + + + + + + + + +
Type
"Bean sprouts grow and mature from seeds"
"[...] construction in a Minecraft virtual environment"
"Cupcakes baking in an oven [...]"
"[...] transitioning from a tightly closed bud to a fully bloomed state [...]"
General VideosMakeLongVideoMakeLongVideoMakeLongVideoMakeLongVideo
Metamorphic VideosModelScopeT2VModelScopeT2VModelScopeT2VModelScopeT2V
+ +### Gallery + +We showcase some metamorphic videos generated by **MagicTime**, [MakeLongVideo](https://github.com/xuduo35/MakeLongVideo), [ModelScopeT2V](https://github.com/modelscope), [VideoCrafter](https://github.com/AILab-CVC/VideoCrafter?tab=readme-ov-file), [ZeroScope](https://huggingface.co/cerspense/zeroscope_v2_576w), [LaVie](https://github.com/Vchitect/LaVie), [T2V-Zero](https://github.com/Picsart-AI-Research/Text2Video-Zero), [Latte](https://github.com/Vchitect/Latte) and [Animatediff](https://github.com/guoyww/AnimateDiff) below. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Method
"cherry blossoms transitioning [...]"
"dough balls baking process [...]"
"an ice cube is melting [...]"
"a simple modern house's construction [...]"
MakeLongVideoMakeLongVideoMakeLongVideoMakeLongVideoMakeLongVideo
ModelScopeT2VModelScopeT2VModelScopeT2VModelScopeT2VModelScopeT2V
VideoCrafterVideoCrafterVideoCrafterVideoCrafterVideoCrafter
ZeroScopeZeroScopeZeroScopeZeroScopeZeroScope
LaVieLaVieLaVieLaVieLaVie
T2V-ZeroT2V-ZeroT2V-ZeroT2V-ZeroT2V-Zero
LatteLatteLatteLatteLatte
AnimatediffAnimatediffAnimatediffAnimatediffAnimatediff
OursOursOursOursOurs
+ + +We show more metamorphic videos generated by **MagicTime** with the help of [Realistic](https://civitai.com/models/4201/realistic-vision-v20), [ToonYou](https://civitai.com/models/30240/toonyou) and [RcnzCartoon](https://civitai.com/models/66347/rcnz-cartoon-3d). + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
RealisticRealisticRealistic
"[...] bean sprouts grow and mature from seeds"
"dough [...] swells and browns in the oven [...]"
"the construction [...] in Minecraft [...]"
RcnzCartoonRcnzCartoonRcnzCartoon
"a bud transforms into a yellow flower"
"time-lapse of a plant germinating [...]"
"[...] a modern house being constructed in Minecraft [...]"
ToonYouToonYouToonYou
"an ice cube is melting"
"bean plant sprouts grow and mature from the soil"
"time-lapse of delicate pink plum blossoms [...]"
+ +Prompts are trimmed for display, see [here](https://github.com/PKU-YuanGroup/MagicTime/blob/main/__assets__/promtp_unet.txt) for full prompts. +### Integrate into DiT-based Architecture + +The mission of this project is to help reproduce Sora and provide high-quality video-text data and data annotation pipelines, to support [Open-Sora-Plan](https://github.com/PKU-YuanGroup/Open-Sora-Plan) or other DiT-based T2V models. To this end, we take an initial step to integrate our MagicTime scheme into the DiT-based Framework. Specifically, our method supports the Open-Sora-Plan v1.0.0 for fine-tuning. We first scale up with additional metamorphic landscape time-lapse videos in the same annotation framework to get the ChronoMagic-Landscape dataset. Then, we fine-tune the Open-Sora-Plan v1.0.0 with the ChronoMagic-Landscape dataset to get the MagicTime-DiT model. The results are as follows (**257×512×512 (10s)**): + + + + + + + + + + + + + + + + + + + + + + + + + +
OpenSoraOpenSoraOpenSoraOpenSora
"Time-lapse of a coastal landscape [...]"
"Display the serene beauty of twilight [...]"
"Sunrise Splendor: Capture the breathtaking moment [...]"
"Nightfall Elegance: Embrace the tranquil beauty [...]"
OpenSoraOpenSoraOpenSoraOpenSora
"The sun descending below the horizon [...]"
"[...] daylight fades into the embrace of the night [...]"
"Time-lapse of the dynamic formations of clouds [...]"
"Capture the dynamic formations of clouds [...]"
+ +Prompts are trimmed for display, see [here](https://github.com/PKU-YuanGroup/MagicTime/blob/main/__assets__/promtp_opensora.txt) for full prompts. + +## 🤗 Demo + +### Gradio Web UI + +Highly recommend trying out our web demo by the following command, which incorporates all features currently supported by MagicTime. We also provide [online demo](https://github.com/PKU-YuanGroup/MagicTime) in Huggingface Spaces. + +```bash +python app.py +``` + +## ⚙️ Requirements and Installation + +We recommend the requirements as follows. + +```bash +git clone https://github.com/PKU-YuanGroup/MagicTime.git +cd MagicTime +conda env create -f environment.yml +conda activate magictime +``` + +## 🗝️ Training & Inference + +The training code is coming soon! For inference, some example are shown below: + +``` +# For [Realistic](https://civitai.com/models/4201/realistic-vision-v20) +python inference_magictime.py --config sample_configs/RealisticVision.yaml +# For [ToonYou](https://civitai.com/models/30240/toonyou) +python inference_magictime.py --config sample_configs/ToonYou.yaml +# For [RcnzCartoon](https://civitai.com/models/66347/rcnz-cartoon-3d) +python inference_magictime.py --config sample_configs/RcnzCartoon.yaml + +# or you can directly run the .sh +sh inference.sh +``` +## 🐳 ChronoMagic Dataset +ChronoMagic with 2265 metamorphic time-lapse videos, each accompanied by a detailed caption. We released the subset of ChronoMagic used to train MagicTime. The dataset can be downloaded at [Google Drive](https://drive.google.com/drive/folders/1WsomdkmSp3ql3ImcNsmzFuSQ9Qukuyr8?usp=sharing). Some samples can be found on our Project Page. + + +## 👍 Acknowledgement +* [Animatediff](https://github.com/guoyww/AnimateDiff/tree/main) The codebase we built upon and it is a strong U-Net-based text-to-video generation model. + +* [Open-Sora-Plan](https://github.com/PKU-YuanGroup/Open-Sora-Plan) The codebase we built upon and it is a simple and scalable DiT-based text-to-video generation repo, to reproduce [Sora](https://openai.com/sora). + +## 🔒 License +* The majority of this project is released under the Apache 2.0 license as found in the [LICENSE](https://github.com/PKU-YuanGroup/MagicTime/blob/main/LICENSE) file. +* The service is a research preview intended for non-commercial use only. Please contact us if you find any potential violations. + + + +## ✏️ Citation +If you find our paper and code useful in your research, please consider giving a star :star: and citation :pencil:. + +```BibTeX +@misc{yuan2024magictime, + title={MagicTime: Time-lapse Video Generation Models as Metamorphic Simulators}, + author={Shenghai Yuan and Jinfa Huang and Yujun Shi and Yongqi Xu and Ruijie Zhu and Bin Lin and Xinhua Cheng and Li Yuan and Jiebo Luo}, + year={2024}, + eprint={2404.05014}, + archivePrefix={arXiv}, + primaryClass={cs.CV} +} +``` + + + +## 🤝 Contributors + + + + diff --git a/__assets__/promtp_opensora.txt b/__assets__/promtp_opensora.txt new file mode 100644 index 0000000000000000000000000000000000000000..fed86471088bf60d99261ae5df17cddcaf52ed91 --- /dev/null +++ b/__assets__/promtp_opensora.txt @@ -0,0 +1,8 @@ +1. Time-lapse of a coastal landscape transitioning from sunrise to nightfall, with early morning light and soft shadows giving way to a clearer, bright midday sky, and later visible signs of sunset with orange hues and a dimming sky, culminating in a vibrant dusk. +2. Display the serene beauty of twilight, marking the transition from day to night with subtle changes in lighting. +3. Sunrise Splendor: Capture the breathtaking moment as the sun peeks over the horizon, casting its warm hues across the landscape in a mesmerizing time-lapse. +4. Nightfall Elegance: Embrace the tranquil beauty of dusk as daylight fades into the embrace of the night, unveiling the twinkling stars against the darkening sky in a mesmerizing time-lapse spectacle. +5. The sun descending below the horizon at dusk. The video is a time-lapse showcasing the gradual dimming of daylight, leading to the onset of twilight. +6. Nightfall Elegance: Embrace the tranquil beauty of dusk as daylight fades into the embrace of the night, unveiling the twinkling stars against the darkening sky in a mesmerizing time-lapse spectacle. +7. Time-lapse of the dynamic formations of clouds, showcasing their continuous motion and evolution over the course of the video. +8. Capture the dynamic formations of clouds, showcasing their continuous motion and evolution over the course of the video. \ No newline at end of file diff --git a/__assets__/promtp_unet.txt b/__assets__/promtp_unet.txt new file mode 100644 index 0000000000000000000000000000000000000000..69dca4e43fad745002a9f83dfe8ddf46086927a8 --- /dev/null +++ b/__assets__/promtp_unet.txt @@ -0,0 +1,9 @@ +1. A time-lapse video of bean sprouts grow and mature from seeds. +2. Dough starts smooth, swells and browns in the oven, finishing as fully expanded, baked bread. +3. The construction of a simple modern house in Minecraft. As the construction progresses, the roof and walls are completed, and the area around the house is cleared and shaped. +4. A bud transforms into a yellow flower. +5. Time-lapse of a plant germinating and developing into a young plant with multiple true leaves in a container, showing progressive growth stages from bare soil to a full plant. +6. Time-lapse of a modern house being constructed in Minecraft, beginning with a basic structure and progressively adding roof details, and new sections. +7. An ice cube is melting. +8. Bean plant sprouts grow and mature from the soil. +9. Time-lapse of delicate pink plum blossoms transitioning from tightly closed buds to gently unfurling petals, revealing the intricate details of stamens and pistils within. \ No newline at end of file diff --git a/__assets__/videos/A_0_0.gif b/__assets__/videos/A_0_0.gif new file mode 100644 index 0000000000000000000000000000000000000000..963cdcb785f07372556c43d7cda1168c13931f32 --- /dev/null +++ b/__assets__/videos/A_0_0.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1eac8b22504f00becd4c636d2975bb82163e3fb80cd6e8cd30bb90dc6bbf69a7 +size 6526672 diff --git a/__assets__/videos/A_0_1.gif b/__assets__/videos/A_0_1.gif new file mode 100644 index 0000000000000000000000000000000000000000..7a0c2cae0785e5764f0437d754cde8e89aa01613 --- /dev/null +++ b/__assets__/videos/A_0_1.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:13e784e532ffd5a614a06ce25a0df601a18b863477c48616418452a4c0b5eede +size 5218645 diff --git a/__assets__/videos/A_0_2.gif b/__assets__/videos/A_0_2.gif new file mode 100644 index 0000000000000000000000000000000000000000..61376e1c2c348ed701e0958c9f47667e80728a5e --- /dev/null +++ b/__assets__/videos/A_0_2.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bc6b63554ee5c353c799abf4ac1779a7d05534cee940d4a4870807220cdaa154 +size 6962696 diff --git a/__assets__/videos/A_0_3.gif b/__assets__/videos/A_0_3.gif new file mode 100644 index 0000000000000000000000000000000000000000..a1ff0186d2011442404a6f51f8bca5b4068ed136 --- /dev/null +++ b/__assets__/videos/A_0_3.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e4d0bda84b236a36bb4e89b963022612b241c8975a26eb3f75aa018b9808fc38 +size 6616801 diff --git a/__assets__/videos/A_1_0.gif b/__assets__/videos/A_1_0.gif new file mode 100644 index 0000000000000000000000000000000000000000..46dc79eb73cb231dd56ba53eed6e4fe33d9c514f Binary files /dev/null and b/__assets__/videos/A_1_0.gif differ diff --git a/__assets__/videos/A_1_1.gif b/__assets__/videos/A_1_1.gif new file mode 100644 index 0000000000000000000000000000000000000000..81612f66c583e2c821b9203a0af9663eaf10e444 Binary files /dev/null and b/__assets__/videos/A_1_1.gif differ diff --git a/__assets__/videos/A_1_2.gif b/__assets__/videos/A_1_2.gif new file mode 100644 index 0000000000000000000000000000000000000000..d661d2c1482bd33b130a1b6f22de385eca32742b --- /dev/null +++ b/__assets__/videos/A_1_2.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:392452fabde0504fb191b2bc272fabb52283e278b8377e2df0aa72b09f56622f +size 1144875 diff --git a/__assets__/videos/A_1_3.gif b/__assets__/videos/A_1_3.gif new file mode 100644 index 0000000000000000000000000000000000000000..a00b988211f250aa95a51fd2ff65ef573b40b439 Binary files /dev/null and b/__assets__/videos/A_1_3.gif differ diff --git a/__assets__/videos/A_2_0.gif b/__assets__/videos/A_2_0.gif new file mode 100644 index 0000000000000000000000000000000000000000..2ba6ab312052dbcee519841244120dd1fc35de28 --- /dev/null +++ b/__assets__/videos/A_2_0.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:da5ad193404d30d9cef96e4d1df7ee8e4d5335c0232314117e27ff90309d7b3e +size 1790967 diff --git a/__assets__/videos/A_2_1.gif b/__assets__/videos/A_2_1.gif new file mode 100644 index 0000000000000000000000000000000000000000..8e12b72a4524be6e05fbc3b36f3a548eca952170 --- /dev/null +++ b/__assets__/videos/A_2_1.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b164d9bc80eea42e85373acfc8c846ac377f2fdb3d33342a40884957fbc94f82 +size 2464424 diff --git a/__assets__/videos/A_2_2.gif b/__assets__/videos/A_2_2.gif new file mode 100644 index 0000000000000000000000000000000000000000..b25e5794456dbda8f51907ff053882eb666a1f81 --- /dev/null +++ b/__assets__/videos/A_2_2.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:af1d8ffafc579430caa9c7cd58edb2150c902dc8a3cc8d9a2cb80a91fb71c893 +size 2256277 diff --git a/__assets__/videos/A_2_3.gif b/__assets__/videos/A_2_3.gif new file mode 100644 index 0000000000000000000000000000000000000000..efaadc81dbcd2d0db53d30a2afec8360e92784db --- /dev/null +++ b/__assets__/videos/A_2_3.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:987c4a60e697bdda016b9befcd759db45e49ccd1495dc1aafc33275dc92a8046 +size 1820743 diff --git a/__assets__/videos/A_3_0.gif b/__assets__/videos/A_3_0.gif new file mode 100644 index 0000000000000000000000000000000000000000..fffc681beb39f12241285b5feac4d560193bb44e --- /dev/null +++ b/__assets__/videos/A_3_0.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:75a95ae5a7e25d80fabfd5049aa9f6bb7502eeee7d2fc1c99c1b2685beafdd04 +size 4658479 diff --git a/__assets__/videos/A_3_1.gif b/__assets__/videos/A_3_1.gif new file mode 100644 index 0000000000000000000000000000000000000000..f9e0df1b7f4db4bb8a7d59bb5cfca3e975c9ef43 --- /dev/null +++ b/__assets__/videos/A_3_1.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d0ee200e69474cdaa17a968de8c748f75babd1c6937256d7980a0bff317cd98f +size 3443377 diff --git a/__assets__/videos/A_3_2.gif b/__assets__/videos/A_3_2.gif new file mode 100644 index 0000000000000000000000000000000000000000..d17b3a7c47affd1acbf940e3cf1e874dcb00dac2 --- /dev/null +++ b/__assets__/videos/A_3_2.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cdbe67a44656a53f98cddb8ec122464c666b18e0e68d2513279ff3bdbdb8b641 +size 5312763 diff --git a/__assets__/videos/A_3_3.gif b/__assets__/videos/A_3_3.gif new file mode 100644 index 0000000000000000000000000000000000000000..507dd54b29d6452a33d65ff49711074ff2f92711 --- /dev/null +++ b/__assets__/videos/A_3_3.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:80d65914df12cce5daa9cfb01421bbcf3cfc8264af05ac3e24450866ca968080 +size 2493412 diff --git a/__assets__/videos/A_4_0.gif b/__assets__/videos/A_4_0.gif new file mode 100644 index 0000000000000000000000000000000000000000..4366d19fee0813197b9aad1e81ab5e3d9cb42fea --- /dev/null +++ b/__assets__/videos/A_4_0.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:58dc34187cdbcddb059ab3495a86984c8a1e252f661b3763ab0f53af78538b7b +size 2482440 diff --git a/__assets__/videos/A_4_1.gif b/__assets__/videos/A_4_1.gif new file mode 100644 index 0000000000000000000000000000000000000000..f4112e4e7b9b2fafbe7741d9f05fcbb63d1847bd --- /dev/null +++ b/__assets__/videos/A_4_1.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e6fafd0f8572996f925f60afa0de31ac36e2774ceacb02d896daefa37a939383 +size 2211870 diff --git a/__assets__/videos/A_4_2.gif b/__assets__/videos/A_4_2.gif new file mode 100644 index 0000000000000000000000000000000000000000..32ffe3852b8b1c848acb5803067a7a9a4d7e1713 --- /dev/null +++ b/__assets__/videos/A_4_2.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5cf07c1f97fb3dd13ef0a33bc5210c834dff3f39ef7a5b421015abb7c1d8fd11 +size 1682209 diff --git a/__assets__/videos/A_4_3.gif b/__assets__/videos/A_4_3.gif new file mode 100644 index 0000000000000000000000000000000000000000..cbafbb228ae35c36bb576be0e1306400158281b6 --- /dev/null +++ b/__assets__/videos/A_4_3.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bb6fd84e6b2ba48d26691c8e816b15caf4f8d6b10d4218cecbd8573bb911ecd7 +size 2042406 diff --git a/__assets__/videos/A_5_0.gif b/__assets__/videos/A_5_0.gif new file mode 100644 index 0000000000000000000000000000000000000000..cc12a7fc177dab7f21c26533c4a5c3982955da14 --- /dev/null +++ b/__assets__/videos/A_5_0.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7c3f1e3bb98b5a3406153b473fd586eb7b4d91d555b792eb228004b471bdca73 +size 1044986 diff --git a/__assets__/videos/A_5_1.gif b/__assets__/videos/A_5_1.gif new file mode 100644 index 0000000000000000000000000000000000000000..f2585ec98d84162154a309f5287a4513d10b80c3 --- /dev/null +++ b/__assets__/videos/A_5_1.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:834b814f38a9a610a20298c386454eed71cee1adb5aa667b8949b4296eb386a3 +size 1242803 diff --git a/__assets__/videos/A_5_2.gif b/__assets__/videos/A_5_2.gif new file mode 100644 index 0000000000000000000000000000000000000000..c8e368abb4a163845a24d893f9895206e40bc55f Binary files /dev/null and b/__assets__/videos/A_5_2.gif differ diff --git a/__assets__/videos/A_5_3.gif b/__assets__/videos/A_5_3.gif new file mode 100644 index 0000000000000000000000000000000000000000..802a56084a9f7e3cb2b2c8125179500feb084d96 --- /dev/null +++ b/__assets__/videos/A_5_3.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a03b64c3bab9e8243a51ef2a5f3e6fef15579009047d581ee4d2393af96d79b3 +size 1116225 diff --git a/__assets__/videos/A_6_0.gif b/__assets__/videos/A_6_0.gif new file mode 100644 index 0000000000000000000000000000000000000000..65f3000f682e22da3e3fd0ca7550be14a849c172 --- /dev/null +++ b/__assets__/videos/A_6_0.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1efa40f8513a4df3b4a3f8959080a53167385c4b7a96b9afb7db3c95def8296d +size 2870942 diff --git a/__assets__/videos/A_6_1.gif b/__assets__/videos/A_6_1.gif new file mode 100644 index 0000000000000000000000000000000000000000..4cae729d624d18102d1eed9c8a57dddce2df9779 --- /dev/null +++ b/__assets__/videos/A_6_1.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7b563ac805da60714585387d58eff4fe067246278ee913f51602c2fd82b359c3 +size 2228330 diff --git a/__assets__/videos/A_6_2.gif b/__assets__/videos/A_6_2.gif new file mode 100644 index 0000000000000000000000000000000000000000..86b457a30b22e652713b11120ec56a9cb23e3b53 --- /dev/null +++ b/__assets__/videos/A_6_2.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b07f858c9817847d6ab3590b1ac4514b155a4c38968e11864aaddf88afaf708b +size 1596501 diff --git a/__assets__/videos/A_6_3.gif b/__assets__/videos/A_6_3.gif new file mode 100644 index 0000000000000000000000000000000000000000..d9a3b42d112e16d04a54bc1c2482f5c4e1ecf7dd --- /dev/null +++ b/__assets__/videos/A_6_3.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:723620f40711810eb836968b17fd964449e1cb0b83e77173bb5a9f591a83204b +size 1663100 diff --git a/__assets__/videos/A_7_0.gif b/__assets__/videos/A_7_0.gif new file mode 100644 index 0000000000000000000000000000000000000000..1685120707d657895a3b5ecf9e96bec329020524 --- /dev/null +++ b/__assets__/videos/A_7_0.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:89e0a4ee53163370ec03d2156fb11b2f3f297260e80304984356330b435c9e00 +size 3576935 diff --git a/__assets__/videos/A_7_1.gif b/__assets__/videos/A_7_1.gif new file mode 100644 index 0000000000000000000000000000000000000000..c199cd52d6d60f04a8d6314a528dbf5266604ea5 --- /dev/null +++ b/__assets__/videos/A_7_1.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:43f6fd96ecd88e1f71ea9b41ea4a3358ee1470c0e24a36eaffda30651d03391e +size 3299672 diff --git a/__assets__/videos/A_7_2.gif b/__assets__/videos/A_7_2.gif new file mode 100644 index 0000000000000000000000000000000000000000..8f11a2d9a9b775825f07d4ba863975c528983b6c --- /dev/null +++ b/__assets__/videos/A_7_2.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7d95fae4befcb912fff253f173e8aa613c45be7156846a79157fd6862cc830f9 +size 3023564 diff --git a/__assets__/videos/A_7_3.gif b/__assets__/videos/A_7_3.gif new file mode 100644 index 0000000000000000000000000000000000000000..99a0926f558b5a32f62516a44421637d415a38c6 --- /dev/null +++ b/__assets__/videos/A_7_3.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fe4925e0404ec5ff0798a02d0a7096aed6d72ed67458f893f81b66119838e08d +size 2892320 diff --git a/__assets__/videos/A_8_0.gif b/__assets__/videos/A_8_0.gif new file mode 100644 index 0000000000000000000000000000000000000000..eb9c38b41a944c1a28bb9f2945bb27d8912cd69f --- /dev/null +++ b/__assets__/videos/A_8_0.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:49001b5e95e260b95309f092b41c7b6b1230584e52e530cbe998e8742685321f +size 3365336 diff --git a/__assets__/videos/A_8_1.gif b/__assets__/videos/A_8_1.gif new file mode 100644 index 0000000000000000000000000000000000000000..5000b6d22b73c6f0657781db8decc691c08cfa39 --- /dev/null +++ b/__assets__/videos/A_8_1.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2aaa2da4ad79a1bdef64751e14f8edc2792b4a2946a20941ec4e658846e73bf6 +size 2568019 diff --git a/__assets__/videos/A_8_2.gif b/__assets__/videos/A_8_2.gif new file mode 100644 index 0000000000000000000000000000000000000000..faaec953a24b2731771d238d2088c5d9c986a4ea --- /dev/null +++ b/__assets__/videos/A_8_2.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:56a9dd65216f8c70ffe929bd0d85f1f70f0d24c8d1d579cfd877c746fcffdfd9 +size 3574173 diff --git a/__assets__/videos/A_8_3.gif b/__assets__/videos/A_8_3.gif new file mode 100644 index 0000000000000000000000000000000000000000..7080125938afd03137d12d37c9f2060f8fbcf046 --- /dev/null +++ b/__assets__/videos/A_8_3.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9b544ee2072a69d559b59f252223a56b36c0a1b5a1d6b535612f497358286a16 +size 2412240 diff --git a/__assets__/videos/B_0_0.gif b/__assets__/videos/B_0_0.gif new file mode 100644 index 0000000000000000000000000000000000000000..d91b7e7ebfcf506d0beedc698604620a23a37a23 --- /dev/null +++ b/__assets__/videos/B_0_0.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b8783cba1f8a2436703694687093dd68bb19f2fb9ed682a9b928e4eda39942da +size 2892677 diff --git a/__assets__/videos/B_0_1.gif b/__assets__/videos/B_0_1.gif new file mode 100644 index 0000000000000000000000000000000000000000..16e801e3d06aa938f0127bc31b8d99329e302994 --- /dev/null +++ b/__assets__/videos/B_0_1.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:353c385468e8ce61332cb3d7215c5b6646a7b868145cebdbb706e92ddf19c443 +size 2154477 diff --git a/__assets__/videos/B_0_2.gif b/__assets__/videos/B_0_2.gif new file mode 100644 index 0000000000000000000000000000000000000000..6aeea92cf1477b23ca5d8db5841cfc6511428dde --- /dev/null +++ b/__assets__/videos/B_0_2.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:61fe6a79a561e32dd8944caddd4666e4ea541aca704b4d2cf2a06ec49f77fe14 +size 2320644 diff --git a/__assets__/videos/B_1_0.gif b/__assets__/videos/B_1_0.gif new file mode 100644 index 0000000000000000000000000000000000000000..54b289c7bc8caa40ebdc0ebfe405485b56975321 --- /dev/null +++ b/__assets__/videos/B_1_0.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:71fd54477be5294e467ae0709707797a6b1e55744aa6dbe6a1c6665d480320f4 +size 2267673 diff --git a/__assets__/videos/B_1_1.gif b/__assets__/videos/B_1_1.gif new file mode 100644 index 0000000000000000000000000000000000000000..1fc47a2adbe90b5aa12e4c908b2a72102d59ca4c --- /dev/null +++ b/__assets__/videos/B_1_1.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2522806d32dd1cf169538629d944b406968aa316cf64140063dc96cbcf7d9ed8 +size 2776910 diff --git a/__assets__/videos/B_1_2.gif b/__assets__/videos/B_1_2.gif new file mode 100644 index 0000000000000000000000000000000000000000..911d6aaa223d0eb33c2ddf803c9adbd4a29cb2bb --- /dev/null +++ b/__assets__/videos/B_1_2.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9a20a027ffa6633c4269baef6a9186e59d46f2abcae04491239e778640d9773b +size 1953076 diff --git a/__assets__/videos/B_2_0.gif b/__assets__/videos/B_2_0.gif new file mode 100644 index 0000000000000000000000000000000000000000..2fa89d74ebdd7ae6a74128764661e3731c635192 --- /dev/null +++ b/__assets__/videos/B_2_0.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b7e8a2ce8835f6bad9bd2ec6156ce46087503f82b133b918dd58e84f0d03ca01 +size 1688429 diff --git a/__assets__/videos/B_2_1.gif b/__assets__/videos/B_2_1.gif new file mode 100644 index 0000000000000000000000000000000000000000..3801ba6c7bcbb7a3ffb51176fe7b81e05a5411ea --- /dev/null +++ b/__assets__/videos/B_2_1.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:985833b13c99f72ab3789046889486b392a2312cc862dd4ccf2e7ce6d76a3e94 +size 2312522 diff --git a/__assets__/videos/B_2_2.gif b/__assets__/videos/B_2_2.gif new file mode 100644 index 0000000000000000000000000000000000000000..5d65877327a2d229fc50465617e2035f5513ed33 --- /dev/null +++ b/__assets__/videos/B_2_2.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:54298f9159ec3fa52cb5d5ec6dcd1160afe40629df87ffc2850a519a4c58d89f +size 2238159 diff --git a/__assets__/videos/C_0_0.gif b/__assets__/videos/C_0_0.gif new file mode 100644 index 0000000000000000000000000000000000000000..7e6217b70242cddaed1235d8449702486475c425 --- /dev/null +++ b/__assets__/videos/C_0_0.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:827955632b5848a38fe56fb2e1a53f92782549672aabef3e6abf68e207e47e81 +size 12704702 diff --git a/__assets__/videos/C_0_1.gif b/__assets__/videos/C_0_1.gif new file mode 100644 index 0000000000000000000000000000000000000000..4890d1308cd731dbedeb4c2fcdfefbdbcce2f337 --- /dev/null +++ b/__assets__/videos/C_0_1.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:597a3578e38e5da6aedc933ac018f7c219de702fe4380c58e4c4c4417939dd0c +size 10272195 diff --git a/__assets__/videos/C_0_2.gif b/__assets__/videos/C_0_2.gif new file mode 100644 index 0000000000000000000000000000000000000000..e7fb606024aa27a3677cb4305c50a3829ec1276a --- /dev/null +++ b/__assets__/videos/C_0_2.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b31d3a60f21365b642e4c9cdd853c02e9882cf47cb3f3baa9bbf56d52d74f869 +size 10363078 diff --git a/__assets__/videos/C_0_3.gif b/__assets__/videos/C_0_3.gif new file mode 100644 index 0000000000000000000000000000000000000000..4cf01f1ce61e0ce4e355d08d0284e961e68b15bb --- /dev/null +++ b/__assets__/videos/C_0_3.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:64749d3890b81f3d11eade39efdab0c8d849e4f04bccf7ebf415d3a7b4719035 +size 5947407 diff --git a/__assets__/videos/C_1_0.gif b/__assets__/videos/C_1_0.gif new file mode 100644 index 0000000000000000000000000000000000000000..aad341c969d2fe16a0e74a03ba891b0d0931e728 --- /dev/null +++ b/__assets__/videos/C_1_0.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8193d1dc70d047a7a4d51e28ef06e01485c0955aaaa8b1e9912d28826ff01c24 +size 11574579 diff --git a/__assets__/videos/C_1_1.gif b/__assets__/videos/C_1_1.gif new file mode 100644 index 0000000000000000000000000000000000000000..472c00db0c3a7f8aa0a9fd0743968e350aa452e2 --- /dev/null +++ b/__assets__/videos/C_1_1.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:55e45a6080602861cdcba9850be9e45a56853bb73d94a29d3a777604fea48592 +size 6640919 diff --git a/__assets__/videos/C_1_2.gif b/__assets__/videos/C_1_2.gif new file mode 100644 index 0000000000000000000000000000000000000000..9d2928601b0d154c5d9b13d3ded885680cd6d071 --- /dev/null +++ b/__assets__/videos/C_1_2.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7d19ca8ee2f153b1f42df0d23973aff31151ab9e929b29f4bedf772a310d49b2 +size 10115996 diff --git a/__assets__/videos/C_1_3.gif b/__assets__/videos/C_1_3.gif new file mode 100644 index 0000000000000000000000000000000000000000..cc4dd8b0062c270ff847f28fefcaeeb93deab088 --- /dev/null +++ b/__assets__/videos/C_1_3.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d2371b5125f4a74eb947131c7424268271af5fc42e3d4c14fb1f9c700717551f +size 4296170 diff --git a/__assets__/videos/D_0_0.gif b/__assets__/videos/D_0_0.gif new file mode 100644 index 0000000000000000000000000000000000000000..81d8b4c1312f5690f79de53d61a8a0d4ac18cdc0 --- /dev/null +++ b/__assets__/videos/D_0_0.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:918e39283c050016f8feb00ccf8348c40a1a24311f05d72883e4e4caebf468a1 +size 19015097 diff --git a/__assets__/videos/D_0_1.gif b/__assets__/videos/D_0_1.gif new file mode 100644 index 0000000000000000000000000000000000000000..66d83f1f541ab1d660d6cfa87ecbfb7ecfcba9e7 --- /dev/null +++ b/__assets__/videos/D_0_1.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1622c8c89087406a36a947cf432e6db1baea44557ea65001c38d94cc262888b2 +size 15739042 diff --git a/__assets__/videos/D_0_2.gif b/__assets__/videos/D_0_2.gif new file mode 100644 index 0000000000000000000000000000000000000000..93d31248aac24c6e2eb1d000042e10df03716037 --- /dev/null +++ b/__assets__/videos/D_0_2.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e6cfda3b5e36f31f79eeba591ec4c3eae1dd54d26a8321ec3bcd9a5106d9b71a +size 20727265 diff --git a/__assets__/videos/D_0_3.gif b/__assets__/videos/D_0_3.gif new file mode 100644 index 0000000000000000000000000000000000000000..ec805a4a9fba9418aedab3cc2467ab25517fadce --- /dev/null +++ b/__assets__/videos/D_0_3.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:73e94710459bc78b1d4a4d58a3c76cea81c34aa52759cc6b5e48248f74f8ffa3 +size 21313215 diff --git a/__assets__/videos/D_0_4.gif b/__assets__/videos/D_0_4.gif new file mode 100644 index 0000000000000000000000000000000000000000..c1e4177fa05013bc8f93803fe29242389d66035e --- /dev/null +++ b/__assets__/videos/D_0_4.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f152adc7f86e2dc2ee146b3d04eafe64cf46238f0bd81e4743972cb200e76f0c +size 16048005 diff --git a/__assets__/videos/D_0_5.gif b/__assets__/videos/D_0_5.gif new file mode 100644 index 0000000000000000000000000000000000000000..568a8ba94c8d2dc7021b6ca272c7992b1038c2cb --- /dev/null +++ b/__assets__/videos/D_0_5.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2400b3b827e0bbc5bdc920e0cdc03ae89c2243c0b3127f4d25e155a82543cd70 +size 20472620 diff --git a/__assets__/videos/D_0_6.gif b/__assets__/videos/D_0_6.gif new file mode 100644 index 0000000000000000000000000000000000000000..20883b5f4a75fab62ebdf563ad18a29299ef80ca --- /dev/null +++ b/__assets__/videos/D_0_6.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0685ce6746b69c1fa2bac56f5dda30ba8803651f2037edefacf67f32adef8402 +size 18960444 diff --git a/__assets__/videos/D_0_7.gif b/__assets__/videos/D_0_7.gif new file mode 100644 index 0000000000000000000000000000000000000000..ed313d059ff7376e71a36e6e7e36947f8e71ed24 --- /dev/null +++ b/__assets__/videos/D_0_7.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6f10b497a76e6d30336d1cb67702abd5b582a39a082858293d8162175b556d2f +size 16191285 diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..411535731aa364042e31cf4567e6eea25b55e423 --- /dev/null +++ b/app.py @@ -0,0 +1,246 @@ +import os +import copy +import torch +import random +import gradio as gr +from glob import glob +from omegaconf import OmegaConf +from safetensors import safe_open +from diffusers import AutoencoderKL +from diffusers import EulerDiscreteScheduler, DDIMScheduler +from diffusers.utils.import_utils import is_xformers_available +from transformers import CLIPTextModel, CLIPTokenizer + +from utils.unet import UNet3DConditionModel +from utils.pipeline_magictime import MagicTimePipeline +from utils.util import save_videos_grid, convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint, load_diffusers_lora_unet, convert_ldm_clip_text_model + +pretrained_model_path = "./ckpts/Base_Model/stable-diffusion-v1-5" +inference_config_path = "./sample_configs/RealisticVision.yaml" +magic_adapter_s_path = "./ckpts/Magic_Weights/magic_adapter_s/magic_adapter_s.ckpt" +magic_adapter_t_path = "./ckpts/Magic_Weights/magic_adapter_t" +magic_text_encoder_path = "./ckpts/Magic_Weights/magic_text_encoder" + +css = """ +.toolbutton { + margin-buttom: 0em 0em 0em 0em; + max-width: 2.5em; + min-width: 2.5em !important; + height: 2.5em; +} +""" + +examples = [ + # 1-ToonYou + [ + "ToonYou_beta6.safetensors", + "motion_module.ckpt", + "Bean sprouts grow and mature from seeds.", + "worst quality, low quality, letterboxed", + 512, 512, "13204175718326964000" + ], + # 2-RCNZ + [ + "RcnzCartoon.safetensors", + "motion_module.ckpt", + "Time-lapse of a simple modern house's construction in a Minecraft virtual environment: beginning with an avatar laying a white foundation, progressing through wall erection and interior furnishing, to adding roof and exterior details, and completed with landscaping and a tall chimney.", + "worst quality, low quality, letterboxed", + 512, 512, "1268480012" + ], + # 3-RealisticVision + [ + "RealisticVisionV60B1_v51VAE.safetensors", + "motion_module.ckpt", + "Cherry blossoms transitioning from tightly closed buds to a peak state of bloom. The progression moves through stages of bud swelling, petal exposure, and gradual opening, culminating in a full and vibrant display of open blossoms.", + "worst quality, low quality, letterboxed", + 512, 512, "2038801077" + ] +] + +# clean Grdio cache +print(f"### Cleaning cached examples ...") +os.system(f"rm -rf gradio_cached_examples/") + + +class MagicTimeController: + def __init__(self): + + # config dirs + self.basedir = os.getcwd() + self.stable_diffusion_dir = os.path.join(self.basedir, "ckpts", "Base_Model") + self.motion_module_dir = os.path.join(self.basedir, "ckpts", "Base_Model", "motion_module") + self.personalized_model_dir = os.path.join(self.basedir, "ckpts", "DreamBooth") + self.savedir = os.path.join(self.basedir, "outputs") + os.makedirs(self.savedir, exist_ok=True) + + self.dreambooth_list = [] + self.motion_module_list = [] + + self.selected_dreambooth = None + self.selected_motion_module = None + + self.refresh_motion_module() + self.refresh_personalized_model() + + # config models + self.inference_config = OmegaConf.load(inference_config_path)[1] + + self.tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer") + self.text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder").cuda() + self.vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae").cuda() + self.unet = UNet3DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(self.inference_config.unet_additional_kwargs)).cuda() + + self.text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14") + + self.update_dreambooth(self.dreambooth_list[0]) + self.update_motion_module(self.motion_module_list[0]) + + from swift import Swift + magic_adapter_s_state_dict = torch.load(magic_adapter_s_path, map_location="cpu") + self.unet = load_diffusers_lora_unet(self.unet, magic_adapter_s_state_dict, alpha=1.0) + self.unet = Swift.from_pretrained(self.unet, magic_adapter_t_path) + self.text_encoder = Swift.from_pretrained(self.text_encoder, magic_text_encoder_path) + + + def refresh_motion_module(self): + motion_module_list = glob(os.path.join(self.motion_module_dir, "*.ckpt")) + self.motion_module_list = [os.path.basename(p) for p in motion_module_list] + + def refresh_personalized_model(self): + dreambooth_list = glob(os.path.join(self.personalized_model_dir, "*.safetensors")) + self.dreambooth_list = [os.path.basename(p) for p in dreambooth_list] + + def update_dreambooth(self, dreambooth_dropdown): + self.selected_dreambooth = dreambooth_dropdown + + dreambooth_dropdown = os.path.join(self.personalized_model_dir, dreambooth_dropdown) + dreambooth_state_dict = {} + with safe_open(dreambooth_dropdown, framework="pt", device="cpu") as f: + for key in f.keys(): dreambooth_state_dict[key] = f.get_tensor(key) + + converted_vae_checkpoint = convert_ldm_vae_checkpoint(dreambooth_state_dict, self.vae.config) + self.vae.load_state_dict(converted_vae_checkpoint) + + converted_unet_checkpoint = convert_ldm_unet_checkpoint(dreambooth_state_dict, self.unet.config) + self.unet.load_state_dict(converted_unet_checkpoint, strict=False) + + text_model = copy.deepcopy(self.text_model) + self.text_encoder = convert_ldm_clip_text_model(text_model, dreambooth_state_dict) + return gr.Dropdown() + + def update_motion_module(self, motion_module_dropdown): + self.selected_motion_module = motion_module_dropdown + motion_module_dropdown = os.path.join(self.motion_module_dir, motion_module_dropdown) + motion_module_state_dict = torch.load(motion_module_dropdown, map_location="cpu") + _, unexpected = self.unet.load_state_dict(motion_module_state_dict, strict=False) + assert len(unexpected) == 0 + return gr.Dropdown() + + + def magictime( + self, + dreambooth_dropdown, + motion_module_dropdown, + prompt_textbox, + negative_prompt_textbox, + width_slider, + height_slider, + seed_textbox, + ): + if self.selected_dreambooth != dreambooth_dropdown: self.update_dreambooth(dreambooth_dropdown) + if self.selected_motion_module != motion_module_dropdown: self.update_motion_module(motion_module_dropdown) + + if is_xformers_available(): self.unet.enable_xformers_memory_efficient_attention() + + pipeline = MagicTimePipeline( + vae=self.vae, text_encoder=self.text_encoder, tokenizer=self.tokenizer, unet=self.unet, + scheduler=DDIMScheduler(**OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs)) + ).to("cuda") + + if int(seed_textbox) > 0: seed = int(seed_textbox) + else: seed = random.randint(1, 1e16) + torch.manual_seed(int(seed)) + + assert seed == torch.initial_seed() + print(f"### seed: {seed}") + + generator = torch.Generator(device="cuda") + generator.manual_seed(seed) + + sample = pipeline( + prompt_textbox, + negative_prompt = negative_prompt_textbox, + num_inference_steps = 25, + guidance_scale = 8., + width = width_slider, + height = height_slider, + video_length = 16, + generator = generator, + ).videos + + save_sample_path = os.path.join(self.savedir, f"sample.mp4") + save_videos_grid(sample, save_sample_path) + + json_config = { + "prompt": prompt_textbox, + "n_prompt": negative_prompt_textbox, + "width": width_slider, + "height": height_slider, + "seed": seed, + "dreambooth": dreambooth_dropdown, + } + return gr.Video(value=save_sample_path), gr.Json(value=json_config) + +controller = MagicTimeController() + + +def ui(): + with gr.Blocks(css=css) as demo: + gr.Markdown( + """ +

MagicTime: Time-lapse Video Generation Models as Metamorphic Simulators

+
If you like our project, please give us a star ⭐ on GitHub for the latest update.
+ + [GitHub](https://img.shields.io/github/stars/PKU-YuanGroup/MagicTime) | [arXiv](https://arxiv.org/abs/2404.05014) | [Home Page](https://pku-yuangroup.github.io/MagicTime/) | [Dataset](https://drive.google.com/drive/folders/1WsomdkmSp3ql3ImcNsmzFuSQ9Qukuyr8?usp=sharing) + """ + ) + with gr.Row(): + with gr.Column(): + dreambooth_dropdown = gr.Dropdown( label="DreamBooth Model", choices=controller.dreambooth_list, value=controller.dreambooth_list[0], interactive=True ) + motion_module_dropdown = gr.Dropdown( label="Motion Module", choices=controller.motion_module_list, value=controller.motion_module_list[0], interactive=True ) + + dreambooth_dropdown.change(fn=controller.update_dreambooth, inputs=[dreambooth_dropdown], outputs=[dreambooth_dropdown]) + motion_module_dropdown.change(fn=controller.update_motion_module, inputs=[motion_module_dropdown], outputs=[motion_module_dropdown]) + + prompt_textbox = gr.Textbox( label="Prompt", lines=3 ) + negative_prompt_textbox = gr.Textbox( label="Negative Prompt", lines=3, value="worst quality, low quality, nsfw, logo") + + with gr.Accordion("Advance", open=False): + with gr.Row(): + width_slider = gr.Slider( label="Width", value=512, minimum=256, maximum=1024, step=64 ) + height_slider = gr.Slider( label="Height", value=512, minimum=256, maximum=1024, step=64 ) + with gr.Row(): + seed_textbox = gr.Textbox( label="Seed", value=-1) + seed_button = gr.Button(value="\U0001F3B2", elem_classes="toolbutton") + seed_button.click(fn=lambda: gr.Textbox.update(value=random.randint(1, 1e16)), inputs=[], outputs=[seed_textbox]) + + generate_button = gr.Button( value="Generate", variant='primary' ) + + with gr.Column(): + result_video = gr.Video( label="Generated Animation", interactive=False ) + json_config = gr.Json( label="Config", value=None ) + + inputs = [dreambooth_dropdown, motion_module_dropdown, prompt_textbox, negative_prompt_textbox, width_slider, height_slider, seed_textbox] + outputs = [result_video, json_config] + + generate_button.click( fn=controller.magictime, inputs=inputs, outputs=outputs ) + + gr.Examples( fn=controller.magictime, examples=examples, inputs=inputs, outputs=outputs, cache_examples=True ) + + return demo + + +if __name__ == "__main__": + demo = ui() + demo.queue(max_size=20) + demo.launch() \ No newline at end of file diff --git a/ckpts/Base_Model/base_model_path.txt b/ckpts/Base_Model/base_model_path.txt new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ckpts/DreamBooth/dreambooth_path.txt b/ckpts/DreamBooth/dreambooth_path.txt new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ckpts/Magic_Weights/magic_weights_path.txt b/ckpts/Magic_Weights/magic_weights_path.txt new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/environment.yml b/environment.yml new file mode 100644 index 0000000000000000000000000000000000000000..319b2eb63150d52bcc35500a62bb43864604f572 --- /dev/null +++ b/environment.yml @@ -0,0 +1,232 @@ +name: magictime +channels: + - pytorch + - nvidia + - defaults +dependencies: + - _libgcc_mutex=0.1=main + - _openmp_mutex=5.1=1_gnu + - blas=1.0=mkl + - brotli-python=1.0.9=py310h6a678d5_7 + - bzip2=1.0.8=h5eee18b_5 + - ca-certificates=2024.3.11=h06a4308_0 + - certifi=2024.2.2=py310h06a4308_0 + - charset-normalizer=2.0.4=pyhd3eb1b0_0 + - cuda-cudart=11.7.99=0 + - cuda-cupti=11.7.101=0 + - cuda-libraries=11.7.1=0 + - cuda-nvrtc=11.7.99=0 + - cuda-nvtx=11.7.91=0 + - cuda-runtime=11.7.1=0 + - ffmpeg=4.3=hf484d3e_0 + - freetype=2.12.1=h4a9f257_0 + - gmp=6.2.1=h295c915_3 + - gnutls=3.6.15=he1e5248_0 + - idna=3.4=py310h06a4308_0 + - intel-openmp=2023.1.0=hdb19cb5_46306 + - jpeg=9e=h5eee18b_1 + - lame=3.100=h7b6447c_0 + - lcms2=2.12=h3be6417_0 + - ld_impl_linux-64=2.38=h1181459_1 + - lerc=3.0=h295c915_0 + - libcublas=11.10.3.66=0 + - libcufft=10.7.2.124=h4fbf590_0 + - libcufile=1.9.0.20=0 + - libcurand=10.3.5.119=0 + - libcusolver=11.4.0.1=0 + - libcusparse=11.7.4.91=0 + - libdeflate=1.17=h5eee18b_1 + - libffi=3.4.4=h6a678d5_0 + - libgcc-ng=11.2.0=h1234567_1 + - libgomp=11.2.0=h1234567_1 + - libiconv=1.16=h7f8727e_2 + - libidn2=2.3.4=h5eee18b_0 + - libnpp=11.7.4.75=0 + - libnvjpeg=11.8.0.2=0 + - libpng=1.6.39=h5eee18b_0 + - libstdcxx-ng=11.2.0=h1234567_1 + - libtasn1=4.19.0=h5eee18b_0 + - libtiff=4.5.1=h6a678d5_0 + - libunistring=0.9.10=h27cfd23_0 + - libuuid=1.41.5=h5eee18b_0 + - libwebp-base=1.3.2=h5eee18b_0 + - lz4-c=1.9.4=h6a678d5_0 + - mkl=2023.1.0=h213fc3f_46344 + - mkl-service=2.4.0=py310h5eee18b_1 + - mkl_fft=1.3.8=py310h5eee18b_0 + - mkl_random=1.2.4=py310hdb19cb5_0 + - ncurses=6.4=h6a678d5_0 + - nettle=3.7.3=hbbd107a_1 + - numpy=1.26.4=py310h5f9d8c6_0 + - numpy-base=1.26.4=py310hb5e798b_0 + - openh264=2.1.1=h4ff587b_0 + - openjpeg=2.4.0=h3ad879b_0 + - openssl=3.0.13=h7f8727e_0 + - pillow=10.2.0=py310h5eee18b_0 + - pip=23.3.1=py310h06a4308_0 + - pysocks=1.7.1=py310h06a4308_0 + - python=3.10.13=h955ad1f_0 + - pytorch=1.13.1=py3.10_cuda11.7_cudnn8.5.0_0 + - pytorch-cuda=11.7=h778d358_5 + - pytorch-mutex=1.0=cuda + - readline=8.2=h5eee18b_0 + - requests=2.31.0=py310h06a4308_1 + - setuptools=68.2.2=py310h06a4308_0 + - sqlite=3.41.2=h5eee18b_0 + - tbb=2021.8.0=hdb19cb5_0 + - tk=8.6.12=h1ccaba5_0 + - torchaudio=0.13.1=py310_cu117 + - torchvision=0.14.1=py310_cu117 + - typing_extensions=4.9.0=py310h06a4308_1 + - urllib3=2.1.0=py310h06a4308_1 + - wheel=0.41.2=py310h06a4308_0 + - xz=5.4.6=h5eee18b_0 + - zlib=1.2.13=h5eee18b_0 + - zstd=1.5.5=hc292b87_0 + - pip: + - absl-py==2.1.0 + - accelerate==0.28.0 + - addict==2.4.0 + - aiofiles==23.2.1 + - aiohttp==3.9.3 + - aiosignal==1.3.1 + - aliyun-python-sdk-core==2.15.0 + - aliyun-python-sdk-kms==2.16.2 + - altair==5.2.0 + - annotated-types==0.6.0 + - antlr4-python3-runtime==4.9.3 + - anyio==4.3.0 + - appdirs==1.4.4 + - async-timeout==4.0.3 + - attrs==23.2.0 + - av==12.0.0 + - beautifulsoup4==4.12.3 + - cffi==1.16.0 + - click==8.1.7 + - colorama==0.4.6 + - coloredlogs==15.0.1 + - contourpy==1.2.0 + - crcmod==1.7 + - cryptography==42.0.5 + - cycler==0.12.1 + - dacite==1.8.1 + - datasets==2.18.0 + - decord==0.6.0 + - diffusers==0.11.1 + - dill==0.3.8 + - docker-pycreds==0.4.0 + - docstring-parser==0.16 + - einops==0.7.0 + - exceptiongroup==1.2.0 + - fastapi==0.110.0 + - ffmpy==0.3.2 + - filelock==3.13.1 + - fonttools==4.50.0 + - frozenlist==1.4.1 + - fsspec==2024.2.0 + - gast==0.5.4 + - gdown==5.1.0 + - gitdb==4.0.11 + - gitpython==3.1.42 + - gradio==4.26.0 + - gradio-client==0.15.1 + - grpcio==1.62.1 + - h11==0.14.0 + - httpcore==1.0.4 + - httpx==0.27.0 + - huggingface-hub==0.21.4 + - humanfriendly==10.0 + - imageio==2.27.0 + - imageio-ffmpeg==0.4.9 + - importlib-metadata==7.0.2 + - importlib-resources==6.3.1 + - jieba==0.42.1 + - jinja2==3.1.3 + - jmespath==0.10.0 + - joblib==1.3.2 + - jsonschema==4.21.1 + - jsonschema-specifications==2023.12.1 + - kiwisolver==1.4.5 + - markdown==3.6 + - markdown-it-py==3.0.0 + - markupsafe==2.1.5 + - matplotlib==3.8.3 + - mdurl==0.1.2 + - modelscope==1.13.1 + - mpmath==1.3.0 + - ms-swift==1.7.3 + - multidict==6.0.5 + - multiprocess==0.70.16 + - mypy-extensions==1.0.0 + - nltk==3.8.1 + - omegaconf==2.3.0 + - optimum==1.17.1 + - orjson==3.9.15 + - oss2==2.18.4 + - packaging==24.0 + - pandas==2.2.1 + - peft==0.9.0 + - platformdirs==4.2.0 + - protobuf==4.25.3 + - psutil==5.9.8 + - pyarrow==15.0.2 + - pyarrow-hotfix==0.6 + - pycparser==2.21 + - pycryptodome==3.20.0 + - pydantic==2.6.4 + - pydantic-core==2.16.3 + - pydub==0.25.1 + - pygments==2.17.2 + - pyparsing==3.1.2 + - pyre-extensions==0.0.23 + - python-dateutil==2.9.0.post0 + - python-multipart==0.0.9 + - pytz==2024.1 + - pyyaml==6.0.1 + - referencing==0.34.0 + - regex==2023.12.25 + - rich==13.7.1 + - rouge==1.0.1 + - rpds-py==0.18.0 + - ruff==0.3.3 + - safetensors==0.4.2 + - scipy==1.12.0 + - semantic-version==2.10.0 + - sentencepiece==0.2.0 + - sentry-sdk==1.42.0 + - setproctitle==1.3.3 + - shellingham==1.5.4 + - shtab==1.7.1 + - simplejson==3.19.2 + - six==1.16.0 + - smmap==5.0.1 + - sniffio==1.3.1 + - sortedcontainers==2.4.0 + - soupsieve==2.5 + - starlette==0.36.3 + - sympy==1.12 + - tensorboard==2.16.2 + - tensorboard-data-server==0.7.2 + - tokenizers==0.15.2 + - tomli==2.0.1 + - tomlkit==0.12.0 + - toolz==0.12.1 + - tqdm==4.66.2 + - transformers==4.38.2 + - transformers-stream-generator==0.0.5 + - triton==2.2.0 + - trl==0.7.11 + - typer==0.9.0 + - typing-inspect==0.9.0 + - tyro==0.7.3 + - tzdata==2024.1 + - uvicorn==0.28.0 + - wandb==0.16.4 + - websockets==11.0.3 + - werkzeug==3.0.1 + - xformers==0.0.16 + - xxhash==3.4.1 + - yapf==0.40.2 + - yarl==1.9.4 + - zipp==3.18.1 +prefix: /home/ysh/miniconda/envs/ad_new diff --git a/inference.sh b/inference.sh new file mode 100644 index 0000000000000000000000000000000000000000..a4f9754a531bdf333d9f5b82f3bf9cf3c4179134 --- /dev/null +++ b/inference.sh @@ -0,0 +1,2 @@ +CUDA_VISIBLE_DEVICES=0 python inference_magictime.py \ + --config sample_configs/RealisticVision.yaml \ No newline at end of file diff --git a/inference_magictime.py b/inference_magictime.py new file mode 100644 index 0000000000000000000000000000000000000000..c1375c34a4d227e1f179e5918d9fdf3d23aa22d2 --- /dev/null +++ b/inference_magictime.py @@ -0,0 +1,249 @@ +import os +import json +import time +import torch +import random +import inspect +import argparse +import numpy as np +import pandas as pd +from pathlib import Path +from omegaconf import OmegaConf +from transformers import CLIPTextModel, CLIPTokenizer +from diffusers import AutoencoderKL, DDIMScheduler +from diffusers.utils.import_utils import is_xformers_available + +from utils.unet import UNet3DConditionModel +from utils.pipeline_magictime import MagicTimePipeline +from utils.util import save_videos_grid +from utils.util import load_weights + +@torch.no_grad() +def main(args): + *_, func_args = inspect.getargvalues(inspect.currentframe()) + func_args = dict(func_args) + + if 'counter' not in globals(): + globals()['counter'] = 0 + unique_id = globals()['counter'] + globals()['counter'] += 1 + savedir_base = f"{Path(args.config).stem}" + savedir_prefix = "outputs" + savedir = None + if args.save_path: + savedir = os.path.join(savedir_prefix, args.save_path, f"{savedir_base}-{unique_id}") + else: + savedir = os.path.join(savedir_prefix, f"{savedir_base}-{unique_id}") + while os.path.exists(savedir): + unique_id = globals()['counter'] + globals()['counter'] += 1 + if args.save_path: + savedir = os.path.join(savedir_prefix, args.save_path, f"{savedir_base}-{unique_id}") + else: + savedir = os.path.join(savedir_prefix, f"{savedir_base}-{unique_id}") + os.makedirs(savedir) + print(f"The results will be save to {savedir}") + + model_config = OmegaConf.load(args.config)[0] + inference_config = OmegaConf.load(args.config)[1] + + if model_config.magic_adapter_s_path: + print("Use MagicAdapter-S") + if model_config.magic_adapter_t_path: + print("Use MagicAdapter-T") + if model_config.magic_text_encoder_path: + print("Use Magic_Text_Encoder") + + samples = [] + + # create validation pipeline + tokenizer = CLIPTokenizer.from_pretrained(model_config.pretrained_model_path, subfolder="tokenizer") + text_encoder = CLIPTextModel.from_pretrained(model_config.pretrained_model_path, subfolder="text_encoder").cuda() + vae = AutoencoderKL.from_pretrained(model_config.pretrained_model_path, subfolder="vae").cuda() + unet = UNet3DConditionModel.from_pretrained_2d(model_config.pretrained_model_path, subfolder="unet", + unet_additional_kwargs=OmegaConf.to_container( + inference_config.unet_additional_kwargs)).cuda() + + # set xformers + if is_xformers_available() and (not args.without_xformers): + unet.enable_xformers_memory_efficient_attention() + + pipeline = MagicTimePipeline( + vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, + scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)), + ).to("cuda") + + pipeline = load_weights( + pipeline, + motion_module_path=model_config.get("motion_module", ""), + dreambooth_model_path=model_config.get("dreambooth_path", ""), + magic_adapter_s_path=model_config.get("magic_adapter_s_path", ""), + magic_adapter_t_path=model_config.get("magic_adapter_t_path", ""), + magic_text_encoder_path=model_config.get("magic_text_encoder_path", ""), + ).to("cuda") + + sample_idx = 0 + if args.human: + sample_idx = 0 # Initialize sample index + while True: + user_prompt = input("Enter your prompt (or type 'exit' to quit): ") + if user_prompt.lower() == "exit": + break + + random_seed = torch.randint(0, 2 ** 32 - 1, (1,)).item() + torch.manual_seed(random_seed) + + print(f"current seed: {random_seed}") + print(f"sampling {user_prompt} ...") + + # Now, you directly use `user_prompt` to generate a video. + # The following is a placeholder call; you need to adapt it to your actual video generation function. + sample = pipeline( + user_prompt, + num_inference_steps=model_config.steps, + guidance_scale=model_config.guidance_scale, + width=model_config.W, + height=model_config.H, + video_length=model_config.L, + ).videos + + # Adapt the filename to avoid conflicts and properly represent the content + prompt_for_filename = "-".join(user_prompt.replace("/", "").split(" ")[:10]) + save_videos_grid(sample, f"{savedir}/sample/{sample_idx}-{random_seed}-{prompt_for_filename}.gif") + print(f"save to {savedir}/sample/{sample_idx}-{random_seed}-{prompt_for_filename}.gif") + + sample_idx += 1 + elif args.run_csv: + print("run_csv") + file_path = args.run_csv + data = pd.read_csv(file_path) + for index, row in data.iterrows(): + user_prompt = row['name'] # Set the user_prompt to the 'name' field of the current row + videoid = row['videoid'] # Extract videoid for filename + + random_seed = torch.randint(0, 2 ** 32 - 1, (1,)).item() + torch.manual_seed(random_seed) + + print(f"current seed: {random_seed}") + print(f"sampling {user_prompt} ...") + + sample = pipeline( + user_prompt, + num_inference_steps=model_config.steps, + guidance_scale=model_config.guidance_scale, + width=model_config.W, + height=model_config.H, + video_length=model_config.L, + ).videos + + # Adapt the filename to avoid conflicts and properly represent the content + save_videos_grid(sample, f"{savedir}/sample/{videoid}.gif") + print(f"save to {savedir}/sample/{videoid}.gif") + elif args.run_json: + print("run_json") + file_path = args.run_json + + with open(file_path, 'r') as file: + data = json.load(file) + + prompts = [] + videoids = [] + senids = [] + + for item in data: + prompts.append(item['caption']) + videoids.append(item['video_id']) + senids.append(item['sen_id']) + + n_prompts = list(model_config.n_prompt) * len(prompts) if len( + model_config.n_prompt) == 1 else model_config.n_prompt + + random_seeds = model_config.get("seed", [-1]) + random_seeds = [random_seeds] if isinstance(random_seeds, int) else list(random_seeds) + random_seeds = random_seeds * len(prompts) if len(random_seeds) == 1 else random_seeds + + model_config.random_seed = [] + for prompt_idx, (prompt, n_prompt, random_seed) in enumerate(zip(prompts, n_prompts, random_seeds)): + filename = f"MSRVTT/sample/{videoids[prompt_idx]}-{senids[prompt_idx]}.gif" + + if os.path.exists(filename): + print(f"File {filename} already exists, skipping...") + continue + + # manually set random seed for reproduction + if random_seed != -1: + torch.manual_seed(random_seed) + else: + torch.seed() + model_config.random_seed.append(torch.initial_seed()) + + print(f"current seed: {torch.initial_seed()}") + print(f"sampling {prompt} ...") + + sample = pipeline( + prompt, + num_inference_steps=model_config.steps, + guidance_scale=model_config.guidance_scale, + width=model_config.W, + height=model_config.H, + video_length=model_config.L, + ).videos + + # Adapt the filename to avoid conflicts and properly represent the content + save_videos_grid(sample, filename) + print(f"save to {filename}") + else: + prompts = model_config.prompt + n_prompts = list(model_config.n_prompt) * len(prompts) if len( + model_config.n_prompt) == 1 else model_config.n_prompt + + random_seeds = model_config.get("seed", [-1]) + random_seeds = [random_seeds] if isinstance(random_seeds, int) else list(random_seeds) + random_seeds = random_seeds * len(prompts) if len(random_seeds) == 1 else random_seeds + + model_config.random_seed = [] + for prompt_idx, (prompt, n_prompt, random_seed) in enumerate(zip(prompts, n_prompts, random_seeds)): + + # manually set random seed for reproduction + if random_seed != -1: + torch.manual_seed(random_seed) + np.random.seed(random_seed) + random.seed(random_seed) + else: + torch.seed() + model_config.random_seed.append(torch.initial_seed()) + + print(f"current seed: {torch.initial_seed()}") + print(f"sampling {prompt} ...") + sample = pipeline( + prompt, + negative_prompt=n_prompt, + num_inference_steps=model_config.steps, + guidance_scale=model_config.guidance_scale, + width=model_config.W, + height=model_config.H, + video_length=model_config.L, + ).videos + samples.append(sample) + + prompt = "-".join((prompt.replace("/", "").split(" ")[:10])) + save_videos_grid(sample, f"{savedir}/sample/{sample_idx}-{random_seed}-{prompt}.gif") + print(f"save to {savedir}/sample/{random_seed}-{prompt}.gif") + + sample_idx += 1 + samples = torch.concat(samples) + save_videos_grid(samples, f"{savedir}/merge_all.gif", n_rows=4) + + OmegaConf.save(model_config, f"{savedir}/model_config.yaml") + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--config", type=str, required=True) + parser.add_argument("--without-xformers", action="store_true") + parser.add_argument("--human", action="store_true", help="Enable human mode for interactive video generation") + parser.add_argument("--run-csv", type=str, default=None) + parser.add_argument("--run-json", type=str, default=None) + parser.add_argument("--save-path", type=str, default=None) + + args = parser.parse_args() + main(args) diff --git a/requirement.txt b/requirement.txt new file mode 100644 index 0000000000000000000000000000000000000000..e1c7fedf51fac0da29edb555c9c598e718661566 --- /dev/null +++ b/requirement.txt @@ -0,0 +1,161 @@ +absl-py==2.1.0 +accelerate==0.28.0 +addict==2.4.0 +aiofiles==23.2.1 +aiohttp==3.9.3 +aiosignal==1.3.1 +aliyun-python-sdk-core==2.15.0 +aliyun-python-sdk-kms==2.16.2 +altair==5.2.0 +annotated-types==0.6.0 +antlr4-python3-runtime==4.9.3 +anyio==4.3.0 +appdirs==1.4.4 +async-timeout==4.0.3 +attrs==23.2.0 +av==12.0.0 +beautifulsoup4==4.12.3 +Brotli @ file:///tmp/abs_ecyw11_7ze/croots/recipe/brotli-split_1659616059936/work +certifi @ file:///croot/certifi_1707229174982/work/certifi +cffi==1.16.0 +charset-normalizer @ file:///tmp/build/80754af9/charset-normalizer_1630003229654/work +click==8.1.7 +colorama==0.4.6 +coloredlogs==15.0.1 +contourpy==1.2.0 +crcmod==1.7 +cryptography==42.0.5 +cycler==0.12.1 +dacite==1.8.1 +datasets==2.18.0 +decord==0.6.0 +diffusers==0.11.1 +dill==0.3.8 +docker-pycreds==0.4.0 +docstring_parser==0.16 +einops==0.7.0 +exceptiongroup==1.2.0 +fastapi==0.110.0 +ffmpy==0.3.2 +filelock==3.13.1 +fonttools==4.50.0 +frozenlist==1.4.1 +fsspec==2024.2.0 +gast==0.5.4 +gdown==5.1.0 +gitdb==4.0.11 +GitPython==3.1.42 +gradio==4.26.0 +gradio_client==0.15.1 +grpcio==1.62.1 +h11==0.14.0 +httpcore==1.0.4 +httpx==0.27.0 +huggingface-hub==0.21.4 +humanfriendly==10.0 +idna @ file:///croot/idna_1666125576474/work +imageio==2.27.0 +imageio-ffmpeg==0.4.9 +importlib_metadata==7.0.2 +importlib_resources==6.3.1 +jieba==0.42.1 +Jinja2==3.1.3 +jmespath==0.10.0 +joblib==1.3.2 +jsonschema==4.21.1 +jsonschema-specifications==2023.12.1 +kiwisolver==1.4.5 +Markdown==3.6 +markdown-it-py==3.0.0 +MarkupSafe==2.1.5 +matplotlib==3.8.3 +mdurl==0.1.2 +mkl-fft @ file:///croot/mkl_fft_1695058164594/work +mkl-random @ file:///croot/mkl_random_1695059800811/work +mkl-service==2.4.0 +modelscope==1.13.1 +mpmath==1.3.0 +ms-swift==1.7.3 +multidict==6.0.5 +multiprocess==0.70.16 +mypy-extensions==1.0.0 +nltk==3.8.1 +numpy @ file:///croot/numpy_and_numpy_base_1708638617955/work/dist/numpy-1.26.4-cp310-cp310-linux_x86_64.whl#sha256=d8cd837ed43e87f77e6efaa08e8de927ca030a1c9c5d04624432d6fb9a74a5ee +omegaconf==2.3.0 +optimum==1.17.1 +orjson==3.9.15 +oss2==2.18.4 +packaging==24.0 +pandas==2.2.1 +peft==0.9.0 +pillow @ file:///croot/pillow_1707233021655/work +platformdirs==4.2.0 +protobuf==4.25.3 +psutil==5.9.8 +pyarrow==15.0.2 +pyarrow-hotfix==0.6 +pycparser==2.21 +pycryptodome==3.20.0 +pydantic==2.6.4 +pydantic_core==2.16.3 +pydub==0.25.1 +Pygments==2.17.2 +pyparsing==3.1.2 +pyre-extensions==0.0.23 +PySocks @ file:///home/builder/ci_310/pysocks_1640793678128/work +python-dateutil==2.9.0.post0 +python-multipart==0.0.9 +pytz==2024.1 +PyYAML==6.0.1 +referencing==0.34.0 +regex==2023.12.25 +requests @ file:///croot/requests_1707355572290/work +rich==13.7.1 +rouge==1.0.1 +rpds-py==0.18.0 +ruff==0.3.3 +safetensors==0.4.2 +scipy==1.12.0 +semantic-version==2.10.0 +sentencepiece==0.2.0 +sentry-sdk==1.42.0 +setproctitle==1.3.3 +shellingham==1.5.4 +shtab==1.7.1 +simplejson==3.19.2 +six==1.16.0 +smmap==5.0.1 +sniffio==1.3.1 +sortedcontainers==2.4.0 +soupsieve==2.5 +starlette==0.36.3 +sympy==1.12 +tensorboard==2.16.2 +tensorboard-data-server==0.7.2 +tokenizers==0.15.2 +tomli==2.0.1 +tomlkit==0.12.0 +toolz==0.12.1 +torch==1.13.1 +torchaudio==0.13.1 +torchvision==0.14.1 +tqdm==4.66.2 +transformers==4.38.2 +transformers-stream-generator==0.0.5 +triton==2.2.0 +trl==0.7.11 +typer==0.9.0 +typing-inspect==0.9.0 +typing_extensions @ file:///croot/typing_extensions_1705599297034/work +tyro==0.7.3 +tzdata==2024.1 +urllib3 @ file:///croot/urllib3_1707770551213/work +uvicorn==0.28.0 +wandb==0.16.4 +websockets==11.0.3 +Werkzeug==3.0.1 +xformers==0.0.16 +xxhash==3.4.1 +yapf==0.40.2 +yarl==1.9.4 +zipp==3.18.1 diff --git a/sample_configs/RcnzCartoon.yaml b/sample_configs/RcnzCartoon.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c8a3b91842ed3f862494bdc5c53b832a931619b4 --- /dev/null +++ b/sample_configs/RcnzCartoon.yaml @@ -0,0 +1,49 @@ +- pretrained_model_path: "./ckpts/Base_Model/stable-diffusion-v1-5" + motion_module: "./ckpts/Base_Model/motion_module/motion_module.ckpt" + dreambooth_path: "./ckpts/DreamBooth/RcnzCartoon.safetensors" + magic_adapter_s_path: "./ckpts/Magic_Weights/magic_adapter_s/magic_adapter_s.ckpt" + magic_adapter_t_path: "./ckpts/Magic_Weights/magic_adapter_t" + magic_text_encoder_path: "./ckpts/Magic_Weights/magic_text_encoder" + + H: 512 + W: 512 + L: 16 + seed: [1268480012, 3480796026, 3607977321, 1601344133] + steps: 25 + guidance_scale: 8.5 + + prompt: + - "Time-lapse of a simple modern house's construction in a Minecraft virtual environment: beginning with an avatar laying a white foundation, progressing through wall erection and interior furnishing, to adding roof and exterior details, and completed with landscaping and a tall chimney." + - "Time-lapse of a simple modern house's construction in a Minecraft virtual environment: beginning with an avatar laying a white foundation, progressing through wall erection and interior furnishing, to adding roof and exterior details, and completed with landscaping and a tall chimney." + - "Bean sprouts grow and mature from seeds." + - "Time-lapse of a yellow ranunculus flower transitioning from a tightly closed bud to a fully bloomed state, with measured petal separation and unfurling observed across the sequence." + + n_prompt: + - "worst quality, low quality, letterboxed" + +- unet_additional_kwargs: + use_inflated_groupnorm: true + use_motion_module: true + motion_module_resolutions: + - 1 + - 2 + - 4 + - 8 + motion_module_mid_block: false + motion_module_type: Vanilla + motion_module_kwargs: + num_attention_heads: 8 + num_transformer_block: 1 + attention_block_types: + - Temporal_Self + - Temporal_Self + temporal_position_encoding: true + temporal_position_encoding_max_len: 32 + temporal_attention_dim_div: 1 + zero_initialize: true + noise_scheduler_kwargs: + beta_start: 0.00085 + beta_end: 0.012 + beta_schedule: linear + steps_offset: 1 + clip_sample: false \ No newline at end of file diff --git a/sample_configs/RealisticVision.yaml b/sample_configs/RealisticVision.yaml new file mode 100644 index 0000000000000000000000000000000000000000..38a5b5bfd51f60deb1e15462588411b9938ca8db --- /dev/null +++ b/sample_configs/RealisticVision.yaml @@ -0,0 +1,49 @@ +- pretrained_model_path: "./ckpts/Base_Model/stable-diffusion-v1-5" + motion_module: "./ckpts/Base_Model/motion_module/motion_module.ckpt" + dreambooth_path: "./ckpts/DreamBooth/RealisticVisionV60B1_v51VAE.safetensors" + magic_adapter_s_path: "./ckpts/Magic_Weights/magic_adapter_s/magic_adapter_s.ckpt" + magic_adapter_t_path: "./ckpts/Magic_Weights/magic_adapter_t" + magic_text_encoder_path: "./ckpts/Magic_Weights/magic_text_encoder" + + H: 512 + W: 512 + L: 16 + seed: [1587796317, 2883629116, 3068368949, 2038801077] + steps: 25 + guidance_scale: 8.5 + + prompt: + - "Time-lapse of dough balls transforming into bread rolls: Begins with smooth, proofed dough, gradually expands in early baking, becomes taut and voluminous, and finally browns and fully expands to signal the baking's completion." + - "Time-lapse of cupcakes progressing through the baking process: starting from liquid batter in cupcake liners, gradually rising with the formation of domes, to fully baked cupcakes with golden, crackled domes." + - "Cherry blossoms transitioning from tightly closed buds to a peak state of bloom. The progression moves through stages of bud swelling, petal exposure, and gradual opening, culminating in a full and vibrant display of open blossoms." + - "Cherry blossoms transitioning from tightly closed buds to a peak state of bloom. The progression moves through stages of bud swelling, petal exposure, and gradual opening, culminating in a full and vibrant display of open blossoms." + + n_prompt: + - "worst quality, low quality, letterboxed" + +- unet_additional_kwargs: + use_inflated_groupnorm: true + use_motion_module: true + motion_module_resolutions: + - 1 + - 2 + - 4 + - 8 + motion_module_mid_block: false + motion_module_type: Vanilla + motion_module_kwargs: + num_attention_heads: 8 + num_transformer_block: 1 + attention_block_types: + - Temporal_Self + - Temporal_Self + temporal_position_encoding: true + temporal_position_encoding_max_len: 32 + temporal_attention_dim_div: 1 + zero_initialize: true + noise_scheduler_kwargs: + beta_start: 0.00085 + beta_end: 0.012 + beta_schedule: linear + steps_offset: 1 + clip_sample: false \ No newline at end of file diff --git a/sample_configs/ToonYou.yaml b/sample_configs/ToonYou.yaml new file mode 100644 index 0000000000000000000000000000000000000000..67fdc8b4f9e4e1d037f23578f8593b590b4e1ece --- /dev/null +++ b/sample_configs/ToonYou.yaml @@ -0,0 +1,49 @@ +- pretrained_model_path: "./ckpts/Base_Model/stable-diffusion-v1-5" + motion_module: "./ckpts/Base_Model/motion_module/motion_module.ckpt" + dreambooth_path: "./ckpts/DreamBooth/ToonYou_beta6.safetensors" + magic_adapter_s_path: "./ckpts/Magic_Weights/magic_adapter_s/magic_adapter_s.ckpt" + magic_adapter_t_path: "./ckpts/Magic_Weights/magic_adapter_t" + magic_text_encoder_path: "./ckpts/Magic_Weights/magic_text_encoder" + + H: 512 + W: 512 + L: 16 + seed: [3832738942, 153403692, 10789633, 1496541313] + steps: 25 + guidance_scale: 8.5 + + prompt: + - "An ice cube is melting." + - "A mesmerizing time-lapse showcasing the elegant unfolding of pink plum buds blossoms, capturing the gradual bloom from tightly sealed buds to fully open flowers." + - "Time-lapse of a yellow ranunculus flower transitioning from a tightly closed bud to a fully bloomed state, with measured petal separation and unfurling observed across the sequence." + - "Bean sprouts grow and mature from seeds." + + n_prompt: + - "worst quality, low quality, letterboxed" + +- unet_additional_kwargs: + use_inflated_groupnorm: true + use_motion_module: true + motion_module_resolutions: + - 1 + - 2 + - 4 + - 8 + motion_module_mid_block: false + motion_module_type: Vanilla + motion_module_kwargs: + num_attention_heads: 8 + num_transformer_block: 1 + attention_block_types: + - Temporal_Self + - Temporal_Self + temporal_position_encoding: true + temporal_position_encoding_max_len: 32 + temporal_attention_dim_div: 1 + zero_initialize: true + noise_scheduler_kwargs: + beta_start: 0.00085 + beta_end: 0.012 + beta_schedule: linear + steps_offset: 1 + clip_sample: false \ No newline at end of file diff --git a/utils/__pycache__/pipeline_magictime.cpython-310.pyc b/utils/__pycache__/pipeline_magictime.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ce0c92b8fbc741c80f660e6f83323d1127db40ca Binary files /dev/null and b/utils/__pycache__/pipeline_magictime.cpython-310.pyc differ diff --git a/utils/__pycache__/unet.cpython-310.pyc b/utils/__pycache__/unet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dd12967154a0f7384c2f180fb541d25732b100ff Binary files /dev/null and b/utils/__pycache__/unet.cpython-310.pyc differ diff --git a/utils/__pycache__/unet_blocks.cpython-310.pyc b/utils/__pycache__/unet_blocks.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2002e7f0a9c840722265a758b4f701efed4670ef Binary files /dev/null and b/utils/__pycache__/unet_blocks.cpython-310.pyc differ diff --git a/utils/__pycache__/util.cpython-310.pyc b/utils/__pycache__/util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2de4c735caa328e744d48e17f3d1d6bf4138409f Binary files /dev/null and b/utils/__pycache__/util.cpython-310.pyc differ diff --git a/utils/dataset.py b/utils/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..e521d09f55e6bb96aef3c59fe91745b5901b078d --- /dev/null +++ b/utils/dataset.py @@ -0,0 +1,101 @@ +import os, csv, random +import numpy as np +from decord import VideoReader +import torch +import torchvision.transforms as transforms +from torch.utils.data.dataset import Dataset + + +class ChronoMagic(Dataset): + def __init__( + self, + csv_path, video_folder, + sample_size=512, sample_stride=4, sample_n_frames=16, + is_image=False, + is_uniform=True, + ): + with open(csv_path, 'r') as csvfile: + self.dataset = list(csv.DictReader(csvfile)) + self.length = len(self.dataset) + + self.video_folder = video_folder + self.sample_stride = sample_stride + self.sample_n_frames = sample_n_frames + self.is_image = is_image + self.is_uniform = is_uniform + + sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size) + self.pixel_transforms = transforms.Compose([ + transforms.RandomHorizontalFlip(), + transforms.Resize(sample_size[0], interpolation=transforms.InterpolationMode.BICUBIC), + transforms.CenterCrop(sample_size), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ]) + + def _get_frame_indices_adjusted(self, video_length, n_frames): + indices = list(range(video_length)) + additional_frames_needed = n_frames - video_length + + repeat_indices = [] + for i in range(additional_frames_needed): + index_to_repeat = i % video_length + repeat_indices.append(indices[index_to_repeat]) + + all_indices = indices + repeat_indices + all_indices.sort() + + return all_indices + + def _generate_frame_indices(self, video_length, n_frames, sample_stride, is_transmit): + prob_execute_original = 1 if int(is_transmit) == 0 else 0 + + # Generate a random number to decide which block of code to execute + if random.random() < prob_execute_original: + if video_length <= n_frames: + return self._get_frame_indices_adjusted(video_length, n_frames) + else: + interval = (video_length - 1) / (n_frames - 1) + indices = [int(round(i * interval)) for i in range(n_frames)] + indices[-1] = video_length - 1 + return indices + else: + if video_length <= n_frames: + return self._get_frame_indices_adjusted(video_length, n_frames) + else: + clip_length = min(video_length, (n_frames - 1) * sample_stride + 1) + start_idx = random.randint(0, video_length - clip_length) + return np.linspace(start_idx, start_idx + clip_length - 1, n_frames, dtype=int).tolist() + + def get_batch(self, idx): + video_dict = self.dataset[idx] + videoid, name, is_transmit = video_dict['videoid'], video_dict['name'], video_dict['is_transmit'] + + video_dir = os.path.join(self.video_folder, f"{videoid}.mp4") + video_reader = VideoReader(video_dir, num_threads=0) + video_length = len(video_reader) + + batch_index = self._generate_frame_indices(video_length, self.sample_n_frames, self.sample_stride, is_transmit) if not self.is_image else [random.randint(0, video_length - 1)] + + pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2) / 255. + del video_reader + + if self.is_image: + pixel_values = pixel_values[0] + + return pixel_values, name, videoid + + def __len__(self): + return self.length + + def __getitem__(self, idx): + while True: + try: + pixel_values, name, videoid = self.get_batch(idx) + break + + except Exception as e: + idx = random.randint(0, self.length-1) + + pixel_values = self.pixel_transforms(pixel_values) + sample = dict(pixel_values=pixel_values, text=name, id=videoid) + return sample \ No newline at end of file diff --git a/utils/pipeline_magictime.py b/utils/pipeline_magictime.py new file mode 100644 index 0000000000000000000000000000000000000000..394387fcfcefbd4452885cbb745d83c10b16d773 --- /dev/null +++ b/utils/pipeline_magictime.py @@ -0,0 +1,421 @@ +# Adapted from https://github.com/guoyww/AnimateDiff/animatediff/pipelines/pipeline_animation.py + +import torch +import inspect +import numpy as np +from tqdm import tqdm +from einops import rearrange +from packaging import version +from dataclasses import dataclass +from typing import Callable, List, Optional, Union +from transformers import CLIPTextModel, CLIPTokenizer + +from diffusers.utils import is_accelerate_available, deprecate, logging, BaseOutput +from diffusers.configuration_utils import FrozenDict +from diffusers.models import AutoencoderKL +from diffusers.pipeline_utils import DiffusionPipeline +from diffusers.schedulers import ( + DDIMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, +) + +from .unet import UNet3DConditionModel + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +@dataclass +class MagicTimePipelineOutput(BaseOutput): + videos: Union[torch.Tensor, np.ndarray] + +class MagicTimePipeline(DiffusionPipeline): + _optional_components = [] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet3DConditionModel, + scheduler: Union[ + DDIMScheduler, + PNDMScheduler, + LMSDiscreteScheduler, + EulerDiscreteScheduler, + EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler, + ], + ): + super().__init__() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + + def enable_vae_slicing(self): + self.vae.enable_slicing() + + def disable_vae_slicing(self): + self.vae.disable_slicing() + + def enable_sequential_cpu_offload(self, gpu_id=0): + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device(f"cuda:{gpu_id}") + + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + + + @property + def _execution_device(self): + if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + def _encode_prompt(self, prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt): + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + text_embeddings = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + text_embeddings = text_embeddings[0] + + # duplicate text embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = text_embeddings.shape + text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1) + text_embeddings = text_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + uncond_embeddings = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + uncond_embeddings = uncond_embeddings[0] + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = uncond_embeddings.shape[1] + uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1) + uncond_embeddings = uncond_embeddings.view(batch_size * num_videos_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + + return text_embeddings + + def decode_latents(self, latents): + video_length = latents.shape[2] + latents = 1 / 0.18215 * latents + latents = rearrange(latents, "b c f h w -> (b f) c h w") + # video = self.vae.decode(latents).sample + video = [] + for frame_idx in tqdm(range(latents.shape[0])): + video.append(self.vae.decode(latents[frame_idx:frame_idx+1]).sample) + video = torch.cat(video) + video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length) + video = (video / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + video = video.cpu().float().numpy() + return video + + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs(self, prompt, height, width, callback_steps): + if not isinstance(prompt, str) and not isinstance(prompt, list): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + if latents is None: + rand_device = "cpu" if device.type == "mps" else device + + if isinstance(generator, list): + shape = shape + # shape = (1,) + shape[1:] + latents = [ + torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) + for i in range(batch_size) + ] + latents = torch.cat(latents, dim=0).to(device) + else: + latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + video_length: Optional[int], + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_videos_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "tensor", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + **kwargs, + ): + # Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # Check inputs. Raise error if not correct + self.check_inputs(prompt, height, width, callback_steps) + + # Define call parameters + # batch_size = 1 if isinstance(prompt, str) else len(prompt) + batch_size = 1 + if latents is not None: + batch_size = latents.shape[0] + if isinstance(prompt, list): + batch_size = len(prompt) + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # Encode input prompt + prompt = prompt if isinstance(prompt, list) else [prompt] * batch_size + if negative_prompt is not None: + negative_prompt = negative_prompt if isinstance(negative_prompt, list) else [negative_prompt] * batch_size + text_embeddings = self._encode_prompt( + prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt + ) + + # Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # Prepare latent variables + num_channels_latents = self.unet.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + video_length, + height, + width, + text_embeddings.dtype, + device, + generator, + latents, + ) + latents_dtype = latents.dtype + + # Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + down_block_additional_residuals = mid_block_additional_residual = None + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, t, + encoder_hidden_states=text_embeddings, + down_block_additional_residuals = down_block_additional_residuals, + mid_block_additional_residual = mid_block_additional_residual, + ).sample.to(dtype=latents_dtype) + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # Post-processing + video = self.decode_latents(latents) + + # Convert to tensor + if output_type == "tensor": + video = torch.from_numpy(video) + + if not return_dict: + return video + + return MagicTimePipelineOutput(videos=video) diff --git a/utils/unet.py b/utils/unet.py new file mode 100644 index 0000000000000000000000000000000000000000..e72bd8b5ee05a7bf34750bcb698f02d921035e1e --- /dev/null +++ b/utils/unet.py @@ -0,0 +1,513 @@ +# Adapted from https://github.com/guoyww/AnimateDiff/animatediff/models/unet.py +import os +import json +import pdb +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.modeling_utils import ModelMixin +from diffusers.utils import BaseOutput, logging +from diffusers.models.embeddings import TimestepEmbedding, Timesteps +from .unet_blocks import ( + CrossAttnDownBlock3D, + CrossAttnUpBlock3D, + DownBlock3D, + UNetMidBlock3DCrossAttn, + UpBlock3D, + get_down_block, + get_up_block, + InflatedConv3d, + InflatedGroupNorm, +) + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class UNet3DConditionOutput(BaseOutput): + sample: torch.FloatTensor + + +class UNet3DConditionModel(ModelMixin, ConfigMixin): + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: Optional[int] = None, + in_channels: int = 4, + out_channels: int = 4, + center_input_sample: bool = False, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "DownBlock3D", + ), + mid_block_type: str = "UNetMidBlock3DCrossAttn", + up_block_types: Tuple[str] = ( + "UpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D" + ), + only_cross_attention: Union[bool, Tuple[bool]] = False, + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + layers_per_block: int = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: int = 32, + norm_eps: float = 1e-5, + cross_attention_dim: int = 1280, + attention_head_dim: Union[int, Tuple[int]] = 8, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + class_embed_type: Optional[str] = None, + num_class_embeds: Optional[int] = None, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + + use_inflated_groupnorm=False, + + # Additional + use_motion_module = False, + motion_module_resolutions = ( 1,2,4,8 ), + motion_module_mid_block = False, + motion_module_decoder_only = False, + motion_module_type = None, + motion_module_kwargs = {}, + unet_use_cross_frame_attention = False, + unet_use_temporal_attention = False, + ): + super().__init__() + + self.sample_size = sample_size + time_embed_dim = block_out_channels[0] * 4 + + # input + self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)) + + # time + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + + self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + + # class embedding + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + else: + self.class_embedding = None + + self.down_blocks = nn.ModuleList([]) + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + + if isinstance(only_cross_attention, bool): + only_cross_attention = [only_cross_attention] * len(down_block_types) + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + res = 2 ** i + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[i], + downsample_padding=downsample_padding, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + + unet_use_cross_frame_attention=unet_use_cross_frame_attention, + unet_use_temporal_attention=unet_use_temporal_attention, + use_inflated_groupnorm=use_inflated_groupnorm, + + use_motion_module=use_motion_module and (res in motion_module_resolutions) and (not motion_module_decoder_only), + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) + self.down_blocks.append(down_block) + + # mid + if mid_block_type == "UNetMidBlock3DCrossAttn": + self.mid_block = UNetMidBlock3DCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + + unet_use_cross_frame_attention=unet_use_cross_frame_attention, + unet_use_temporal_attention=unet_use_temporal_attention, + use_inflated_groupnorm=use_inflated_groupnorm, + + use_motion_module=use_motion_module and motion_module_mid_block, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) + else: + raise ValueError(f"unknown mid_block_type : {mid_block_type}") + + # count how many layers upsample the videos + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_attention_head_dim = list(reversed(attention_head_dim)) + only_cross_attention = list(reversed(only_cross_attention)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + res = 2 ** (3 - i) + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=layers_per_block + 1, + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=time_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=reversed_attention_head_dim[i], + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + + unet_use_cross_frame_attention=unet_use_cross_frame_attention, + unet_use_temporal_attention=unet_use_temporal_attention, + use_inflated_groupnorm=use_inflated_groupnorm, + + use_motion_module=use_motion_module and (res in motion_module_resolutions), + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + if use_inflated_groupnorm: + self.conv_norm_out = InflatedGroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps) + else: + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps) + self.conv_act = nn.SiLU() + self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1) + + def set_attention_slice(self, slice_size): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_slicable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_slicable_dims(module) + + num_slicable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_slicable_layers * [1] + + slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)): + module.gradient_checkpointing = value + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + + # support controlnet + down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + mid_block_additional_residual: Optional[torch.Tensor] = None, + + return_dict: bool = True, + ) -> Union[UNet3DConditionOutput, Tuple]: + r""" + Args: + sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor + timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps + encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + + Returns: + [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: + [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # prepare attention_mask + if attention_mask is not None: + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # center input if necessary + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + + # time + timesteps = timestep + if not torch.is_tensor(timesteps): + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=self.dtype) + emb = self.time_embedding(t_emb) + + if self.class_embedding is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) + emb = emb + class_emb + + # pre-process + sample = self.conv_in(sample) + + # down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states) + + down_block_res_samples += res_samples + + # support controlnet + down_block_res_samples = list(down_block_res_samples) + if down_block_additional_residuals is not None: + for i, down_block_additional_residual in enumerate(down_block_additional_residuals): + if down_block_additional_residual.dim() == 4: # boardcast + down_block_additional_residual = down_block_additional_residual.unsqueeze(2) + down_block_res_samples[i] = down_block_res_samples[i] + down_block_additional_residual + + # mid + sample = self.mid_block( + sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask + ) + + # support controlnet + if mid_block_additional_residual is not None: + if mid_block_additional_residual.dim() == 4: # boardcast + mid_block_additional_residual = mid_block_additional_residual.unsqueeze(2) + sample = sample + mid_block_additional_residual + + # up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + upsample_size=upsample_size, + attention_mask=attention_mask, + ) + else: + sample = upsample_block( + hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size, encoder_hidden_states=encoder_hidden_states, + ) + + # post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if not return_dict: + return (sample,) + + return UNet3DConditionOutput(sample=sample) + + @classmethod + def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, unet_additional_kwargs=None): + if subfolder is not None: + pretrained_model_path = os.path.join(pretrained_model_path, subfolder) + print(f"loaded 3D unet's pretrained weights from {pretrained_model_path} ...") + + config_file = os.path.join(pretrained_model_path, 'config.json') + if not os.path.isfile(config_file): + raise RuntimeError(f"{config_file} does not exist") + with open(config_file, "r") as f: + config = json.load(f) + config["_class_name"] = cls.__name__ + config["down_block_types"] = [ + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "DownBlock3D" + ] + config["up_block_types"] = [ + "UpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D" + ] + + from diffusers.utils import WEIGHTS_NAME + model = cls.from_config(config, **unet_additional_kwargs) + model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME) + if not os.path.isfile(model_file): + raise RuntimeError(f"{model_file} does not exist") + state_dict = torch.load(model_file, map_location="cpu") + + m, u = model.load_state_dict(state_dict, strict=False) + print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};") + + params = [p.numel() if "motion_modules." in n else 0 for n, p in model.named_parameters()] + print(f"### Motion Module Parameters: {sum(params) / 1e6} M") + + return model diff --git a/utils/unet_blocks.py b/utils/unet_blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..4c4208ec3136792d3dc01c25d175c8f2d1db50dd --- /dev/null +++ b/utils/unet_blocks.py @@ -0,0 +1,1549 @@ +# Adapted from https://github.com/guoyww/AnimateDiff/animatediff/models/unet_blocks.py +import torch +from torch import nn +import torch.nn.functional as F + +import math +from typing import Optional +from einops import rearrange, repeat +from dataclasses import dataclass + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.modeling_utils import ModelMixin +from diffusers.utils import BaseOutput +from diffusers.utils.import_utils import is_xformers_available +from diffusers.models.attention import CrossAttention, FeedForward, AdaLayerNorm + +# Attention +@dataclass +class Transformer3DModelOutput(BaseOutput): + sample: torch.FloatTensor + + +if is_xformers_available(): + import xformers + import xformers.ops +else: + xformers = None + + +class Transformer3DModel(ModelMixin, ConfigMixin): + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + + unet_use_cross_frame_attention=None, + unet_use_temporal_attention=None, + ): + super().__init__() + self.use_linear_projection = use_linear_projection + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + + # Define input layers + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) + if use_linear_projection: + self.proj_in = nn.Linear(in_channels, inner_dim) + else: + self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + + # Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + + unet_use_cross_frame_attention=unet_use_cross_frame_attention, + unet_use_temporal_attention=unet_use_temporal_attention, + ) + for d in range(num_layers) + ] + ) + + # 4. Define output layers + if use_linear_projection: + self.proj_out = nn.Linear(in_channels, inner_dim) + else: + self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True): + # Input + assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." + video_length = hidden_states.shape[2] + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") + encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length) + + batch, channel, height, weight = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm(hidden_states) + if not self.use_linear_projection: + hidden_states = self.proj_in(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) + else: + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) + hidden_states = self.proj_in(hidden_states) + + # Blocks + for block in self.transformer_blocks: + hidden_states = block( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + video_length=video_length + ) + + # Output + if not self.use_linear_projection: + hidden_states = ( + hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() + ) + hidden_states = self.proj_out(hidden_states) + else: + hidden_states = self.proj_out(hidden_states) + hidden_states = ( + hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() + ) + + output = hidden_states + residual + + output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length) + if not return_dict: + return (output,) + + return Transformer3DModelOutput(sample=output) + + +class BasicTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + + unet_use_cross_frame_attention = None, + unet_use_temporal_attention = None, + ): + super().__init__() + self.only_cross_attention = only_cross_attention + self.use_ada_layer_norm = num_embeds_ada_norm is not None + self.unet_use_cross_frame_attention = unet_use_cross_frame_attention + self.unet_use_temporal_attention = unet_use_temporal_attention + + # SC-Attn + assert unet_use_cross_frame_attention is not None + self.attn1 = CrossAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + ) + + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + + # Cross-Attn + if cross_attention_dim is not None: + self.attn2 = CrossAttention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + ) + else: + self.attn2 = None + + if cross_attention_dim is not None: + self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + else: + self.norm2 = None + + # Feed-forward + self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) + self.norm3 = nn.LayerNorm(dim) + + # Temp-Attn + assert unet_use_temporal_attention is not None + if unet_use_temporal_attention: + self.attn_temp = CrossAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + ) + nn.init.zeros_(self.attn_temp.to_out[0].weight.data) + self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + + def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): + if not is_xformers_available(): + print("Here is how to install it") + raise ModuleNotFoundError( + "Refer to https://github.com/facebookresearch/xformers for more information on how to install" + " xformers", + name="xformers", + ) + elif not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only" + " available for GPU " + ) + else: + try: + # Make sure we can run the memory efficient attention + _ = xformers.ops.memory_efficient_attention( + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + ) + except Exception as e: + raise e + self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers + if self.attn2 is not None: + self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers + + def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None): + # SparseCausal-Attention + norm_hidden_states = ( + self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states) + ) + + if self.unet_use_cross_frame_attention: + hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states + else: + hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states + + if self.attn2 is not None: + # Cross-Attention + norm_hidden_states = ( + self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) + ) + hidden_states = ( + self.attn2( + norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask + ) + + hidden_states + ) + + # Feed-forward + hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states + + # Temporal-Attention + if self.unet_use_temporal_attention: + d = hidden_states.shape[1] + hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length) + norm_hidden_states = ( + self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states) + ) + hidden_states = self.attn_temp(norm_hidden_states) + hidden_states + hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d) + + return hidden_states + +# Resnet +class InflatedConv3d(nn.Conv2d): + def forward(self, x): + video_length = x.shape[2] + + x = rearrange(x, "b c f h w -> (b f) c h w") + x = super().forward(x) + x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) + + return x + + +class InflatedGroupNorm(nn.GroupNorm): + def forward(self, x): + video_length = x.shape[2] + + x = rearrange(x, "b c f h w -> (b f) c h w") + x = super().forward(x) + x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) + + return x + + +class Upsample3D(nn.Module): + def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_conv_transpose = use_conv_transpose + self.name = name + + conv = None + if use_conv_transpose: + raise NotImplementedError + elif use_conv: + self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1) + + def forward(self, hidden_states, output_size=None): + assert hidden_states.shape[1] == self.channels + + if self.use_conv_transpose: + raise NotImplementedError + + # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 + dtype = hidden_states.dtype + if dtype == torch.bfloat16: + hidden_states = hidden_states.to(torch.float32) + + # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 + if hidden_states.shape[0] >= 64: + hidden_states = hidden_states.contiguous() + + # if `output_size` is passed we force the interpolation output + # size and do not make use of `scale_factor=2` + if output_size is None: + hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest") + else: + hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") + + # If the input is bfloat16, we cast back to bfloat16 + if dtype == torch.bfloat16: + hidden_states = hidden_states.to(dtype) + + # if self.use_conv: + # if self.name == "conv": + # hidden_states = self.conv(hidden_states) + # else: + # hidden_states = self.Conv2d_0(hidden_states) + hidden_states = self.conv(hidden_states) + + return hidden_states + + +class Downsample3D(nn.Module): + def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.padding = padding + stride = 2 + self.name = name + + if use_conv: + self.conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding) + else: + raise NotImplementedError + + def forward(self, hidden_states): + assert hidden_states.shape[1] == self.channels + if self.use_conv and self.padding == 0: + raise NotImplementedError + + assert hidden_states.shape[1] == self.channels + hidden_states = self.conv(hidden_states) + + return hidden_states + + +class ResnetBlock3D(nn.Module): + def __init__( + self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout=0.0, + temb_channels=512, + groups=32, + groups_out=None, + pre_norm=True, + eps=1e-6, + non_linearity="swish", + time_embedding_norm="default", + output_scale_factor=1.0, + use_in_shortcut=None, + use_inflated_groupnorm=False, + ): + super().__init__() + self.pre_norm = pre_norm + self.pre_norm = True + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + self.time_embedding_norm = time_embedding_norm + self.output_scale_factor = output_scale_factor + + if groups_out is None: + groups_out = groups + + assert use_inflated_groupnorm != None + if use_inflated_groupnorm: + self.norm1 = InflatedGroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) + else: + self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) + + self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + + if temb_channels is not None: + if self.time_embedding_norm == "default": + time_emb_proj_out_channels = out_channels + elif self.time_embedding_norm == "scale_shift": + time_emb_proj_out_channels = out_channels * 2 + else: + raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ") + + self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels) + else: + self.time_emb_proj = None + + if use_inflated_groupnorm: + self.norm2 = InflatedGroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) + else: + self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) + + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + + if non_linearity == "swish": + self.nonlinearity = lambda x: F.silu(x) + elif non_linearity == "mish": + self.nonlinearity = Mish() + elif non_linearity == "silu": + self.nonlinearity = nn.SiLU() + + self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut + + self.conv_shortcut = None + if self.use_in_shortcut: + self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, input_tensor, temb): + hidden_states = input_tensor + + hidden_states = self.norm1(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.conv1(hidden_states) + + if temb is not None: + temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None] + + if temb is not None and self.time_embedding_norm == "default": + hidden_states = hidden_states + temb + + hidden_states = self.norm2(hidden_states) + + if temb is not None and self.time_embedding_norm == "scale_shift": + scale, shift = torch.chunk(temb, 2, dim=1) + hidden_states = hidden_states * (1 + scale) + shift + + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.conv_shortcut is not None: + input_tensor = self.conv_shortcut(input_tensor) + + output_tensor = (input_tensor + hidden_states) / self.output_scale_factor + + return output_tensor + + +class Mish(torch.nn.Module): + def forward(self, hidden_states): + return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states)) + + +# Animatediff_motion_module +def zero_module(module): + # Zero out the parameters of a module and return it. + for p in module.parameters(): + p.detach().zero_() + return module + + +@dataclass +class TemporalTransformer3DModelOutput(BaseOutput): + sample: torch.FloatTensor + + +def get_motion_module( + in_channels, + motion_module_type: str, + motion_module_kwargs: dict +): + if motion_module_type == "Vanilla": + return VanillaTemporalModule(in_channels=in_channels, **motion_module_kwargs,) + else: + raise ValueError + + +class VanillaTemporalModule(nn.Module): + def __init__( + self, + in_channels, + num_attention_heads = 8, + num_transformer_block = 2, + attention_block_types =( "Temporal_Self", "Temporal_Self" ), + cross_frame_attention_mode = None, + temporal_position_encoding = False, + temporal_position_encoding_max_len = 24, + temporal_attention_dim_div = 1, + zero_initialize = True, + ): + super().__init__() + + self.temporal_transformer = TemporalTransformer3DModel( + in_channels=in_channels, + num_attention_heads=num_attention_heads, + attention_head_dim=in_channels // num_attention_heads // temporal_attention_dim_div, + num_layers=num_transformer_block, + attention_block_types=attention_block_types, + cross_frame_attention_mode=cross_frame_attention_mode, + temporal_position_encoding=temporal_position_encoding, + temporal_position_encoding_max_len=temporal_position_encoding_max_len, + ) + + if zero_initialize: + self.temporal_transformer.proj_out = zero_module(self.temporal_transformer.proj_out) + + def forward(self, input_tensor, temb, encoder_hidden_states, attention_mask=None, anchor_frame_idx=None): + hidden_states = input_tensor + hidden_states = self.temporal_transformer(hidden_states, encoder_hidden_states, attention_mask) + + output = hidden_states + return output + +class TemporalTransformer3DModel(nn.Module): + def __init__( + self, + in_channels, + num_attention_heads, + attention_head_dim, + + num_layers, + attention_block_types = ( "Temporal_Self", "Temporal_Self", ), + dropout = 0.0, + norm_num_groups = 32, + cross_attention_dim = 768, + activation_fn = "geglu", + attention_bias = False, + upcast_attention = False, + + cross_frame_attention_mode = None, + temporal_position_encoding = False, + temporal_position_encoding_max_len = 24, + ): + super().__init__() + + inner_dim = num_attention_heads * attention_head_dim + + self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) + self.proj_in = nn.Linear(in_channels, inner_dim) + + self.transformer_blocks = nn.ModuleList( + [ + TemporalTransformerBlock( + dim=inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + attention_block_types=attention_block_types, + dropout=dropout, + norm_num_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + attention_bias=attention_bias, + upcast_attention=upcast_attention, + cross_frame_attention_mode=cross_frame_attention_mode, + temporal_position_encoding=temporal_position_encoding, + temporal_position_encoding_max_len=temporal_position_encoding_max_len, + ) + for d in range(num_layers) + ] + ) + self.proj_out = nn.Linear(inner_dim, in_channels) + + def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None): + assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." + video_length = hidden_states.shape[2] + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") + + batch, channel, height, weight = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) + hidden_states = self.proj_in(hidden_states) + + # Transformer Blocks + for block in self.transformer_blocks: + hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states, video_length=video_length) + + # output + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() + + output = hidden_states + residual + output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length) + + return output + + +class TemporalTransformerBlock(nn.Module): + def __init__( + self, + dim, + num_attention_heads, + attention_head_dim, + attention_block_types = ( "Temporal_Self", "Temporal_Self", ), + dropout = 0.0, + norm_num_groups = 32, + cross_attention_dim = 768, + activation_fn = "geglu", + attention_bias = False, + upcast_attention = False, + cross_frame_attention_mode = None, + temporal_position_encoding = False, + temporal_position_encoding_max_len = 24, + ): + super().__init__() + + attention_blocks = [] + norms = [] + + for block_name in attention_block_types: + attention_blocks.append( + VersatileAttention( + attention_mode=block_name.split("_")[0], + cross_attention_dim=cross_attention_dim if block_name.endswith("_Cross") else None, + + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + + cross_frame_attention_mode=cross_frame_attention_mode, + temporal_position_encoding=temporal_position_encoding, + temporal_position_encoding_max_len=temporal_position_encoding_max_len, + ) + ) + norms.append(nn.LayerNorm(dim)) + + self.attention_blocks = nn.ModuleList(attention_blocks) + self.norms = nn.ModuleList(norms) + + self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) + self.ff_norm = nn.LayerNorm(dim) + + + def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None): + for attention_block, norm in zip(self.attention_blocks, self.norms): + norm_hidden_states = norm(hidden_states) + hidden_states = attention_block( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if attention_block.is_cross_attention else None, + video_length=video_length, + ) + hidden_states + + hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states + + output = hidden_states + return output + + +class PositionalEncoding(nn.Module): + def __init__( + self, + d_model, + dropout = 0., + max_len = 24 + ): + super().__init__() + self.dropout = nn.Dropout(p=dropout) + position = torch.arange(max_len).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) + pe = torch.zeros(1, max_len, d_model) + pe[0, :, 0::2] = torch.sin(position * div_term) + pe[0, :, 1::2] = torch.cos(position * div_term) + self.register_buffer('pe', pe) + + def forward(self, x): + x = x + self.pe[:, :x.size(1)] + return self.dropout(x) + + +class VersatileAttention(CrossAttention): + def __init__( + self, + attention_mode = None, + cross_frame_attention_mode = None, + temporal_position_encoding = False, + temporal_position_encoding_max_len = 24, + *args, **kwargs + ): + super().__init__(*args, **kwargs) + assert attention_mode == "Temporal" + + self.attention_mode = attention_mode + self.is_cross_attention = kwargs["cross_attention_dim"] is not None + + self.pos_encoder = PositionalEncoding( + kwargs["query_dim"], + dropout=0., + max_len=temporal_position_encoding_max_len + ) if (temporal_position_encoding and attention_mode == "Temporal") else None + + def extra_repr(self): + return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}" + + def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None): + batch_size, sequence_length, _ = hidden_states.shape + + if self.attention_mode == "Temporal": + d = hidden_states.shape[1] + hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length) + + if self.pos_encoder is not None: + hidden_states = self.pos_encoder(hidden_states) + + encoder_hidden_states = repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d) if encoder_hidden_states is not None else encoder_hidden_states + else: + raise NotImplementedError + + encoder_hidden_states = encoder_hidden_states + + if self.group_norm is not None: + hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = self.to_q(hidden_states) + dim = query.shape[-1] + query = self.reshape_heads_to_batch_dim(query) + + if self.added_kv_proj_dim is not None: + raise NotImplementedError + + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + key = self.to_k(encoder_hidden_states) + value = self.to_v(encoder_hidden_states) + + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + + if attention_mask is not None: + if attention_mask.shape[-1] != query.shape[1]: + target_length = query.shape[1] + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) + + # attention, what we cannot get enough of + if self._use_memory_efficient_attention_xformers: + hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) + # Some versions of xformers return output in fp32, cast it back to the dtype of the input + hidden_states = hidden_states.to(query.dtype) + else: + if self._slice_size is None or query.shape[0] // self._slice_size == 1: + hidden_states = self._attention(query, key, value, attention_mask) + else: + hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask) + + # linear proj + hidden_states = self.to_out[0](hidden_states) + + # dropout + hidden_states = self.to_out[1](hidden_states) + + if self.attention_mode == "Temporal": + hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d) + + return hidden_states + + +# UNet_block +def get_down_block( + down_block_type, + num_layers, + in_channels, + out_channels, + temb_channels, + add_downsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + resnet_groups=None, + cross_attention_dim=None, + downsample_padding=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", + + unet_use_cross_frame_attention=False, + unet_use_temporal_attention=False, + use_inflated_groupnorm=False, + + use_motion_module=None, + + motion_module_type=None, + motion_module_kwargs=None, +): + down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type + if down_block_type == "DownBlock3D": + return DownBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + + use_inflated_groupnorm=use_inflated_groupnorm, + + use_motion_module=use_motion_module, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) + elif down_block_type == "CrossAttnDownBlock3D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D") + return CrossAttnDownBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + + unet_use_cross_frame_attention=unet_use_cross_frame_attention, + unet_use_temporal_attention=unet_use_temporal_attention, + use_inflated_groupnorm=use_inflated_groupnorm, + + use_motion_module=use_motion_module, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) + raise ValueError(f"{down_block_type} does not exist.") + + +def get_up_block( + up_block_type, + num_layers, + in_channels, + out_channels, + prev_output_channel, + temb_channels, + add_upsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + resnet_groups=None, + cross_attention_dim=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", + + unet_use_cross_frame_attention=False, + unet_use_temporal_attention=False, + use_inflated_groupnorm=False, + + use_motion_module=None, + motion_module_type=None, + motion_module_kwargs=None, +): + up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type + if up_block_type == "UpBlock3D": + return UpBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + + use_inflated_groupnorm=use_inflated_groupnorm, + + use_motion_module=use_motion_module, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) + elif up_block_type == "CrossAttnUpBlock3D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D") + return CrossAttnUpBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + + unet_use_cross_frame_attention=unet_use_cross_frame_attention, + unet_use_temporal_attention=unet_use_temporal_attention, + use_inflated_groupnorm=use_inflated_groupnorm, + + use_motion_module=use_motion_module, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) + raise ValueError(f"{up_block_type} does not exist.") + + +class UNetMidBlock3DCrossAttn(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + output_scale_factor=1.0, + cross_attention_dim=1280, + dual_cross_attention=False, + use_linear_projection=False, + upcast_attention=False, + + unet_use_cross_frame_attention=False, + unet_use_temporal_attention=False, + use_inflated_groupnorm=False, + + use_motion_module=None, + + motion_module_type=None, + motion_module_kwargs=None, + ): + super().__init__() + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + # there is always at least one resnet + resnets = [ + ResnetBlock3D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + + use_inflated_groupnorm=use_inflated_groupnorm, + ) + ] + attentions = [] + motion_modules = [] + + for _ in range(num_layers): + if dual_cross_attention: + raise NotImplementedError + attentions.append( + Transformer3DModel( + attn_num_head_channels, + in_channels // attn_num_head_channels, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + + unet_use_cross_frame_attention=unet_use_cross_frame_attention, + unet_use_temporal_attention=unet_use_temporal_attention, + ) + ) + motion_modules.append( + get_motion_module( + in_channels=in_channels, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) if use_motion_module else None + ) + resnets.append( + ResnetBlock3D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + + use_inflated_groupnorm=use_inflated_groupnorm, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + self.motion_modules = nn.ModuleList(motion_modules) + + def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None): + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet, motion_module in zip(self.attentions, self.resnets[1:], self.motion_modules): + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample + hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states + hidden_states = resnet(hidden_states, temb) + + return hidden_states + + +class CrossAttnDownBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + downsample_padding=1, + add_downsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + + unet_use_cross_frame_attention=False, + unet_use_temporal_attention=False, + use_inflated_groupnorm=False, + + use_motion_module=None, + + motion_module_type=None, + motion_module_kwargs=None, + ): + super().__init__() + resnets = [] + attentions = [] + motion_modules = [] + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock3D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + + use_inflated_groupnorm=use_inflated_groupnorm, + ) + ) + if dual_cross_attention: + raise NotImplementedError + attentions.append( + Transformer3DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + + unet_use_cross_frame_attention=unet_use_cross_frame_attention, + unet_use_temporal_attention=unet_use_temporal_attention, + ) + ) + motion_modules.append( + get_motion_module( + in_channels=out_channels, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) if use_motion_module else None + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + self.motion_modules = nn.ModuleList(motion_modules) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample3D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None): + output_states = () + + for resnet, attn, motion_module in zip(self.resnets, self.attentions, self.motion_modules): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + )[0] + if motion_module is not None: + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states) + + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample + + # add motion module + hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class DownBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_downsample=True, + downsample_padding=1, + + use_inflated_groupnorm=False, + + use_motion_module=None, + motion_module_type=None, + motion_module_kwargs=None, + ): + super().__init__() + resnets = [] + motion_modules = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock3D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + + use_inflated_groupnorm=use_inflated_groupnorm, + ) + ) + motion_modules.append( + get_motion_module( + in_channels=out_channels, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) if use_motion_module else None + ) + + self.resnets = nn.ModuleList(resnets) + self.motion_modules = nn.ModuleList(motion_modules) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample3D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, temb=None, encoder_hidden_states=None): + output_states = () + + for resnet, motion_module in zip(self.resnets, self.motion_modules): + if self.training and self.gradient_checkpointing: + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + if motion_module is not None: + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states) + else: + hidden_states = resnet(hidden_states, temb) + + # add motion module + hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class CrossAttnUpBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + add_upsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + + unet_use_cross_frame_attention=False, + unet_use_temporal_attention=False, + use_inflated_groupnorm=False, + + use_motion_module=None, + + motion_module_type=None, + motion_module_kwargs=None, + ): + super().__init__() + resnets = [] + attentions = [] + motion_modules = [] + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock3D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + + use_inflated_groupnorm=use_inflated_groupnorm, + ) + ) + if dual_cross_attention: + raise NotImplementedError + attentions.append( + Transformer3DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + + unet_use_cross_frame_attention=unet_use_cross_frame_attention, + unet_use_temporal_attention=unet_use_temporal_attention, + ) + ) + motion_modules.append( + get_motion_module( + in_channels=out_channels, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) if use_motion_module else None + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + self.motion_modules = nn.ModuleList(motion_modules) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + res_hidden_states_tuple, + temb=None, + encoder_hidden_states=None, + upsample_size=None, + attention_mask=None, + ): + for resnet, attn, motion_module in zip(self.resnets, self.attentions, self.motion_modules): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + )[0] + if motion_module is not None: + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states) + + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample + + # add motion module + hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +class UpBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_upsample=True, + + use_inflated_groupnorm=False, + + use_motion_module=None, + motion_module_type=None, + motion_module_kwargs=None, + ): + super().__init__() + resnets = [] + motion_modules = [] + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock3D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + + use_inflated_groupnorm=use_inflated_groupnorm, + ) + ) + motion_modules.append( + get_motion_module( + in_channels=out_channels, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) if use_motion_module else None + ) + + self.resnets = nn.ModuleList(resnets) + self.motion_modules = nn.ModuleList(motion_modules) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, encoder_hidden_states=None,): + for resnet, motion_module in zip(self.resnets, self.motion_modules): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + if motion_module is not None: + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states) + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states diff --git a/utils/util.py b/utils/util.py new file mode 100644 index 0000000000000000000000000000000000000000..4f23906e14945abda23ef7acbf85144f675e800f --- /dev/null +++ b/utils/util.py @@ -0,0 +1,717 @@ +import os +import imageio +import numpy as np +from tqdm import tqdm +from typing import Union +from einops import rearrange +from safetensors import safe_open +from transformers import CLIPTextModel +import torch +import torchvision +import torch.distributed as dist + +def zero_rank_print(s): + if (not dist.is_initialized()) and (dist.is_initialized() and dist.get_rank() == 0): print("### " + s) + +def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8): + videos = rearrange(videos, "b c t h w -> t b c h w") + outputs = [] + for x in videos: + x = torchvision.utils.make_grid(x, nrow=n_rows) + x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) + if rescale: + x = (x + 1.0) / 2.0 # -1,1 -> 0,1 + x = (x * 255).numpy().astype(np.uint8) + outputs.append(x) + + os.makedirs(os.path.dirname(path), exist_ok=True) + imageio.mimsave(path, outputs, fps=fps) + +# DDIM Inversion +@torch.no_grad() +def init_prompt(prompt, pipeline): + uncond_input = pipeline.tokenizer( + [""], padding="max_length", max_length=pipeline.tokenizer.model_max_length, + return_tensors="pt" + ) + uncond_embeddings = pipeline.text_encoder(uncond_input.input_ids.to(pipeline.device))[0] + text_input = pipeline.tokenizer( + [prompt], + padding="max_length", + max_length=pipeline.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0] + context = torch.cat([uncond_embeddings, text_embeddings]) + + return context + +def next_step(model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, + sample: Union[torch.FloatTensor, np.ndarray], ddim_scheduler): + timestep, next_timestep = min( + timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999), timestep + alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod + alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep] + beta_prod_t = 1 - alpha_prod_t + next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5 + next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output + next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction + return next_sample + +def get_noise_pred_single(latents, t, context, unet): + noise_pred = unet(latents, t, encoder_hidden_states=context)["sample"] + return noise_pred + +@torch.no_grad() +def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt): + context = init_prompt(prompt, pipeline) + uncond_embeddings, cond_embeddings = context.chunk(2) + all_latent = [latent] + latent = latent.clone().detach() + for i in tqdm(range(num_inv_steps)): + t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1] + noise_pred = get_noise_pred_single(latent, t, cond_embeddings, pipeline.unet) + latent = next_step(noise_pred, t, latent, ddim_scheduler) + all_latent.append(latent) + return all_latent + +@torch.no_grad() +def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""): + ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt) + return ddim_latents + +def load_weights( + magictime_pipeline, + motion_module_path = "", + dreambooth_model_path = "", + magic_adapter_s_path = "", + magic_adapter_t_path = "", + magic_text_encoder_path = "", +): + # motion module + unet_state_dict = {} + if motion_module_path != "": + print(f"load motion module from {motion_module_path}") + try: + motion_module_state_dict = torch.load(motion_module_path, map_location="cpu") + if "state_dict" in motion_module_state_dict: + motion_module_state_dict = motion_module_state_dict["state_dict"] + for name, param in motion_module_state_dict.items(): + if "motion_modules." in name: + modified_name = name.removeprefix('module.') if name.startswith('module.') else name + unet_state_dict[modified_name] = param + except Exception as e: + print(f"Error loading motion module: {e}") + try: + missing, unexpected = magictime_pipeline.unet.load_state_dict(unet_state_dict, strict=False) + assert len(unexpected) == 0, f"Unexpected keys in state_dict: {unexpected}" + del unet_state_dict + except Exception as e: + print(f"Error loading state dict into UNet: {e}") + + # base model + if dreambooth_model_path != "": + print(f"load dreambooth model from {dreambooth_model_path}") + if dreambooth_model_path.endswith(".safetensors"): + dreambooth_state_dict = {} + with safe_open(dreambooth_model_path, framework="pt", device="cpu") as f: + for key in f.keys(): + dreambooth_state_dict[key] = f.get_tensor(key) + elif dreambooth_model_path.endswith(".ckpt"): + dreambooth_state_dict = torch.load(dreambooth_model_path, map_location="cpu") + + # 1. vae + converted_vae_checkpoint = convert_ldm_vae_checkpoint(dreambooth_state_dict, magictime_pipeline.vae.config) + magictime_pipeline.vae.load_state_dict(converted_vae_checkpoint) + # 2. unet + converted_unet_checkpoint = convert_ldm_unet_checkpoint(dreambooth_state_dict, magictime_pipeline.unet.config) + magictime_pipeline.unet.load_state_dict(converted_unet_checkpoint, strict=False) + # 3. text_model + magictime_pipeline.text_encoder = convert_ldm_clip_checkpoint(dreambooth_state_dict) + del dreambooth_state_dict + + # MagicAdapter and MagicTextEncoder + if magic_adapter_s_path != "": + print(f"load domain lora from {magic_adapter_s_path}") + magic_adapter_s_state_dict = torch.load(magic_adapter_s_path, map_location="cpu") + magictime_pipeline = load_diffusers_lora(magictime_pipeline, magic_adapter_s_state_dict, alpha=1.0) + + if magic_adapter_t_path != "" or magic_text_encoder_path != "": + from swift import Swift + + if magic_adapter_t_path != "": + print("load lora from swift for Unet") + Swift.from_pretrained(magictime_pipeline.unet, magic_adapter_t_path) + + if magic_text_encoder_path != "": + print("load lora from swift for text encoder") + Swift.from_pretrained(magictime_pipeline.text_encoder, magic_text_encoder_path) + + return magictime_pipeline + +def load_diffusers_lora(pipeline, state_dict, alpha=1.0): + # directly update weight in diffusers model + for key in state_dict: + # only process lora down key + if "up." in key: continue + + up_key = key.replace(".down.", ".up.") + model_key = key.replace("processor.", "").replace("_lora", "").replace("down.", "").replace("up.", "") + model_key = model_key.replace("to_out.", "to_out.0.") + layer_infos = model_key.split(".")[:-1] + + curr_layer = pipeline.unet + while len(layer_infos) > 0: + temp_name = layer_infos.pop(0) + curr_layer = curr_layer.__getattr__(temp_name) + + weight_down = state_dict[key] * 2 + weight_up = state_dict[up_key] * 2 + curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device) + + return pipeline + +def load_diffusers_lora_unet(unet, state_dict, alpha=1.0): + # directly update weight in diffusers model + for key in state_dict: + # only process lora down key + if "up." in key: continue + + up_key = key.replace(".down.", ".up.") + model_key = key.replace("processor.", "").replace("_lora", "").replace("down.", "").replace("up.", "") + model_key = model_key.replace("to_out.", "to_out.0.") + layer_infos = model_key.split(".")[:-1] + + curr_layer = unet + while len(layer_infos) > 0: + temp_name = layer_infos.pop(0) + curr_layer = curr_layer.__getattr__(temp_name) + + weight_down = state_dict[key] * 2 + weight_up = state_dict[up_key] * 2 + curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device) + + return unet + +def convert_lora(pipeline, state_dict, LORA_PREFIX_UNET="lora_unet", LORA_PREFIX_TEXT_ENCODER="lora_te", alpha=0.6): + visited = [] + + # directly update weight in diffusers model + for key in state_dict: + # it is suggested to print out the key, it usually will be something like below + # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight" + + # as we have set the alpha beforehand, so just skip + if ".alpha" in key or key in visited: + continue + + if "text" in key: + layer_infos = key.split(".")[0].split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_") + curr_layer = pipeline.text_encoder + else: + layer_infos = key.split(".")[0].split(LORA_PREFIX_UNET + "_")[-1].split("_") + curr_layer = pipeline.unet + + # find the target layer + temp_name = layer_infos.pop(0) + while len(layer_infos) > -1: + try: + curr_layer = curr_layer.__getattr__(temp_name) + if len(layer_infos) > 0: + temp_name = layer_infos.pop(0) + elif len(layer_infos) == 0: + break + except Exception: + if len(temp_name) > 0: + temp_name += "_" + layer_infos.pop(0) + else: + temp_name = layer_infos.pop(0) + + pair_keys = [] + if "lora_down" in key: + pair_keys.append(key.replace("lora_down", "lora_up")) + pair_keys.append(key) + else: + pair_keys.append(key) + pair_keys.append(key.replace("lora_up", "lora_down")) + + # update weight + if len(state_dict[pair_keys[0]].shape) == 4: + weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32) + weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32) + curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3).to(curr_layer.weight.data.device) + else: + weight_up = state_dict[pair_keys[0]].to(torch.float32) + weight_down = state_dict[pair_keys[1]].to(torch.float32) + curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device) + + # update visited list + for item in pair_keys: + visited.append(item) + + return pipeline + +def shave_segments(path, n_shave_prefix_segments=1): + """ + Removes segments. Positive values shave the first segments, negative shave the last segments. + """ + if n_shave_prefix_segments >= 0: + return ".".join(path.split(".")[n_shave_prefix_segments:]) + else: + return ".".join(path.split(".")[:n_shave_prefix_segments]) + +def renew_resnet_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item.replace("in_layers.0", "norm1") + new_item = new_item.replace("in_layers.2", "conv1") + + new_item = new_item.replace("out_layers.0", "norm2") + new_item = new_item.replace("out_layers.3", "conv2") + + new_item = new_item.replace("emb_layers.1", "time_emb_proj") + new_item = new_item.replace("skip_connection", "conv_shortcut") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + +def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + new_item = new_item.replace("nin_shortcut", "conv_shortcut") + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + +def renew_attention_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + mapping.append({"old": old_item, "new": new_item}) + return mapping + +def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + new_item = new_item.replace("norm.weight", "group_norm.weight") + new_item = new_item.replace("norm.bias", "group_norm.bias") + + new_item = new_item.replace("q.weight", "query.weight") + new_item = new_item.replace("q.bias", "query.bias") + + new_item = new_item.replace("k.weight", "key.weight") + new_item = new_item.replace("k.bias", "key.bias") + + new_item = new_item.replace("v.weight", "value.weight") + new_item = new_item.replace("v.bias", "value.bias") + + new_item = new_item.replace("proj_out.weight", "proj_attn.weight") + new_item = new_item.replace("proj_out.bias", "proj_attn.bias") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + +def assign_to_checkpoint( + paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None +): + """ + This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits + attention layers, and takes into account additional replacements that may arise. + + Assigns the weights to the new checkpoint. + """ + assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." + + # Splits the attention layers into three variables. + if attention_paths_to_split is not None: + for path, path_map in attention_paths_to_split.items(): + old_tensor = old_checkpoint[path] + channels = old_tensor.shape[0] // 3 + + target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) + + num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 + + old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) + query, key, value = old_tensor.split(channels // num_heads, dim=1) + + checkpoint[path_map["query"]] = query.reshape(target_shape) + checkpoint[path_map["key"]] = key.reshape(target_shape) + checkpoint[path_map["value"]] = value.reshape(target_shape) + + for path in paths: + new_path = path["new"] + + # These have already been assigned + if attention_paths_to_split is not None and new_path in attention_paths_to_split: + continue + + # Global renaming happens here + new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") + new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") + new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") + + if additional_replacements is not None: + for replacement in additional_replacements: + new_path = new_path.replace(replacement["old"], replacement["new"]) + + # proj_attn.weight has to be converted from conv 1D to linear + if "proj_attn.weight" in new_path: + checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] + else: + checkpoint[new_path] = old_checkpoint[path["old"]] + +def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False): + """ + Takes a state dict and a config, and returns a converted checkpoint. + """ + + # extract state_dict for UNet + unet_state_dict = {} + keys = list(checkpoint.keys()) + + unet_key = "model.diffusion_model." + + # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA + if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema: + print(f"Checkpoint {path} has both EMA and non-EMA weights.") + print( + "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA" + " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag." + ) + for key in keys: + if key.startswith("model.diffusion_model"): + flat_ema_key = "model_ema." + "".join(key.split(".")[1:]) + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key) + else: + if sum(k.startswith("model_ema") for k in keys) > 100: + print( + "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA" + " weights (usually better for inference), please make sure to add the `--extract_ema` flag." + ) + + for key in keys: + if key.startswith(unet_key): + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) + + new_checkpoint = {} + + new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] + new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] + new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] + new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] + + if config["class_embed_type"] is None: + # No parameters to port + ... + elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection": + new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"] + new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"] + new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"] + new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"] + else: + raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}") + + new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] + new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] + new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] + new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] + new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] + new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] + + # Retrieves the keys for the input blocks only + num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) + input_blocks = { + layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key] + for layer_id in range(num_input_blocks) + } + + # Retrieves the keys for the middle blocks only + num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) + middle_blocks = { + layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] + for layer_id in range(num_middle_blocks) + } + + # Retrieves the keys for the output blocks only + num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) + output_blocks = { + layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key] + for layer_id in range(num_output_blocks) + } + + for i in range(1, num_input_blocks): + block_id = (i - 1) // (config["layers_per_block"] + 1) + layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) + + resnets = [ + key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key + ] + attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] + + if f"input_blocks.{i}.0.op.weight" in unet_state_dict: + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.weight" + ) + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.bias" + ) + + paths = renew_resnet_paths(resnets) + meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + resnet_0 = middle_blocks[0] + attentions = middle_blocks[1] + resnet_1 = middle_blocks[2] + + resnet_0_paths = renew_resnet_paths(resnet_0) + assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config) + + resnet_1_paths = renew_resnet_paths(resnet_1) + assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config) + + attentions_paths = renew_attention_paths(attentions) + meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} + assign_to_checkpoint( + attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + for i in range(num_output_blocks): + block_id = i // (config["layers_per_block"] + 1) + layer_in_block_id = i % (config["layers_per_block"] + 1) + output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] + output_block_list = {} + + for layer in output_block_layers: + layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) + if layer_id in output_block_list: + output_block_list[layer_id].append(layer_name) + else: + output_block_list[layer_id] = [layer_name] + + if len(output_block_list) > 1: + resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] + attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] + + resnet_0_paths = renew_resnet_paths(resnets) + paths = renew_resnet_paths(resnets) + + meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + output_block_list = {k: sorted(v) for k, v in output_block_list.items()} + if ["conv.bias", "conv.weight"] in output_block_list.values(): + index = list(output_block_list.values()).index(["conv.bias", "conv.weight"]) + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.weight" + ] + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.bias" + ] + + # Clear attentions as they have been attributed above. + if len(attentions) == 2: + attentions = [] + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = { + "old": f"output_blocks.{i}.1", + "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", + } + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + else: + resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) + for path in resnet_0_paths: + old_path = ".".join(["output_blocks", str(i), path["old"]]) + new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) + + new_checkpoint[new_path] = unet_state_dict[old_path] + + return new_checkpoint + +def convert_ldm_clip_checkpoint(checkpoint): + from transformers import CLIPTextModel + text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14") + + keys = list(checkpoint.keys()) + keys.remove("cond_stage_model.transformer.text_model.embeddings.position_ids") + + text_model_dict = {} + + for key in keys: + if key.startswith("cond_stage_model.transformer"): + text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key] + text_model.load_state_dict(text_model_dict) + + return text_model + +def convert_ldm_clip_text_model(text_model, checkpoint): + keys = list(checkpoint.keys()) + keys.remove("cond_stage_model.transformer.text_model.embeddings.position_ids") + + text_model_dict = {} + + for key in keys: + if key.startswith("cond_stage_model.transformer"): + text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key] + text_model.load_state_dict(text_model_dict) + + return text_model + +def conv_attn_to_linear(checkpoint): + keys = list(checkpoint.keys()) + attn_keys = ["query.weight", "key.weight", "value.weight"] + for key in keys: + if ".".join(key.split(".")[-2:]) in attn_keys: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0, 0] + elif "proj_attn.weight" in key: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0] + +def convert_ldm_vae_checkpoint(checkpoint, config): + # extract state dict for VAE + vae_state_dict = {} + vae_key = "first_stage_model." + keys = list(checkpoint.keys()) + for key in keys: + if key.startswith(vae_key): + vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) + + new_checkpoint = {} + + new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] + new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] + new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] + new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] + new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] + new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] + + new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] + new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] + new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] + new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] + new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] + new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] + + new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] + new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] + new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] + new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] + + # Retrieves the keys for the encoder down blocks only + num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) + down_blocks = { + layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) + } + + # Retrieves the keys for the decoder up blocks only + num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) + up_blocks = { + layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) + } + + for i in range(num_down_blocks): + resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] + + if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.weight" + ) + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.bias" + ) + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + conv_attn_to_linear(new_checkpoint) + + for i in range(num_up_blocks): + block_id = num_up_blocks - 1 - i + resnets = [ + key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key + ] + + if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.weight" + ] + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.bias" + ] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + conv_attn_to_linear(new_checkpoint) + + return new_checkpoint \ No newline at end of file