diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..dd75a945359aa9217dc72f770edd831f38eed5fe 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+resources/CogVideoX.pdf filter=lfs diff=lfs merge=lfs -text
+resources/videos/2.mp4 filter=lfs diff=lfs merge=lfs -text
+resources/web_demo.png filter=lfs diff=lfs merge=lfs -text
+tools/caption/assests/cogvlm2-video-example.png filter=lfs diff=lfs merge=lfs -text
diff --git a/.github/ISSUE_TEMPLATE/bug_report.yaml b/.github/ISSUE_TEMPLATE/bug_report.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..19271ef8a55b08c9ad52acdd1a0c2a76dd1e2990
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/bug_report.yaml
@@ -0,0 +1,51 @@
+name: "\U0001F41B Bug Report"
+description: Submit a bug report to help us improve CogVideoX / 提交一个 Bug 问题报告来帮助我们改进 CogVideoX 开源模型
+body:
+ - type: textarea
+ id: system-info
+ attributes:
+ label: System Info / 系統信息
+ description: Your operating environment / 您的运行环境信息
+ placeholder: Includes Cuda version, Diffusers version, Python version, operating system, hardware information (if you suspect a hardware problem)... / 包括Cuda版本,Diffusers,Python版本,操作系统,硬件信息(如果您怀疑是硬件方面的问题)...
+ validations:
+ required: true
+
+ - type: checkboxes
+ id: information-scripts-examples
+ attributes:
+ label: Information / 问题信息
+ description: 'The problem arises when using: / 问题出现在'
+ options:
+ - label: "The official example scripts / 官方的示例脚本"
+ - label: "My own modified scripts / 我自己修改的脚本和任务"
+
+ - type: textarea
+ id: reproduction
+ validations:
+ required: true
+ attributes:
+ label: Reproduction / 复现过程
+ description: |
+ Please provide a code example that reproduces the problem you encountered, preferably with a minimal reproduction unit.
+ If you have code snippets, error messages, stack traces, please provide them here as well.
+ Please format your code correctly using code tags. See https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting
+ Do not use screenshots, as they are difficult to read and (more importantly) do not allow others to copy and paste your code.
+
+ 请提供能重现您遇到的问题的代码示例,最好是最小复现单元。
+ 如果您有代码片段、错误信息、堆栈跟踪,也请在此提供。
+ 请使用代码标签正确格式化您的代码。请参见 https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting
+ 请勿使用截图,因为截图难以阅读,而且(更重要的是)不允许他人复制粘贴您的代码。
+ placeholder: |
+ Steps to reproduce the behavior/复现Bug的步骤:
+
+ 1.
+ 2.
+ 3.
+
+ - type: textarea
+ id: expected-behavior
+ validations:
+ required: true
+ attributes:
+ label: Expected behavior / 期待表现
+ description: "A clear and concise description of what you would expect to happen. /简单描述您期望发生的事情。"
\ No newline at end of file
diff --git a/.github/ISSUE_TEMPLATE/feature-request.yaml b/.github/ISSUE_TEMPLATE/feature-request.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..7e09beeecdc35671c0f3c30cab5c3375be4dd78d
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/feature-request.yaml
@@ -0,0 +1,34 @@
+name: "\U0001F680 Feature request"
+description: Submit a request for a new CogVideoX feature / 提交一个新的 CogVideoX开源模型的功能建议
+labels: [ "feature" ]
+body:
+ - type: textarea
+ id: feature-request
+ validations:
+ required: true
+ attributes:
+ label: Feature request / 功能建议
+ description: |
+ A brief description of the functional proposal. Links to corresponding papers and code are desirable.
+ 对功能建议的简述。最好提供对应的论文和代码链接。
+
+ - type: textarea
+ id: motivation
+ validations:
+ required: true
+ attributes:
+ label: Motivation / 动机
+ description: |
+ Your motivation for making the suggestion. If that motivation is related to another GitHub issue, link to it here.
+ 您提出建议的动机。如果该动机与另一个 GitHub 问题有关,请在此处提供对应的链接。
+
+ - type: textarea
+ id: contribution
+ validations:
+ required: true
+ attributes:
+ label: Your contribution / 您的贡献
+ description: |
+
+ Your PR link or any other link you can help with.
+ 您的PR链接或者其他您能提供帮助的链接。
\ No newline at end of file
diff --git a/.github/PULL_REQUEST_TEMPLATE/pr_template.md b/.github/PULL_REQUEST_TEMPLATE/pr_template.md
new file mode 100644
index 0000000000000000000000000000000000000000..0c3140a2e451de157dd075038c79a461604ada60
--- /dev/null
+++ b/.github/PULL_REQUEST_TEMPLATE/pr_template.md
@@ -0,0 +1,34 @@
+# Raise valuable PR / 提出有价值的PR
+
+## Caution / 注意事项:
+Users should keep the following points in mind when submitting PRs:
+
+1. Ensure that your code meets the requirements in the [specification](../../resources/contribute.md).
+2. the proposed PR should be relevant, if there are multiple ideas and optimizations, they should be assigned to different PRs.
+
+用户在提交PR时候应该注意以下几点:
+
+1. 确保您的代码符合 [规范](../../resources/contribute_zh.md) 中的要求。
+2. 提出的PR应该具有针对性,如果具有多个不同的想法和优化方案,应该分配到不同的PR中。
+
+## 不应该提出的PR / PRs that should not be proposed
+
+If a developer proposes a PR about any of the following, it may be closed or Rejected.
+
+1. those that don't describe improvement options.
+2. multiple issues of different types combined in one PR.
+3. The proposed PR is highly duplicative of already existing PRs.
+
+如果开发者提出关于以下方面的PR,则可能会被直接关闭或拒绝通过。
+
+1. 没有说明改进方案的。
+2. 多个不同类型的问题合并在一个PR中的。
+3. 提出的PR与已经存在的PR高度重复的。
+
+
+# 检查您的PR
+- [ ] Have you read the Contributor Guidelines, Pull Request section? / 您是否阅读了贡献者指南、Pull Request 部分?
+- [ ] Has this been discussed/approved via a Github issue or forum? If so, add a link. / 是否通过 Github 问题或论坛讨论/批准过?如果是,请添加链接。
+- [ ] Did you make sure you updated the documentation with your changes? Here are the Documentation Guidelines, and here are the Documentation Formatting Tips. /您是否确保根据您的更改更新了文档?这里是文档指南,这里是文档格式化技巧。
+- [ ] Did you write new required tests? / 您是否编写了新的必要测试?
+- [ ] Are your PRs for only one issue / 您的PR是否仅针对一个问题
\ No newline at end of file
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..6ff20ae38c4181891eeca6610e30481f8aa843c5
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,9 @@
+output/
+*__pycache__/
+samples*/
+runs/
+checkpoints/
+master_ip
+logs/
+*.DS_Store
+.idea
\ No newline at end of file
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..f98e413cf55b56dbabdf693559055d79800f63c9
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright 2024 CogVideo Model Team @ Zhipu AI
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/Model_License b/Model_License
new file mode 100644
index 0000000000000000000000000000000000000000..3ca0c74848a189e77f466c542122bc09aab94381
--- /dev/null
+++ b/Model_License
@@ -0,0 +1,71 @@
+The CogVideoX License
+
+1. Definitions
+
+“Licensor” means the CogVideoX Model Team that distributes its Software.
+
+“Software” means the CogVideoX model parameters made available under this license.
+
+2. License Grant
+
+Under the terms and conditions of this license, the licensor hereby grants you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty-free copyright license. The intellectual property rights of the generated content belong to the user to the extent permitted by applicable local laws.
+This license allows you to freely use all open-source models in this repository for academic research. Users who wish to use the models for commercial purposes must register and obtain a basic commercial license in https://open.bigmodel.cn/mla/form .
+Users who have registered and obtained the basic commercial license can use the models for commercial activities for free, but must comply with all terms and conditions of this license. Additionally, the number of service users (visits) for your commercial activities must not exceed 1 million visits per month.
+If the number of service users (visits) for your commercial activities exceeds 1 million visits per month, you need to contact our business team to obtain more commercial licenses.
+The above copyright statement and this license statement should be included in all copies or significant portions of this software.
+
+3. Restriction
+
+You will not use, copy, modify, merge, publish, distribute, reproduce, or create derivative works of the Software, in whole or in part, for any military, or illegal purposes.
+
+You will not use the Software for any act that may undermine China's national security and national unity, harm the public interest of society, or infringe upon the rights and interests of human beings.
+
+4. Disclaimer
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+
+5. Limitation of Liability
+
+EXCEPT TO THE EXTENT PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER BASED IN TORT, NEGLIGENCE, CONTRACT, LIABILITY, OR OTHERWISE WILL ANY LICENSOR BE LIABLE TO YOU FOR ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES, OR ANY OTHER COMMERCIAL LOSSES, EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
+
+6. Dispute Resolution
+
+This license shall be governed and construed in accordance with the laws of People’s Republic of China. Any dispute arising from or in connection with this License shall be submitted to Haidian District People's Court in Beijing.
+
+Note that the license is subject to update to a more comprehensive version. For any questions related to the license and copyright, please contact us at license@zhipuai.cn.
+
+1. 定义
+
+“许可方”是指分发其软件的 CogVideoX 模型团队。
+
+“软件”是指根据本许可提供的 CogVideoX 模型参数。
+
+2. 许可授予
+
+根据本许可的条款和条件,许可方特此授予您非排他性、全球性、不可转让、不可再许可、可撤销、免版税的版权许可。生成内容的知识产权所属,可根据适用当地法律的规定,在法律允许的范围内由用户享有生成内容的知识产权或其他权利。
+本许可允许您免费使用本仓库中的所有开源模型进行学术研究。对于希望将模型用于商业目的的用户,需在 https://open.bigmodel.cn/mla/form 完成登记并获得基础商用授权。
+
+经过登记并获得基础商用授权的用户可以免费使用本模型进行商业活动,但必须遵守本许可的所有条款和条件。
+在本许可证下,您的商业活动的服务用户数量(访问量)不得超过100万人次访问 / 每月。如果超过,您需要与我们的商业团队联系以获得更多的商业许可。
+上述版权声明和本许可声明应包含在本软件的所有副本或重要部分中。
+
+3.限制
+
+您不得出于任何军事或非法目的使用、复制、修改、合并、发布、分发、复制或创建本软件的全部或部分衍生作品。
+
+您不得利用本软件从事任何危害国家安全和国家统一、危害社会公共利益、侵犯人身权益的行为。
+
+4.免责声明
+
+本软件“按原样”提供,不提供任何明示或暗示的保证,包括但不限于对适销性、特定用途的适用性和非侵权性的保证。
+在任何情况下,作者或版权持有人均不对任何索赔、损害或其他责任负责,无论是在合同诉讼、侵权行为还是其他方面,由软件或软件的使用或其他交易引起、由软件引起或与之相关 软件。
+
+5. 责任限制
+
+除适用法律禁止的范围外,在任何情况下且根据任何法律理论,无论是基于侵权行为、疏忽、合同、责任或其他原因,任何许可方均不对您承担任何直接、间接、特殊、偶然、示范性、 或间接损害,或任何其他商业损失,即使许可人已被告知此类损害的可能性。
+
+6.争议解决
+
+本许可受中华人民共和国法律管辖并按其解释。 因本许可引起的或与本许可有关的任何争议应提交北京市海淀区人民法院。
+
+请注意,许可证可能会更新到更全面的版本。 有关许可和版权的任何问题,请通过 license@zhipuai.cn 与我们联系。
\ No newline at end of file
diff --git a/README.md b/README.md
index b285f960bcccfde5ad26dfa899f5468d8ca5fa98..7eca7c1f482010d388c32d189c2800a0c0e0a1e1 100644
--- a/README.md
+++ b/README.md
@@ -1,12 +1,163 @@
---
title: CogVideo
-emoji: 🐨
-colorFrom: blue
-colorTo: red
+app_file: gradio_demo.py
sdk: gradio
sdk_version: 4.41.0
-app_file: app.py
-pinned: false
---
+# CogVideo && CogVideoX
-Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
+[中文阅读](./README_zh.md)
+
+
+📍 Visit 清影 and API Platform to experience larger-scale commercial video generation models.
+
+
+## Update and News
+
+- 🔥 **News**: ``2024/8/6``: We have also open-sourced **3D Causal VAE** used in **CogVideoX-2B**, which can reconstruct
+ the video almost losslessly.
+- 🔥 **News**: ``2024/8/6``: We have open-sourced **CogVideoX-2B**,the first model in the CogVideoX series of video
+ generation models.
+- 🌱 **Source**: ```2022/5/19```: We have open-sourced CogVideo (now you can see in `CogVideo` branch),the **first** open-sourced pretrained text-to-video model, and you can check [ICLR'23 CogVideo Paper](https://arxiv.org/abs/2205.15868) for technical details.
+
+**More powerful models with larger parameter sizes are on the way~ Stay tuned!**
+
+## CogVideoX-2B Gallery
+
+
+
+
A detailed wooden toy ship with intricately carved masts and sails is seen gliding smoothly over a plush, blue carpet that mimics the waves of the sea. The ship's hull is painted a rich brown, with tiny windows. The carpet, soft and textured, provides a perfect backdrop, resembling an oceanic expanse. Surrounding the ship are various other toys and children's items, hinting at a playful environment. The scene captures the innocence and imagination of childhood, with the toy ship's journey symbolizing endless adventures in a whimsical, indoor setting.
+
+
+
+
+
The camera follows behind a white vintage SUV with a black roof rack as it speeds up a steep dirt road surrounded by pine trees on a steep mountain slope, dust kicks up from its tires, the sunlight shines on the SUV as it speeds along the dirt road, casting a warm glow over the scene. The dirt road curves gently into the distance, with no other cars or vehicles in sight. The trees on either side of the road are redwoods, with patches of greenery scattered throughout. The car is seen from the rear following the curve with ease, making it seem as if it is on a rugged drive through the rugged terrain. The dirt road itself is surrounded by steep hills and mountains, with a clear blue sky above with wispy clouds.
+
+
+
+
+
A street artist, clad in a worn-out denim jacket and a colorful bandana, stands before a vast concrete wall in the heart, holding a can of spray paint, spray-painting a colorful bird on a mottled wall.
+
+
+
+
+
In the haunting backdrop of a war-torn city, where ruins and crumbled walls tell a story of devastation, a poignant close-up frames a young girl. Her face is smudged with ash, a silent testament to the chaos around her. Her eyes glistening with a mix of sorrow and resilience, capturing the raw emotion of a world that has lost its innocence to the ravages of conflict.
+
+
+## Model Introduction
+
+CogVideoX is an open-source version of the video generation model, which is homologous
+to [清影](https://chatglm.cn/video?fr=osm_cogvideox).
+
+The table below shows the list of video generation models we currently provide,
+along with related basic information:
+
+| Model Name | CogVideoX-2B |
+|-------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
+| Prompt Language | English |
+| GPU Memory Required for Inference (FP16) | 18GB if using [SAT](https://github.com/THUDM/SwissArmyTransformer); 36GB if using diffusers (will be optimized before the PR is merged) |
+| GPU Memory Required for Fine-tuning(bs=1) | 40GB |
+| Prompt Max Length | 226 Tokens |
+| Video Length | 6 seconds |
+| Frames Per Second | 8 frames |
+| Resolution | 720 * 480 |
+| Quantized Inference | Not Supported |
+| Multi-card Inference | Not Supported |
+| Download Link (HF diffusers Model) | 🤗 [Huggingface](https://huggingface.co/THUDM/CogVideoX-2B) [🤖 ModelScope](https://modelscope.cn/models/ZhipuAI/CogVideoX-2b) [💫 WiseModel](https://wisemodel.cn/models/ZhipuAI/CogVideoX-2b) |
+| Download Link (SAT Model) | [SAT](./sat/README.md) |
+
+## Project Structure
+
+This open-source repository will guide developers to quickly get started with the basic usage and fine-tuning examples
+of the **CogVideoX** open-source model.
+
+### Inference
+
++ [cli_demo](inference/cli_demo.py): A more detailed explanation of the inference code, mentioning the significance of common parameters.
++ [cli_vae_demo](inference/cli_vae_demo.py): Executing the VAE inference code alone currently requires 71GB of memory, but it will be optimized in the future.
++ [convert_demo](inference/convert_demo.py): How to convert user input into a format suitable for CogVideoX. Because CogVideoX is trained on long caption, we need to convert the input text to be consistent with the training distribution using a LLM. By default, the script uses GLM4, but it can also be replaced with any other LLM such as GPT, Gemini, etc.
++ [web_demo](inference/web_demo.py): A simple streamlit web application demonstrating how to use the CogVideoX-2B model to generate videos.
+
+
+
+
+
+### sat
+
++ [sat_demo](sat/README.md): Contains the inference code and fine-tuning code of SAT weights. It is
+ recommended to improve based on the CogVideoX model structure. Innovative researchers use this code to better perform
+ rapid stacking and development.
+
+### Tools
+
+This folder contains some tools for model conversion / caption generation, etc.
+
++ [convert_weight_sat2hf](tools/convert_weight_sat2hf.py): Convert SAT model weights to Huggingface model weights.
++ [caption_demo](tools/caption): Caption tool, a model that understands videos and outputs them in text.
+
+## Project Plan
+
+- [x] Open source CogVideoX model
+ - [x] Open source 3D Causal VAE used in CogVideoX.
+ - [x] CogVideoX model inference example (CLI / Web Demo)
+ - [x] CogVideoX online experience demo (Huggingface Space)
+ - [x] CogVideoX open source model API interface example (Huggingface)
+ - [x] CogVideoX model fine-tuning example (SAT)
+ - [ ] CogVideoX model fine-tuning example (Huggingface / SAT)
+ - [ ] Open source CogVideoX-Pro (adapted for CogVideoX-2B suite)
+ - [x] Release CogVideoX technical report
+
+We welcome your contributions. You can click [here](resources/contribute.md) for more information.
+
+## Model License
+
+The code in this repository is released under the [Apache 2.0 License](LICENSE).
+
+The model weights and implementation code are released under the [CogVideoX LICENSE](MODEL_LICENSE).
+
+## CogVideo(ICLR'23)
+The official repo for the paper: [CogVideo: Large-scale Pretraining for Text-to-Video Generation via Transformers](https://arxiv.org/abs/2205.15868) is on the [CogVideo branch](https://github.com/THUDM/CogVideo/tree/CogVideo)
+
+**CogVideo is able to generate relatively high-frame-rate videos.**
+A 4-second clip of 32 frames is shown below.
+
+
+
+
+
+
+
+
+
+The demo for CogVideo is at [https://models.aminer.cn/cogvideo](https://models.aminer.cn/cogvideo/), where you can get hands-on practice on text-to-video generation. *The original input is in Chinese.*
+
+
+## Citation
+
+🌟 If you find our work helpful, please leave us a star and cite our paper.
+
+```
+@article{yang2024cogvideox,
+ title={CogVideoX: Text-to-Video Diffusion Models with An Expert Transformer},
+ author={Zhuoyi Yang and Jiayan Teng and Wendi Zheng and Ming Ding and Shiyu Huang and JiaZheng Xu and Yuanming Yang and Xiaohan Zhang and Xiaotao Gu and Guanyu Feng and Da Yin and Wenyi Hong and Weihan Wang and Yean Cheng and Yuxuan Zhang and Ting Liu and Bin Xu and Yuxiao Dong and Jie Tang},
+ year={2024},
+}
+@article{hong2022cogvideo,
+ title={CogVideo: Large-scale Pretraining for Text-to-Video Generation via Transformers},
+ author={Hong, Wenyi and Ding, Ming and Zheng, Wendi and Liu, Xinghan and Tang, Jie},
+ journal={arXiv preprint arXiv:2205.15868},
+ year={2022}
+}
+```
diff --git a/README_zh.md b/README_zh.md
new file mode 100644
index 0000000000000000000000000000000000000000..13b509d77ad3ab5a2986c008b34b571b1030ae60
--- /dev/null
+++ b/README_zh.md
@@ -0,0 +1,149 @@
+# CogVideo && CogVideoX
+
+[Read this in English.](./README_zh)
+
+
+
A detailed wooden toy ship with intricately carved masts and sails is seen gliding smoothly over a plush, blue carpet that mimics the waves of the sea. The ship's hull is painted a rich brown, with tiny windows. The carpet, soft and textured, provides a perfect backdrop, resembling an oceanic expanse. Surrounding the ship are various other toys and children's items, hinting at a playful environment. The scene captures the innocence and imagination of childhood, with the toy ship's journey symbolizing endless adventures in a whimsical, indoor setting.
+
+
+
+
+
The camera follows behind a white vintage SUV with a black roof rack as it speeds up a steep dirt road surrounded by pine trees on a steep mountain slope, dust kicks up from its tires, the sunlight shines on the SUV as it speeds along the dirt road, casting a warm glow over the scene. The dirt road curves gently into the distance, with no other cars or vehicles in sight. The trees on either side of the road are redwoods, with patches of greenery scattered throughout. The car is seen from the rear following the curve with ease, making it seem as if it is on a rugged drive through the rugged terrain. The dirt road itself is surrounded by steep hills and mountains, with a clear blue sky above with wispy clouds.
+
+
+
+
+
A street artist, clad in a worn-out denim jacket and a colorful bandana, stands before a vast concrete wall in the heart, holding a can of spray paint, spray-painting a colorful bird on a mottled wall.
+
+
+
+
+
In the haunting backdrop of a war-torn city, where ruins and crumbled walls tell a story of devastation, a poignant close-up frames a young girl. Her face is smudged with ash, a silent testament to the chaos around her. Her eyes glistening with a mix of sorrow and resilience, capturing the raw emotion of a world that has lost its innocence to the ravages of conflict.
+
+CogVideo的demo网站在[https://models.aminer.cn/cogvideo](https://models.aminer.cn/cogvideo/)。您可以在这里体验文本到视频生成。*原始输入为中文。*
+
+## 引用
+
+🌟 如果您发现我们的工作有所帮助,欢迎引用我们的文章,留下宝贵的stars
+
+```
+@article{yang2024cogvideox,
+ title={CogVideoX: Text-to-Video Diffusion Models with An Expert Transformer},
+ author={Zhuoyi Yang and Jiayan Teng and Wendi Zheng and Ming Ding and Shiyu Huang and JiaZheng Xu and Yuanming Yang and Xiaohan Zhang and Xiaotao Gu and Guanyu Feng and Da Yin and Wenyi Hong and Weihan Wang and Yean Cheng and Yuxuan Zhang and Ting Liu and Bin Xu and Yuxiao Dong and Jie Tang},
+ year={2024},
+}
+@article{hong2022cogvideo,
+ title={CogVideo: Large-scale Pretraining for Text-to-Video Generation via Transformers},
+ author={Hong, Wenyi and Ding, Ming and Zheng, Wendi and Liu, Xinghan and Tang, Jie},
+ journal={arXiv preprint arXiv:2205.15868},
+ year={2022}
+}
+```
\ No newline at end of file
diff --git a/gradio_demo.py b/gradio_demo.py
new file mode 100644
index 0000000000000000000000000000000000000000..65eeb4896ae4da6d0c7a7e42fce87cdf3a00709e
--- /dev/null
+++ b/gradio_demo.py
@@ -0,0 +1,254 @@
+import os
+import tempfile
+import threading
+import time
+
+import gradio as gr
+import numpy as np
+import torch
+from diffusers import CogVideoXPipeline
+from datetime import datetime, timedelta
+from openai import OpenAI
+import spaces
+import imageio
+import moviepy.editor as mp
+from typing import List, Union
+import PIL
+
+dtype = torch.bfloat16
+device = "cuda" if torch.cuda.is_available() else "cpu"
+pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=dtype).to(device)
+
+sys_prompt = """You are part of a team of bots that creates videos. You work with an assistant bot that will draw anything you say in square brackets.
+
+For example , outputting " a beautiful morning in the woods with the sun peaking through the trees " will trigger your partner bot to output an video of a forest morning , as described. You will be prompted by people looking to create detailed , amazing videos. The way to accomplish this is to take their short prompts and make them extremely detailed and descriptive.
+There are a few rules to follow:
+
+You will only ever output a single video description per user request.
+
+When modifications are requested , you should not simply make the description longer . You should refactor the entire description to integrate the suggestions.
+Other times the user will not want modifications , but instead want a new image . In this case , you should ignore your previous conversation with the user.
+
+Video descriptions must have the same num of words as examples below. Extra words will be ignored.
+"""
+
+
+def export_to_video_imageio(
+ video_frames: Union[List[np.ndarray], List[PIL.Image.Image]], output_video_path: str = None, fps: int = 8
+) -> str:
+ """
+ Export the video frames to a video file using imageio lib to Avoid "green screen" issue (for example CogVideoX)
+ """
+ if output_video_path is None:
+ output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name
+
+ if isinstance(video_frames[0], PIL.Image.Image):
+ video_frames = [np.array(frame) for frame in video_frames]
+
+ with imageio.get_writer(output_video_path, fps=fps) as writer:
+ for frame in video_frames:
+ writer.append_data(frame)
+
+ return output_video_path
+
+
+def convert_prompt(prompt: str, retry_times: int = 3) -> str:
+ if not os.environ.get("OPENAI_API_KEY"):
+ return prompt
+ client = OpenAI()
+ text = prompt.strip()
+
+ for i in range(retry_times):
+ response = client.chat.completions.create(
+ messages=[
+ {"role": "system", "content": sys_prompt},
+ {"role": "user",
+ "content": 'Create an imaginative video descriptive caption or modify an earlier caption for the user input : "a girl is on the beach"'},
+ {"role": "assistant",
+ "content": "A radiant woman stands on a deserted beach, arms outstretched, wearing a beige trench coat, white blouse, light blue jeans, and chic boots, against a backdrop of soft sky and sea. Moments later, she is seen mid-twirl, arms exuberant, with the lighting suggesting dawn or dusk. Then, she runs along the beach, her attire complemented by an off-white scarf and black ankle boots, the tranquil sea behind her. Finally, she holds a paper airplane, her pose reflecting joy and freedom, with the ocean's gentle waves and the sky's soft pastel hues enhancing the serene ambiance."},
+ {"role": "user",
+ "content": 'Create an imaginative video descriptive caption or modify an earlier caption for the user input : "A man jogging on a football field"'},
+ {"role": "assistant",
+ "content": "A determined man in athletic attire, including a blue long-sleeve shirt, black shorts, and blue socks, jogs around a snow-covered soccer field, showcasing his solitary exercise in a quiet, overcast setting. His long dreadlocks, focused expression, and the serene winter backdrop highlight his dedication to fitness. As he moves, his attire, consisting of a blue sports sweatshirt, black athletic pants, gloves, and sneakers, grips the snowy ground. He is seen running past a chain-link fence enclosing the playground area, with a basketball hoop and children's slide, suggesting a moment of solitary exercise amidst the empty field."},
+ {"role": "user",
+ "content": 'Create an imaginative video descriptive caption or modify an earlier caption for the user input : " A woman is dancing, HD footage, close-up"'},
+ {"role": "assistant",
+ "content": "A young woman with her hair in an updo and wearing a teal hoodie stands against a light backdrop, initially looking over her shoulder with a contemplative expression. She then confidently makes a subtle dance move, suggesting rhythm and movement. Next, she appears poised and focused, looking directly at the camera. Her expression shifts to one of introspection as she gazes downward slightly. Finally, she dances with confidence, her left hand over her heart, symbolizing a poignant moment, all while dressed in the same teal hoodie against a plain, light-colored background."},
+ {"role": "user",
+ "content": f'Create an imaginative video descriptive caption or modify an earlier caption in ENGLISH for the user input: "{text}"'},
+ ],
+ model="glm-4-0520",
+ temperature=0.01,
+ top_p=0.7,
+ stream=False,
+ max_tokens=250,
+ )
+ if response.choices:
+ return response.choices[0].message.content
+ return prompt
+
+
+@spaces.GPU(duration=240)
+def infer(
+ prompt: str,
+ num_inference_steps: int,
+ guidance_scale: float,
+ progress=gr.Progress(track_tqdm=True)
+):
+ torch.cuda.empty_cache()
+
+ prompt_embeds, _ = pipe.encode_prompt(
+ prompt=prompt,
+ negative_prompt=None,
+ do_classifier_free_guidance=True,
+ num_videos_per_prompt=1,
+ max_sequence_length=226,
+ device=device,
+ dtype=dtype,
+ )
+
+ video = pipe(
+ num_inference_steps=num_inference_steps,
+ guidance_scale=guidance_scale,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=torch.zeros_like(prompt_embeds),
+ ).frames[0]
+
+
+ return video
+
+
+def save_video(tensor):
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
+ video_path = f"./output/{timestamp}.mp4"
+ os.makedirs(os.path.dirname(video_path), exist_ok=True)
+ export_to_video_imageio(tensor[1:], video_path)
+ return video_path
+
+def convert_to_gif(video_path):
+ clip = mp.VideoFileClip(video_path)
+ clip = clip.set_fps(8)
+ clip = clip.resize(height=240)
+ gif_path = video_path.replace('.mp4', '.gif')
+ clip.write_gif(gif_path, fps=8)
+ return gif_path
+
+
+def delete_old_files():
+ while True:
+ now = datetime.now()
+ cutoff = now - timedelta(minutes=10)
+ output_dir = './output'
+ for filename in os.listdir(output_dir):
+ file_path = os.path.join(output_dir, filename)
+ if os.path.isfile(file_path):
+ file_mtime = datetime.fromtimestamp(os.path.getmtime(file_path))
+ if file_mtime < cutoff:
+ os.remove(file_path)
+ time.sleep(600) # Sleep for 10 minutes
+
+
+threading.Thread(target=delete_old_files, daemon=True).start()
+
+with gr.Blocks() as demo:
+ gr.Markdown("""
+
+ ⚠️ This demo is for academic research and experiential use only.
+ Users should strictly adhere to local laws and ethics.
+
+ """)
+ with gr.Row():
+ with gr.Column():
+ prompt = gr.Textbox(label="Prompt (Less than 200 Words)", placeholder="Enter your prompt here", lines=5)
+ with gr.Row():
+ gr.Markdown(
+ "✨Upon pressing the enhanced prompt button, we will use [GLM-4 Model](https://github.com/THUDM/GLM-4) to polish the prompt and overwrite the original one.")
+ enhance_button = gr.Button("✨ Enhance Prompt(Optional)")
+
+ with gr.Column():
+ gr.Markdown("**Optional Parameters** (default values are recommended) "
+ "Turn Inference Steps larger if you want more detailed video, but it will be slower. "
+ "50 steps are recommended for most cases. will cause 120 seconds for inference. ")
+ with gr.Row():
+ num_inference_steps = gr.Number(label="Inference Steps", value=50)
+ guidance_scale = gr.Number(label="Guidance Scale", value=6.0)
+ generate_button = gr.Button("🎬 Generate Video")
+
+ with gr.Column():
+ video_output = gr.Video(label="CogVideoX Generate Video", width=720, height=480)
+ with gr.Row():
+ download_video_button = gr.File(label="📥 Download Video", visible=False)
+ download_gif_button = gr.File(label="📥 Download GIF", visible=False)
+
+ gr.Markdown("""
+
+
+
Prompt
+
Video URL
+
Inference Steps
+
Guidance Scale
+
+
+
A detailed wooden toy ship with intricately carved masts and sails is seen gliding smoothly over a plush, blue carpet that mimics the waves of the sea. The ship's hull is painted a rich brown, with tiny windows. The carpet, soft and textured, provides a perfect backdrop, resembling an oceanic expanse. Surrounding the ship are various other toys and children's items, hinting at a playful environment. The scene captures the innocence and imagination of childhood, with the toy ship's journey symbolizing endless adventures in a whimsical, indoor setting.
The camera follows behind a white vintage SUV with a black roof rack as it speeds up a steep dirt road surrounded by pine trees on a steep mountain slope, dust kicks up from it’s tires, the sunlight shines on the SUV as it speeds along the dirt road, casting a warm glow over the scene. The dirt road curves gently into the distance, with no other cars or vehicles in sight. The trees on either side of the road are redwoods, with patches of greenery scattered throughout. The car is seen from the rear following the curve with ease, making it seem as if it is on a rugged drive through the rugged terrain. The dirt road itself is surrounded by steep hills and mountains, with a clear blue sky above with wispy clouds.
A street artist, clad in a worn-out denim jacket and a colorful bandana, stands before a vast concrete wall in the heart, holding a can of spray paint, spray-painting a colorful bird on a mottled wall.
In the haunting backdrop of a war-torn city, where ruins and crumbled walls tell a story of devastation, a poignant close-up frames a young girl. Her face is smudged with ash, a silent testament to the chaos around her. Her eyes glistening with a mix of sorrow and resilience, capturing the raw emotion of a world that has lost its innocence to the ravages of conflict.
+ """)
+
+
+ def generate(prompt, num_inference_steps, guidance_scale, progress=gr.Progress(track_tqdm=True)):
+ tensor = infer(prompt, num_inference_steps, guidance_scale, progress=progress)
+ video_path = save_video(tensor)
+ video_update = gr.update(visible=True, value=video_path)
+ gif_path = convert_to_gif(video_path)
+ gif_update = gr.update(visible=True, value=gif_path)
+
+ return video_path, video_update, gif_update
+
+
+ def enhance_prompt_func(prompt):
+ return convert_prompt(prompt, retry_times=1)
+
+
+ generate_button.click(
+ generate,
+ inputs=[prompt, num_inference_steps, guidance_scale],
+ outputs=[video_output, download_video_button, download_gif_button]
+ )
+
+ enhance_button.click(
+ enhance_prompt_func,
+ inputs=[prompt],
+ outputs=[prompt]
+ )
+
+if __name__ == "__main__":
+ demo.launch(server_name="127.0.0.1", server_port=7870, share=True)
diff --git a/inference/cli_demo.py b/inference/cli_demo.py
new file mode 100644
index 0000000000000000000000000000000000000000..c480d43918ba412232a4a953c73031953365bc0e
--- /dev/null
+++ b/inference/cli_demo.py
@@ -0,0 +1,127 @@
+"""
+This script demonstrates how to generate a video from a text prompt using CogVideoX with 🤗Huggingface Diffusers Pipeline.
+
+Note:
+ This script requires the `diffusers>=0.30.0` library to be installed.
+ If the video exported using OpenCV appears “completely green” and cannot be viewed, lease switch to a different player to watch it. This is a normal phenomenon.
+
+Run the script:
+ $ python cli_demo.py --prompt "A girl ridding a bike." --model_path THUDM/CogVideoX-2b
+
+"""
+
+import argparse
+import tempfile
+from typing import Union, List
+
+import PIL
+import imageio
+import numpy as np
+import torch
+from diffusers import CogVideoXPipeline
+
+
+def export_to_video_imageio(
+ video_frames: Union[List[np.ndarray], List[PIL.Image.Image]], output_video_path: str = None, fps: int = 8
+) -> str:
+ """
+ Export the video frames to a video file using imageio lib to Avoid "green screen" issue (for example CogVideoX)
+ """
+ if output_video_path is None:
+ output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name
+ if isinstance(video_frames[0], PIL.Image.Image):
+ video_frames = [np.array(frame) for frame in video_frames]
+ with imageio.get_writer(output_video_path, fps=fps) as writer:
+ for frame in video_frames:
+ writer.append_data(frame)
+ return output_video_path
+
+
+def generate_video(
+ prompt: str,
+ model_path: str,
+ output_path: str = "./output.mp4",
+ num_inference_steps: int = 50,
+ guidance_scale: float = 6.0,
+ num_videos_per_prompt: int = 1,
+ device: str = "cuda",
+ dtype: torch.dtype = torch.float16,
+):
+ """
+ Generates a video based on the given prompt and saves it to the specified path.
+
+ Parameters:
+ - prompt (str): The description of the video to be generated.
+ - model_path (str): The path of the pre-trained model to be used.
+ - output_path (str): The path where the generated video will be saved.
+ - num_inference_steps (int): Number of steps for the inference process. More steps can result in better quality.
+ - guidance_scale (float): The scale for classifier-free guidance. Higher values can lead to better alignment with the prompt.
+ - num_videos_per_prompt (int): Number of videos to generate per prompt.
+ - device (str): The device to use for computation (e.g., "cuda" or "cpu").
+ - dtype (torch.dtype): The data type for computation (default is torch.float16).
+ """
+
+ # Load the pre-trained CogVideoX pipeline with the specified precision (float16) and move it to the specified device
+ pipe = CogVideoXPipeline.from_pretrained(model_path, torch_dtype=dtype).to(device)
+
+ # Encode the prompt to get the prompt embeddings
+ prompt_embeds, _ = pipe.encode_prompt(
+ prompt=prompt, # The textual description for video generation
+ negative_prompt=None, # The negative prompt to guide the video generation
+ do_classifier_free_guidance=True, # Whether to use classifier-free guidance
+ num_videos_per_prompt=num_videos_per_prompt, # Number of videos to generate per prompt
+ max_sequence_length=226, # Maximum length of the sequence, must be 226
+ device=device, # Device to use for computation
+ dtype=dtype, # Data type for computation
+ )
+
+ # Generate the video frames using the pipeline
+ video = pipe(
+ num_inference_steps=num_inference_steps, # Number of inference steps
+ guidance_scale=guidance_scale, # Guidance scale for classifier-free guidance
+ prompt_embeds=prompt_embeds, # Encoded prompt embeddings
+ negative_prompt_embeds=torch.zeros_like(prompt_embeds), # Not Supported negative prompt
+ ).frames[0]
+
+ # Export the generated frames to a video file. fps must be 8
+ export_to_video_imageio(video, output_path, fps=8)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Generate a video from a text prompt using CogVideoX")
+ parser.add_argument("--prompt", type=str, required=True, help="The description of the video to be generated")
+ parser.add_argument(
+ "--model_path", type=str, default="THUDM/CogVideoX-2b", help="The path of the pre-trained model to be used"
+ )
+ parser.add_argument(
+ "--output_path", type=str, default="./output.mp4", help="The path where the generated video will be saved"
+ )
+ parser.add_argument(
+ "--num_inference_steps", type=int, default=50, help="Number of steps for the inference process"
+ )
+ parser.add_argument("--guidance_scale", type=float, default=6.0, help="The scale for classifier-free guidance")
+ parser.add_argument("--num_videos_per_prompt", type=int, default=1, help="Number of videos to generate per prompt")
+ parser.add_argument(
+ "--device", type=str, default="cuda", help="The device to use for computation (e.g., 'cuda' or 'cpu')"
+ )
+
+ parser.add_argument(
+ "--dtype", type=str, default="float16", help="The data type for computation (e.g., 'float16' or 'float32')"
+ )
+
+ args = parser.parse_args()
+
+ # Convert dtype argument to torch.dtype, NOT suggest BF16.
+ dtype = torch.float16 if args.dtype == "float16" else torch.float32
+
+ # main function to generate video.
+ generate_video(
+ prompt=args.prompt,
+ model_path=args.model_path,
+ output_path=args.output_path,
+ num_inference_steps=args.num_inference_steps,
+ guidance_scale=args.guidance_scale,
+ num_videos_per_prompt=args.num_videos_per_prompt,
+ device=args.device,
+ dtype=dtype,
+ )
diff --git a/inference/cli_vae_demo.py b/inference/cli_vae_demo.py
new file mode 100644
index 0000000000000000000000000000000000000000..b133f2069e1f1839d5d1f1492797813ebb8bad88
--- /dev/null
+++ b/inference/cli_vae_demo.py
@@ -0,0 +1,103 @@
+"""
+This script demonstrates how to encode video frames using a pre-trained CogVideoX model with 🤗 Huggingface Diffusers.
+
+Note:
+ This script requires the `diffusers>=0.30.0` library to be installed.
+ If the video appears “completely green” and cannot be viewed, please switch to a different player to watch it. This is a normal phenomenon.
+ Cost 71GB of GPU memory for encoding a 6s video at 720p resolution.
+
+Run the script:
+ $ python cli_demo.py --model_path THUDM/CogVideoX-2b --video_path path/to/video.mp4 --output_path path/to/output
+
+"""
+
+import argparse
+import torch
+import imageio
+import numpy as np
+from diffusers import AutoencoderKLCogVideoX
+from torchvision import transforms
+
+
+def vae_demo(model_path, video_path, dtype, device):
+ """
+ Loads a pre-trained AutoencoderKLCogVideoX model and encodes the video frames.
+
+ Parameters:
+ - model_path (str): The path to the pre-trained model.
+ - video_path (str): The path to the video file.
+ - dtype (torch.dtype): The data type for computation.
+ - device (str): The device to use for computation (e.g., "cuda" or "cpu").
+
+ Returns:
+ - torch.Tensor: The encoded video frames.
+ """
+ # Load the pre-trained model
+ model = AutoencoderKLCogVideoX.from_pretrained(model_path, torch_dtype=dtype).to(device)
+
+ # Load video frames
+ video_reader = imageio.get_reader(video_path, "ffmpeg")
+ frames = []
+ for frame in video_reader:
+ frames.append(frame)
+ video_reader.close()
+
+ # Transform frames to Tensor
+ transform = transforms.Compose(
+ [
+ transforms.ToTensor(),
+ ]
+ )
+ frames_tensor = torch.stack([transform(frame) for frame in frames]).to(device)
+
+ # Add batch dimension and reshape to [1, 3, 49, 480, 720]
+ frames_tensor = frames_tensor.permute(1, 0, 2, 3).unsqueeze(0).to(dtype).to(device)
+
+ # Run the model with Encoder and Decoder
+ with torch.no_grad():
+ output = model(frames_tensor)
+
+ return output
+
+
+def save_video(tensor, output_path):
+ """
+ Saves the encoded video frames to a video file.
+
+ Parameters:
+ - tensor (torch.Tensor): The encoded video frames.
+ - output_path (str): The path to save the output video.
+ """
+ # Remove batch dimension and permute back to [49, 480, 720, 3]
+ frames = tensor[0].squeeze(0).permute(1, 2, 3, 0).cpu().numpy()
+
+ # Clip values to [0, 1] and convert to uint8
+ frames = np.clip(frames, 0, 1)
+ frames = (frames * 255).astype(np.uint8)
+
+ # Save frames to video
+ writer = imageio.get_writer(output_path + "/output.mp4", fps=30)
+ for frame in frames:
+ writer.append_data(frame)
+ writer.close()
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Convert a CogVideoX model to Diffusers")
+ parser.add_argument("--model_path", type=str, required=True, help="The path to the CogVideoX model")
+ parser.add_argument("--video_path", type=str, required=True, help="The path to the video file")
+ parser.add_argument("--output_path", type=str, default="./", help="The path to save the output video")
+ parser.add_argument(
+ "--dtype", type=str, default="float16", help="The data type for computation (e.g., 'float16' or 'float32')"
+ )
+ parser.add_argument(
+ "--device", type=str, default="cuda", help="The device to use for computation (e.g., 'cuda' or 'cpu')"
+ )
+ args = parser.parse_args()
+
+ # Set device and dtype
+ device = torch.device(args.device)
+ dtype = torch.float16 if args.dtype == "float16" else torch.float32
+
+ output = vae_demo(args.model_path, args.video_path, dtype, device)
+ save_video(output, args.output_path)
diff --git a/inference/convert_demo.py b/inference/convert_demo.py
new file mode 100644
index 0000000000000000000000000000000000000000..97815e7ca169949b1c205266709d64d11a849134
--- /dev/null
+++ b/inference/convert_demo.py
@@ -0,0 +1,92 @@
+"""
+
+The CogVideoX model is pre-trained and fine-tuned using longer and more detailed prompts.Therefore, it requires highly granular and detailed prompts as input.This script aims to transform user inputs into executable inputs for CogVideoX, enabling superior video generation.
+
+This step is not mandatory; the model will still function correctly and without errors even if the prompts are not refined using this script. However, we strongly recommend using it to ensure the generation of high-quality videos.
+
+Note:
+Please set the OPENAI_API_KEY and OPENAI_BASE_URL(if needed) environment variable to your OpenAI API key before running this script.
+
+Run the script:
+ $ python convert_demo.py --prompt "A girl ridding a bike." # Using with OpenAI's API
+"""
+
+import argparse
+
+from openai import OpenAI
+
+
+sys_prompt = """You are part of a team of bots that creates videos. You work with an assistant bot that will draw anything you say in square brackets.
+
+For example , outputting " a beautiful morning in the woods with the sun peaking through the trees " will trigger your partner bot to output an video of a forest morning , as described. You will be prompted by people looking to create detailed , amazing videos. The way to accomplish this is to take their short prompts and make them extremely detailed and descriptive.
+There are a few rules to follow:
+
+You will only ever output a single video description per user request.
+
+When modifications are requested , you should not simply make the description longer . You should refactor the entire description to integrate the suggestions.
+Other times the user will not want modifications , but instead want a new image . In this case , you should ignore your previous conversation with the user.
+
+Video descriptions must have the same num of words as examples below. Extra words will be ignored.
+"""
+
+
+def convert_prompt(prompt: str, retry_times: int = 3):
+ """
+ Convert a prompt to a format that can be used by the model for inference
+ """
+
+ client = OpenAI()
+ text = prompt.strip()
+
+ for i in range(retry_times):
+ response = client.chat.completions.create(
+ messages=[
+ {"role": "system", "content": f"{sys_prompt}"},
+ {
+ "role": "user",
+ "content": 'Create an imaginative video descriptive caption or modify an earlier caption for the user input : " a girl is on the beach"',
+ },
+ {
+ "role": "assistant",
+ "content": "A radiant woman stands on a deserted beach, arms outstretched, wearing a beige trench coat, white blouse, light blue jeans, and chic boots, against a backdrop of soft sky and sea. Moments later, she is seen mid-twirl, arms exuberant, with the lighting suggesting dawn or dusk. Then, she runs along the beach, her attire complemented by an off-white scarf and black ankle boots, the tranquil sea behind her. Finally, she holds a paper airplane, her pose reflecting joy and freedom, with the ocean's gentle waves and the sky's soft pastel hues enhancing the serene ambiance.",
+ },
+ {
+ "role": "user",
+ "content": 'Create an imaginative video descriptive caption or modify an earlier caption for the user input : " A man jogging on a football field"',
+ },
+ {
+ "role": "assistant",
+ "content": "A determined man in athletic attire, including a blue long-sleeve shirt, black shorts, and blue socks, jogs around a snow-covered soccer field, showcasing his solitary exercise in a quiet, overcast setting. His long dreadlocks, focused expression, and the serene winter backdrop highlight his dedication to fitness. As he moves, his attire, consisting of a blue sports sweatshirt, black athletic pants, gloves, and sneakers, grips the snowy ground. He is seen running past a chain-link fence enclosing the playground area, with a basketball hoop and children's slide, suggesting a moment of solitary exercise amidst the empty field.",
+ },
+ {
+ "role": "user",
+ "content": 'Create an imaginative video descriptive caption or modify an earlier caption for the user input : " A woman is dancing, HD footage, close-up"',
+ },
+ {
+ "role": "assistant",
+ "content": "A young woman with her hair in an updo and wearing a teal hoodie stands against a light backdrop, initially looking over her shoulder with a contemplative expression. She then confidently makes a subtle dance move, suggesting rhythm and movement. Next, she appears poised and focused, looking directly at the camera. Her expression shifts to one of introspection as she gazes downward slightly. Finally, she dances with confidence, her left hand over her heart, symbolizing a poignant moment, all while dressed in the same teal hoodie against a plain, light-colored background.",
+ },
+ {
+ "role": "user",
+ "content": f'Create an imaginative video descriptive caption or modify an earlier caption in ENGLISH for the user input: " {text} "',
+ },
+ ],
+ model="glm-4-0520", # glm-4-0520 and gpt-4o have be tested
+ temperature=0.01,
+ top_p=0.7,
+ stream=False,
+ max_tokens=250,
+ )
+ if response.choices:
+ return response.choices[0].message.content
+ return prompt
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--prompt", type=str, required=True, help="Prompt to convert")
+ parser.add_argument("--retry_times", type=int, default=3, help="Number of times to retry the conversion")
+ args = parser.parse_args()
+
+ converted_prompt = convert_prompt(args.prompt, args.retry_times)
+ print(converted_prompt)
diff --git a/inference/web_demo.py b/inference/web_demo.py
new file mode 100644
index 0000000000000000000000000000000000000000..8695975e4fc49cd41eb5ae6770a3fbcd4bda6da8
--- /dev/null
+++ b/inference/web_demo.py
@@ -0,0 +1,214 @@
+"""
+This script is used to create a Streamlit web application for generating videos using the CogVideoX model.
+
+Run the script using Streamlit:
+ $ export OPENAI_API_KEY=your OpenAI Key or ZhiupAI Key
+ $ export OPENAI_BASE_URL=https://open.bigmodel.cn/api/paas/v4/ # using with ZhipuAI, Not using this when using OpenAI
+ $ streamlit run web_demo.py
+"""
+
+import base64
+import json
+import os
+import time
+from datetime import datetime
+from typing import List
+
+import imageio
+import numpy as np
+import streamlit as st
+import torch
+from convert_demo import convert_prompt
+from diffusers import CogVideoXPipeline
+
+
+model_path: str = "THUDM/CogVideoX-2b"
+
+
+# Load the model at the start
+@st.cache_resource
+def load_model(model_path: str, dtype: torch.dtype, device: str) -> CogVideoXPipeline:
+ """
+ Load the CogVideoX model.
+
+ Args:
+ - model_path (str): Path to the model.
+ - dtype (torch.dtype): Data type for model.
+ - device (str): Device to load the model on.
+
+ Returns:
+ - CogVideoXPipeline: Loaded model pipeline.
+ """
+ return CogVideoXPipeline.from_pretrained(model_path, torch_dtype=dtype).to(device)
+
+
+# Define a function to generate video based on the provided prompt and model path
+def generate_video(
+ pipe: CogVideoXPipeline,
+ prompt: str,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 6.0,
+ num_videos_per_prompt: int = 1,
+ device: str = "cuda",
+ dtype: torch.dtype = torch.float16,
+) -> List[np.ndarray]:
+ """
+ Generate a video based on the provided prompt and model path.
+
+ Args:
+ - pipe (CogVideoXPipeline): The pipeline for generating videos.
+ - prompt (str): Text prompt for video generation.
+ - num_inference_steps (int): Number of inference steps.
+ - guidance_scale (float): Guidance scale for generation.
+ - num_videos_per_prompt (int): Number of videos to generate per prompt.
+ - device (str): Device to run the generation on.
+ - dtype (torch.dtype): Data type for the model.
+
+ Returns:
+ - List[np.ndarray]: Generated video frames.
+ """
+ prompt_embeds, _ = pipe.encode_prompt(
+ prompt=prompt,
+ negative_prompt=None,
+ do_classifier_free_guidance=True,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=226,
+ device=device,
+ dtype=dtype,
+ )
+
+ # Generate video
+ video = pipe(
+ num_inference_steps=num_inference_steps,
+ guidance_scale=guidance_scale,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=torch.zeros_like(prompt_embeds),
+ ).frames[0]
+ return video
+
+
+def save_video(video: List[np.ndarray], path: str, fps: int = 8) -> None:
+ """
+ Save the generated video to a file.
+
+ Args:
+ - video (List[np.ndarray]): Video frames.
+ - path (str): Path to save the video.
+ - fps (int): Frames per second for the video.
+ """
+ # Remove the first frame
+ video = video[1:]
+
+ writer = imageio.get_writer(path, fps=fps, codec="libx264")
+ for frame in video:
+ np_frame = np.array(frame)
+ writer.append_data(np_frame)
+
+ writer.close()
+
+
+def save_metadata(
+ prompt: str,
+ converted_prompt: str,
+ num_inference_steps: int,
+ guidance_scale: float,
+ num_videos_per_prompt: int,
+ path: str,
+) -> None:
+ """
+ Save metadata to a JSON file.
+
+ Args:
+ - prompt (str): Original prompt.
+ - converted_prompt (str): Converted prompt.
+ - num_inference_steps (int): Number of inference steps.
+ - guidance_scale (float): Guidance scale.
+ - num_videos_per_prompt (int): Number of videos per prompt.
+ - path (str): Path to save the metadata.
+ """
+ metadata = {
+ "prompt": prompt,
+ "converted_prompt": converted_prompt,
+ "num_inference_steps": num_inference_steps,
+ "guidance_scale": guidance_scale,
+ "num_videos_per_prompt": num_videos_per_prompt,
+ }
+ with open(path, "w") as f:
+ json.dump(metadata, f, indent=4)
+
+
+def main() -> None:
+ """
+ Main function to run the Streamlit web application.
+ """
+ st.set_page_config(page_title="CogVideoX-Demo", page_icon="🎥", layout="wide")
+ st.write("# CogVideoX 🎥")
+ dtype: torch.dtype = torch.float16
+ device: str = "cuda"
+
+ global pipe
+ pipe = load_model(model_path, dtype, device)
+
+ with st.sidebar:
+ st.info("It will take some time to generate a video (~90 seconds per videos in 50 steps).", icon="ℹ️")
+ num_inference_steps: int = st.number_input("Inference Steps", min_value=1, max_value=100, value=50)
+ guidance_scale: float = st.number_input("Guidance Scale", min_value=0.0, max_value=20.0, value=6.0)
+ num_videos_per_prompt: int = st.number_input("Videos per Prompt", min_value=1, max_value=10, value=1)
+
+ share_links_container = st.empty()
+
+ prompt: str = st.chat_input("Prompt")
+
+ if prompt:
+ # Not Necessary, Suggestions
+ with st.spinner("Refining prompts..."):
+ converted_prompt = convert_prompt(prompt=prompt, retry_times=1)
+ if converted_prompt is None:
+ st.error("Failed to Refining the prompt, Using origin one.")
+
+ st.info(f"**Origin prompt:** \n{prompt} \n \n**Convert prompt:** \n{converted_prompt}")
+ torch.cuda.empty_cache()
+
+ with st.spinner("Generating Video..."):
+ start_time = time.time()
+ video_paths = []
+
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
+ output_dir = f"./output/{timestamp}"
+ os.makedirs(output_dir, exist_ok=True)
+
+ metadata_path = os.path.join(output_dir, "config.json")
+ save_metadata(
+ prompt, converted_prompt, num_inference_steps, guidance_scale, num_videos_per_prompt, metadata_path
+ )
+
+ for i in range(num_videos_per_prompt):
+ video_path = os.path.join(output_dir, f"output_{i + 1}.mp4")
+
+ video = generate_video(
+ pipe, converted_prompt or prompt, num_inference_steps, guidance_scale, 1, device, dtype
+ )
+ save_video(video, video_path, fps=8)
+ video_paths.append(video_path)
+ with open(video_path, "rb") as video_file:
+ video_bytes: bytes = video_file.read()
+ st.video(video_bytes, autoplay=True, loop=True, format="video/mp4")
+ torch.cuda.empty_cache()
+
+ used_time: float = time.time() - start_time
+ st.success(f"Videos generated in {used_time:.2f} seconds.")
+
+ # Create download links in the sidebar
+ with share_links_container:
+ st.sidebar.write("### Download Links:")
+ for video_path in video_paths:
+ video_name = os.path.basename(video_path)
+ with open(video_path, "rb") as f:
+ video_bytes: bytes = f.read()
+ b64_video = base64.b64encode(video_bytes).decode()
+ href = f'Download {video_name}'
+ st.sidebar.markdown(href, unsafe_allow_html=True)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000000000000000000000000000000000000..09bc849926ce8a8868cf7853a4f58e9229176bfa
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,27 @@
+[tool.ruff]
+line-length = 119
+
+[tool.ruff.lint]
+# Never enforce `E501` (line length violations).
+ignore = ["C901", "E501", "E741", "F402", "F823"]
+select = ["C", "E", "F", "I", "W"]
+
+# Ignore import violations in all `__init__.py` files.
+[tool.ruff.lint.per-file-ignores]
+"__init__.py" = ["E402", "F401", "F403", "F811"]
+
+[tool.ruff.lint.isort]
+lines-after-imports = 2
+
+[tool.ruff.format]
+# Like Black, use double quotes for strings.
+quote-style = "double"
+
+# Like Black, indent with spaces, rather than tabs.
+indent-style = "space"
+
+# Like Black, respect magic trailing commas.
+skip-magic-trailing-comma = false
+
+# Like Black, automatically detect the appropriate line ending.
+line-ending = "auto"
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..6f89c1ca7c11d70c301e61f4ba61f3cfb0a06433
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,11 @@
+git+https://github.com/huggingface/diffusers.git@d1c575ad7ee0390c2735f50cc59a79aae666567a#egg=diffusers
+torch==2.4.0
+torchvision==0.19.0
+streamlit==1.37.0
+opencv-python
+imageio-ffmpeg==0.5.1
+openai==1.38.0
+transformers==4.43.4
+accelerate==0.33.0
+sentencepiece==0.2.0
+pillow==9.5.0
diff --git a/resources/CogVideoX.pdf b/resources/CogVideoX.pdf
new file mode 100644
index 0000000000000000000000000000000000000000..3b9317fcbfad10fb04791b462470cb6f967a53dd
--- /dev/null
+++ b/resources/CogVideoX.pdf
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:25ba30aafcd9604178c6d7adbd17f2bf1b251f3d29d1d29498e576075cb67c4e
+size 31028426
diff --git a/resources/WECHAT.md b/resources/WECHAT.md
new file mode 100644
index 0000000000000000000000000000000000000000..7f9620d24141549da2a9a8c0edcb7fa3ada96b73
--- /dev/null
+++ b/resources/WECHAT.md
@@ -0,0 +1,7 @@
+
+
+
+
扫码关注公众号,加入「 CogVideoX 交流群」
+
Scan the QR code to follow the official account and join the "CogVLM Discussion Group"
+
+
diff --git a/resources/contribute.md b/resources/contribute.md
new file mode 100644
index 0000000000000000000000000000000000000000..780ec49f7b5102a88a568a5f23b4d63e58efdf93
--- /dev/null
+++ b/resources/contribute.md
@@ -0,0 +1,50 @@
+# Contribution Guide
+
+There may still be many incomplete aspects in this project.
+
+We look forward to your contributions to the repository in the following areas. If you complete the work mentioned above
+and are willing to submit a PR and share it with the community, upon review, we
+will acknowledge your contribution on the project homepage.
+
+## Model Algorithms
+
+- Support for model quantization inference (Int4, Int8, etc. quantization engineering)
+- Support for multi-card inference / model inference concurrency engineering
+- Support for non-CUDA architecture inference devices
+
+## Model Engineering / Secondary Development
+
+- Model fine-tuning examples / best prompt practices
+- Video super-resolution/frame interpolation for enhancing video generation quality.
+- Any peripheral tools for the model
+- Any minimal complete open-source projects using the CogVideoX open-source model
+
+## Code Standards
+
+Good code style is an art. We have prepared a `pyproject.toml` configuration file for the project to standardize code
+style. You can organize the code according to the following specifications:
+
+1. Install the `ruff` tool
+
+```shell
+pip install ruff
+```
+
+Then, run the `ruff` tool
+
+```shell
+ruff check tools sat inference
+```
+
+Check the code style. If there are issues, you can automatically fix them using the `ruff format` command.
+
+```shell
+ruff format tools sat inference
+```
+
+Once your code meets the standard, there should be no errors.
+
+## Naming Conventions
+1. Please use English names, do not use Pinyin or other language names. All comments should be in English.
+2. Please strictly follow the PEP8 specification and use underscores to separate words. Do not use names like a, b, c.
+
diff --git a/resources/contribute_zh.md b/resources/contribute_zh.md
new file mode 100644
index 0000000000000000000000000000000000000000..b0dcf0977b94752b6f87fa0ebc120fbb082f4b2a
--- /dev/null
+++ b/resources/contribute_zh.md
@@ -0,0 +1,45 @@
+# 贡献指南
+
+本项目可能还存在很多不完善的内容。 我们期待您在以下方面与我们共建仓库, 如果您完成了上述工作并愿意PR和分享到社区,在通过审核后,我们将在项目首页感谢您的贡献。
+
+## 模型算法
+
+- 模型量化推理支持 (Int4,Int8等量化工程)
+- 模型多卡推理支持 / 模型推理并发工程
+- 非 CUDA 架构 推理设备支持
+
+## 模型工程 / 模型二次开发
+
+- 模型微调示例 / 最佳提示词实践
+- 视频超分/插帧,用于美化视频生成效果。
+- 任何模型周边工具
+- 任何使用CogVideoX开源模型制作的最小完整开源项目
+
+## 代码规范
+
+良好的代码风格是一种艺术,我们已经为项目准备好了`pyproject.toml`配置文件,用于规范代码风格。您可以按照以下规范梳理代码:
+
+1. 安装`ruff`工具
+
+```shell
+pip install ruff
+```
+
+接着,运行`ruff`工具
+
+```shell
+ruff check tools sat inference
+```
+
+检查代码风格,如果有问题,您可以通过`ruff formate`命令自动修复。
+
+```shell
+ruff formate tools sat inference
+```
+
+如果您的代码符合规范,应该不会出现任何的错误。
+
+## 命名规范
+
+- 请使用英文命名,不要使用拼音或者其他语言命名。所有的注释均使用英文。
+- 请严格遵循 PEP8 规范,使用下划线分割单词。请勿使用 a,b,c 这样的命名。
\ No newline at end of file
diff --git a/resources/logo.svg b/resources/logo.svg
new file mode 100644
index 0000000000000000000000000000000000000000..68333bea7e0ff28cc28732142be7d767bca49666
--- /dev/null
+++ b/resources/logo.svg
@@ -0,0 +1,298 @@
+
+
+
diff --git a/resources/videos/1.mp4 b/resources/videos/1.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..b34aeb5a7bd70a0af2e672aa7bd9c4060ca3e364
Binary files /dev/null and b/resources/videos/1.mp4 differ
diff --git a/resources/videos/2.mp4 b/resources/videos/2.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..9309a260bf8fd8cadc3e60b43c40e829ed81eb83
--- /dev/null
+++ b/resources/videos/2.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6e738926f262a28b3b9af6573987905457ea82cdcadb0ec04ad9ab134324f5cc
+size 1683616
diff --git a/resources/videos/3.mp4 b/resources/videos/3.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..0f203099fda1f85c133a5ed53ba2a1c3c1eaa002
Binary files /dev/null and b/resources/videos/3.mp4 differ
diff --git a/resources/videos/4.mp4 b/resources/videos/4.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..ed6fd25c119cccab9af759f27edc278e27c0b3fd
Binary files /dev/null and b/resources/videos/4.mp4 differ
diff --git a/resources/web_demo.png b/resources/web_demo.png
new file mode 100644
index 0000000000000000000000000000000000000000..90780eb292f18a86b9d9ef4c41c9d4a8ac703e54
--- /dev/null
+++ b/resources/web_demo.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0ac281bdbebe756ea9840cd5c13f04aafa3f05c2a16de1f75a45a6f31079e340
+size 4808873
diff --git a/resources/wechat.jpg b/resources/wechat.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..4974535bdda7c0de5eeafcb930c22ef43538d135
Binary files /dev/null and b/resources/wechat.jpg differ
diff --git a/sat/README.md b/sat/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..8d9ac8f9418457135aafd2faa4554c9c53a4415f
--- /dev/null
+++ b/sat/README.md
@@ -0,0 +1,182 @@
+# SAT CogVideoX-2B
+
+This folder contains the inference code using [SAT](https://github.com/THUDM/SwissArmyTransformer) weights and the
+fine-tuning code for SAT weights.
+
+This code is the framework used by the team to train the model. It has few comments and requires careful study.
+
+## Inference Model
+
+1. Ensure that you have correctly installed the dependencies required by this folder.
+
+```shell
+pip install -r requirements.txt
+```
+
+2. Download the model weights
+
+First, go to the SAT mirror to download the dependencies.
+
+```shell
+mkdir CogVideoX-2b-sat
+cd CogVideoX-2b-sat
+wget https://cloud.tsinghua.edu.cn/f/fdba7608a49c463ba754/?dl=1
+mv 'index.html?dl=1' vae.zip
+unzip vae.zip
+wget https://cloud.tsinghua.edu.cn/f/556a3e1329e74f1bac45/?dl=1
+mv 'index.html?dl=1' transformer.zip
+unzip transformer.zip
+```
+
+Then unzip, the model structure should look like this:
+
+```
+.
+├── transformer
+│ ├── 1000
+│ │ └── mp_rank_00_model_states.pt
+│ └── latest
+└── vae
+ └── 3d-vae.pt
+```
+
+Next, clone the T5 model, which is not used for training and fine-tuning, but must be used.
+
+```shell
+git lfs install
+git clone https://huggingface.co/google/t5-v1_1-xxl.git
+```
+
+**We don't need the tf_model.h5** file. This file can be deleted.
+
+3. Modify the file `configs/cogvideox_2b_infer.yaml`.
+
+```yaml
+load: "{your_CogVideoX-2b-sat_path}/transformer" ## Transformer model path
+
+conditioner_config:
+ target: sgm.modules.GeneralConditioner
+ params:
+ emb_models:
+ - is_trainable: false
+ input_key: txt
+ ucg_rate: 0.1
+ target: sgm.modules.encoders.modules.FrozenT5Embedder
+ params:
+ model_dir: "google/t5-v1_1-xxl" ## T5 model path
+ max_length: 226
+
+first_stage_config:
+ target: sgm.models.autoencoder.VideoAutoencoderInferenceWrapper
+ params:
+ cp_size: 1
+ ckpt_path: "{your_CogVideoX-2b-sat_path}/vae/3d-vae.pt" ## VAE model path
+```
+
++ If using txt to save multiple prompts, please refer to `configs/test.txt` for modification. One prompt per line. If
+ you don't know how to write prompts, you can first use [this code](../inference/convert_demo.py) to call LLM for
+ refinement.
++ If using the command line as input, modify
+
+```yaml
+input_type: cli
+```
+
+so that prompts can be entered from the command line.
+
+If you want to change the output video directory, you can modify:
+
+```yaml
+output_dir: outputs/
+```
+
+The default is saved in the `.outputs/` folder.
+
+4. Run the inference code to start inference
+
+```shell
+bash inference.sh
+```
+
+## Fine-Tuning the Model
+
+### Preparing the Dataset
+
+The dataset format should be as follows:
+
+```
+.
+├── labels
+│ ├── 1.txt
+│ ├── 2.txt
+│ ├── ...
+└── videos
+ ├── 1.mp4
+ ├── 2.mp4
+ ├── ...
+```
+
+Each txt file should have the same name as its corresponding video file and contain the labels for that video. Each
+video should have a one-to-one correspondence with a label. Typically, a video should not have multiple labels.
+
+For style fine-tuning, please prepare at least 50 videos and labels with similar styles to facilitate fitting.
+
+### Modifying the Configuration File
+
+We support both `Lora` and `full-parameter fine-tuning` methods. Please note that both fine-tuning methods only apply to the `transformer` part. The `VAE part` is not modified. `T5` is only used as an Encoder.
+
+the `configs/cogvideox_2b_sft.yaml` (for full fine-tuning) as follows.
+
+```yaml
+ # checkpoint_activations: True ## using gradient checkpointing (both checkpoint_activations in the configuration file need to be set to True)
+ model_parallel_size: 1 # Model parallel size
+ experiment_name: lora-disney # Experiment name (do not change)
+ mode: finetune # Mode (do not change)
+ load: "{your_CogVideoX-2b-sat_path}/transformer" # Transformer model path
+ no_load_rng: True # Whether to load the random seed
+ train_iters: 1000 # Number of training iterations
+ eval_iters: 1 # Number of evaluation iterations
+ eval_interval: 100 # Evaluation interval
+ eval_batch_size: 1 # Batch size for evaluation
+ save: ckpts # Model save path
+ save_interval: 100 # Model save interval
+ log_interval: 20 # Log output interval
+ train_data: [ "your train data path" ]
+ valid_data: [ "your val data path" ] # Training and validation sets can be the same
+ split: 1,0,0 # Ratio of training, validation, and test sets
+ num_workers: 8 # Number of worker threads for data loading
+```
+
+If you wish to use Lora fine-tuning, you also need to modify:
+
+```yaml
+model:
+ scale_factor: 1.15258426
+ disable_first_stage_autocast: true
+ not_trainable_prefixes: [ 'all' ] ## Uncomment
+ log_keys:
+ - txt'
+
+ lora_config: ## Uncomment
+ target: sat.model.finetune.lora2.LoraMixin
+ params:
+ r: 256
+```
+
+### Fine-Tuning and Validation
+
+1. Run the inference code to start fine-tuning.
+
+```shell
+bash finetune.sh
+```
+
+### Converting to Huggingface Diffusers Supported Weights
+
+The SAT weight format is different from Huggingface's weight format and needs to be converted. Please run:
+
+```shell
+python ../tools/convert_weight_sat2hf.py
+```
+
+**Note**: This content has not yet been tested with LORA fine-tuning models.
\ No newline at end of file
diff --git a/sat/README_zh.md b/sat/README_zh.md
new file mode 100644
index 0000000000000000000000000000000000000000..ba301c8feb0154e52082c6b5f67eb155a73748b9
--- /dev/null
+++ b/sat/README_zh.md
@@ -0,0 +1,180 @@
+# SAT CogVideoX-2B
+
+本文件夹包含了使用 [SAT](https://github.com/THUDM/SwissArmyTransformer) 权重的推理代码,以及 SAT 权重的微调代码。
+
+该代码是团队训练模型时使用的框架。注释较少,需要认真研究。
+
+## 推理模型
+
+1. 确保你已经正确安装本文件夹中的要求的依赖
+
+```shell
+pip install -r requirements.txt
+```
+
+2. 下载模型权重
+
+首先,前往 SAT 镜像下载依赖。
+
+```shell
+mkdir CogVideoX-2b-sat
+cd CogVideoX-2b-sat
+wget https://cloud.tsinghua.edu.cn/f/fdba7608a49c463ba754/?dl=1
+mv 'index.html?dl=1' vae.zip
+unzip vae.zip
+wget https://cloud.tsinghua.edu.cn/f/556a3e1329e74f1bac45/?dl=1
+mv 'index.html?dl=1' transformer.zip
+unzip transformer.zip
+```
+
+然后,解压文件,模型结构应该如下
+
+```
+.
+├── transformer
+│ ├── 1000
+│ │ └── mp_rank_00_model_states.pt
+│ └── latest
+└── vae
+ └── 3d-vae.pt
+```
+
+接着,克隆 T5 模型,该模型不用做训练和微调,但是必须使用。
+
+```shell
+git lfs install
+git clone https://huggingface.co/google/t5-v1_1-xxl.git
+```
+
+**我们不需要使用tf_model.h5**文件。该文件可以删除。
+
+3. 修改`configs/cogvideox_2b_infer.yaml`中的文件。
+
+```yaml
+load: "{your_CogVideoX-2b-sat_path}/transformer" ## Transformer 模型路径
+
+conditioner_config:
+ target: sgm.modules.GeneralConditioner
+ params:
+ emb_models:
+ - is_trainable: false
+ input_key: txt
+ ucg_rate: 0.1
+ target: sgm.modules.encoders.modules.FrozenT5Embedder
+ params:
+ model_dir: "google/t5-v1_1-xxl" ## T5 模型路径
+ max_length: 226
+
+first_stage_config:
+ target: sgm.models.autoencoder.VideoAutoencoderInferenceWrapper
+ params:
+ cp_size: 1
+ ckpt_path: "{your_CogVideoX-2b-sat_path}/vae/3d-vae.pt" ## VAE 模型路径
+
+```
+
++ 如果使用 txt 保存多个提示词,请参考`configs/test.txt`
+ 进行修改。每一行一个提示词。如果您不知道如何书写提示词,可以先使用[此代码](../inference/convert_demo.py)调用 LLM进行润色。
++ 如果使用命令行作为输入,请修改
+
+```yaml
+input_type: cli
+```
+
+这样就可以从命令行输入提示词。
+
+如果你希望修改输出视频的地址,你可以修改:
+
+```yaml
+output_dir: outputs/
+```
+
+默认保存在`.outputs/`文件夹下。
+
+4. 运行推理代码,即可推理
+
+```shell
+bash inference.sh
+```
+
+## 微调模型
+
+### 准备数据集
+
+数据集格式应该如下:
+
+```
+.
+├── labels
+│ ├── 1.txt
+│ ├── 2.txt
+│ ├── ...
+└── videos
+ ├── 1.mp4
+ ├── 2.mp4
+ ├── ...
+```
+
+每个 txt 与视频同名,为视频的标签。视频与标签应该一一对应。通常情况下,不使用一个视频对应多个标签。
+
+如果为风格微调,清准备至少50条风格相似的视频和标签,以利于拟合。
+
+### 修改配置文件
+
+我们支持 `Lora` 和 全参数微调两种方式。请注意,两种微调方式都仅仅对 `transformer` 部分进行微调。不改动 `VAE` 部分。`T5`仅作为
+Encoder 使用。
+部分。 请按照以下方式修改`configs/cogvideox_2b_sft.yaml`(全量微调) 中的文件。
+
+```yaml
+ # checkpoint_activations: True ## using gradient checkpointing (配置文件中的两个checkpoint_activations都需要设置为True)
+ model_parallel_size: 1 # 模型并行大小
+ experiment_name: lora-disney # 实验名称(不要改动)
+ mode: finetune # 模式(不要改动)
+ load: "{your_CogVideoX-2b-sat_path}/transformer" ## Transformer 模型路径
+ no_load_rng: True # 是否加载随机数种子
+ train_iters: 1000 # 训练迭代次数
+ eval_iters: 1 # 验证迭代次数
+ eval_interval: 100 # 验证间隔
+ eval_batch_size: 1 # 验证集 batch size
+ save: ckpts # 模型保存路径
+ save_interval: 100 # 模型保存间隔
+ log_interval: 20 # 日志输出间隔
+ train_data: [ "your train data path" ]
+ valid_data: [ "your val data path" ] # 训练集和验证集可以相同
+ split: 1,0,0 # 训练集,验证集,测试集比例
+ num_workers: 8 # 数据加载器的工作线程数
+```
+
+如果你希望使用 Lora 微调,你还需要修改:
+
+```yaml
+model:
+ scale_factor: 1.15258426
+ disable_first_stage_autocast: true
+ not_trainable_prefixes: [ 'all' ] ## 解除注释
+ log_keys:
+ - txt'
+
+ lora_config: ## 解除注释
+ target: sat.model.finetune.lora2.LoraMixin
+ params:
+ r: 256
+```
+
+### 微调和验证
+
+1. 运行推理代码,即可开始微调。
+
+```shell
+bash finetune.sh
+```
+
+### 转换到 Huggingface Diffusers 库支持的权重
+
+SAT 权重格式与 Huggingface 的权重格式不同,需要转换。请运行
+
+```shell
+python ../tools/convert_weight_sat2hf.py
+```
+
+**注意** 本内容暂未测试 LORA 微调模型。
\ No newline at end of file
diff --git a/sat/arguments.py b/sat/arguments.py
new file mode 100644
index 0000000000000000000000000000000000000000..44767d3c8a71194c4a0b43acd8808a89a09ce0bf
--- /dev/null
+++ b/sat/arguments.py
@@ -0,0 +1,281 @@
+import argparse
+import os
+import torch
+import json
+import warnings
+import omegaconf
+from omegaconf import OmegaConf
+from sat.helpers import print_rank0
+from sat import mpu
+from sat.arguments import set_random_seed
+from sat.arguments import add_training_args, add_evaluation_args, add_data_args
+import torch.distributed
+
+
+def add_model_config_args(parser):
+ """Model arguments"""
+
+ group = parser.add_argument_group("model", "model configuration")
+ group.add_argument("--base", type=str, nargs="*", help="config for input and saving")
+ group.add_argument(
+ "--model-parallel-size", type=int, default=1, help="size of the model parallel. only use if you are an expert."
+ )
+ group.add_argument("--force-pretrain", action="store_true")
+ group.add_argument("--device", type=int, default=-1)
+ group.add_argument("--debug", action="store_true")
+ group.add_argument("--log-image", type=bool, default=True)
+
+ return parser
+
+
+def add_sampling_config_args(parser):
+ """Sampling configurations"""
+
+ group = parser.add_argument_group("sampling", "Sampling Configurations")
+ group.add_argument("--output-dir", type=str, default="samples")
+ group.add_argument("--input-dir", type=str, default=None)
+ group.add_argument("--input-type", type=str, default="cli")
+ group.add_argument("--input-file", type=str, default="input.txt")
+ group.add_argument("--final-size", type=int, default=2048)
+ group.add_argument("--sdedit", action="store_true")
+ group.add_argument("--grid-num-rows", type=int, default=1)
+ group.add_argument("--force-inference", action="store_true")
+ group.add_argument("--lcm_steps", type=int, default=None)
+ group.add_argument("--sampling-num-frames", type=int, default=32)
+ group.add_argument("--sampling-fps", type=int, default=8)
+ group.add_argument("--only-save-latents", type=bool, default=False)
+ group.add_argument("--only-log-video-latents", type=bool, default=False)
+ group.add_argument("--latent-channels", type=int, default=32)
+ group.add_argument("--image2video", action="store_true")
+
+ return parser
+
+
+def get_args(args_list=None, parser=None):
+ """Parse all the args."""
+ if parser is None:
+ parser = argparse.ArgumentParser(description="sat")
+ else:
+ assert isinstance(parser, argparse.ArgumentParser)
+ parser = add_model_config_args(parser)
+ parser = add_sampling_config_args(parser)
+ parser = add_training_args(parser)
+ parser = add_evaluation_args(parser)
+ parser = add_data_args(parser)
+
+ import deepspeed
+
+ parser = deepspeed.add_config_arguments(parser)
+
+ args = parser.parse_args(args_list)
+ args = process_config_to_args(args)
+
+ if not args.train_data:
+ print_rank0("No training data specified", level="WARNING")
+
+ assert (args.train_iters is None) or (args.epochs is None), "only one of train_iters and epochs should be set."
+ if args.train_iters is None and args.epochs is None:
+ args.train_iters = 10000 # default 10k iters
+ print_rank0("No train_iters (recommended) or epochs specified, use default 10k iters.", level="WARNING")
+
+ args.cuda = torch.cuda.is_available()
+
+ args.rank = int(os.getenv("RANK", "0"))
+ args.world_size = int(os.getenv("WORLD_SIZE", "1"))
+ if args.local_rank is None:
+ args.local_rank = int(os.getenv("LOCAL_RANK", "0")) # torchrun
+
+ if args.device == -1:
+ if torch.cuda.device_count() == 0:
+ args.device = "cpu"
+ elif args.local_rank is not None:
+ args.device = args.local_rank
+ else:
+ args.device = args.rank % torch.cuda.device_count()
+
+ if args.local_rank != args.device and args.mode != "inference":
+ raise ValueError(
+ "LOCAL_RANK (default 0) and args.device inconsistent. "
+ "This can only happens in inference mode. "
+ "Please use CUDA_VISIBLE_DEVICES=x for single-GPU training. "
+ )
+
+ if args.rank == 0:
+ print_rank0("using world size: {}".format(args.world_size))
+
+ if args.train_data_weights is not None:
+ assert len(args.train_data_weights) == len(args.train_data)
+
+ if args.mode != "inference": # training with deepspeed
+ args.deepspeed = True
+ if args.deepspeed_config is None: # not specified
+ deepspeed_config_path = os.path.join(
+ os.path.dirname(__file__), "training", f"deepspeed_zero{args.zero_stage}.json"
+ )
+ with open(deepspeed_config_path) as file:
+ args.deepspeed_config = json.load(file)
+ override_deepspeed_config = True
+ else:
+ override_deepspeed_config = False
+
+ assert not (args.fp16 and args.bf16), "cannot specify both fp16 and bf16."
+
+ if args.zero_stage > 0 and not args.fp16 and not args.bf16:
+ print_rank0("Automatically set fp16=True to use ZeRO.")
+ args.fp16 = True
+ args.bf16 = False
+
+ if args.deepspeed:
+ if args.checkpoint_activations:
+ args.deepspeed_activation_checkpointing = True
+ else:
+ args.deepspeed_activation_checkpointing = False
+ if args.deepspeed_config is not None:
+ deepspeed_config = args.deepspeed_config
+
+ if override_deepspeed_config: # not specify deepspeed_config, use args
+ if args.fp16:
+ deepspeed_config["fp16"]["enabled"] = True
+ elif args.bf16:
+ deepspeed_config["bf16"]["enabled"] = True
+ deepspeed_config["fp16"]["enabled"] = False
+ else:
+ deepspeed_config["fp16"]["enabled"] = False
+ deepspeed_config["train_micro_batch_size_per_gpu"] = args.batch_size
+ deepspeed_config["gradient_accumulation_steps"] = args.gradient_accumulation_steps
+ optimizer_params_config = deepspeed_config["optimizer"]["params"]
+ optimizer_params_config["lr"] = args.lr
+ optimizer_params_config["weight_decay"] = args.weight_decay
+ else: # override args with values in deepspeed_config
+ if args.rank == 0:
+ print_rank0("Will override arguments with manually specified deepspeed_config!")
+ if "fp16" in deepspeed_config and deepspeed_config["fp16"]["enabled"]:
+ args.fp16 = True
+ else:
+ args.fp16 = False
+ if "bf16" in deepspeed_config and deepspeed_config["bf16"]["enabled"]:
+ args.bf16 = True
+ else:
+ args.bf16 = False
+ if "train_micro_batch_size_per_gpu" in deepspeed_config:
+ args.batch_size = deepspeed_config["train_micro_batch_size_per_gpu"]
+ if "gradient_accumulation_steps" in deepspeed_config:
+ args.gradient_accumulation_steps = deepspeed_config["gradient_accumulation_steps"]
+ else:
+ args.gradient_accumulation_steps = None
+ if "optimizer" in deepspeed_config:
+ optimizer_params_config = deepspeed_config["optimizer"].get("params", {})
+ args.lr = optimizer_params_config.get("lr", args.lr)
+ args.weight_decay = optimizer_params_config.get("weight_decay", args.weight_decay)
+ args.deepspeed_config = deepspeed_config
+
+ # initialize distributed and random seed because it always seems to be necessary.
+ initialize_distributed(args)
+ args.seed = args.seed + mpu.get_data_parallel_rank()
+ set_random_seed(args.seed)
+ return args
+
+
+def initialize_distributed(args):
+ """Initialize torch.distributed."""
+ if torch.distributed.is_initialized():
+ if mpu.model_parallel_is_initialized():
+ if args.model_parallel_size != mpu.get_model_parallel_world_size():
+ raise ValueError(
+ "model_parallel_size is inconsistent with prior configuration."
+ "We currently do not support changing model_parallel_size."
+ )
+ return False
+ else:
+ if args.model_parallel_size > 1:
+ warnings.warn(
+ "model_parallel_size > 1 but torch.distributed is not initialized via SAT."
+ "Please carefully make sure the correctness on your own."
+ )
+ mpu.initialize_model_parallel(args.model_parallel_size)
+ return True
+ # the automatic assignment of devices has been moved to arguments.py
+ if args.device == "cpu":
+ pass
+ else:
+ torch.cuda.set_device(args.device)
+ # Call the init process
+ init_method = "tcp://"
+ args.master_ip = os.getenv("MASTER_ADDR", "localhost")
+
+ if args.world_size == 1:
+ from sat.helpers import get_free_port
+
+ default_master_port = str(get_free_port())
+ else:
+ default_master_port = "6000"
+ args.master_port = os.getenv("MASTER_PORT", default_master_port)
+ init_method += args.master_ip + ":" + args.master_port
+ torch.distributed.init_process_group(
+ backend=args.distributed_backend, world_size=args.world_size, rank=args.rank, init_method=init_method
+ )
+
+ # Set the model-parallel / data-parallel communicators.
+ mpu.initialize_model_parallel(args.model_parallel_size)
+
+ # Set vae context parallel group equal to model parallel group
+ from sgm.util import set_context_parallel_group, initialize_context_parallel
+
+ if args.model_parallel_size <= 2:
+ set_context_parallel_group(args.model_parallel_size, mpu.get_model_parallel_group())
+ else:
+ initialize_context_parallel(2)
+ # mpu.initialize_model_parallel(1)
+ # Optional DeepSpeed Activation Checkpointing Features
+ if args.deepspeed:
+ import deepspeed
+
+ deepspeed.init_distributed(
+ dist_backend=args.distributed_backend, world_size=args.world_size, rank=args.rank, init_method=init_method
+ )
+ # # It seems that it has no negative influence to configure it even without using checkpointing.
+ # deepspeed.checkpointing.configure(mpu, deepspeed_config=args.deepspeed_config, num_checkpoints=args.num_layers)
+ else:
+ # in model-only mode, we don't want to init deepspeed, but we still need to init the rng tracker for model_parallel, just because we save the seed by default when dropout.
+ try:
+ import deepspeed
+ from deepspeed.runtime.activation_checkpointing.checkpointing import (
+ _CUDA_RNG_STATE_TRACKER,
+ _MODEL_PARALLEL_RNG_TRACKER_NAME,
+ )
+
+ _CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME, 1) # default seed 1
+ except Exception as e:
+ from sat.helpers import print_rank0
+
+ print_rank0(str(e), level="DEBUG")
+
+ return True
+
+
+def process_config_to_args(args):
+ """Fetch args from only --base"""
+
+ configs = [OmegaConf.load(cfg) for cfg in args.base]
+ config = OmegaConf.merge(*configs)
+
+ args_config = config.pop("args", OmegaConf.create())
+ for key in args_config:
+ if isinstance(args_config[key], omegaconf.DictConfig) or isinstance(args_config[key], omegaconf.ListConfig):
+ arg = OmegaConf.to_object(args_config[key])
+ else:
+ arg = args_config[key]
+ if hasattr(args, key):
+ setattr(args, key, arg)
+
+ if "model" in config:
+ model_config = config.pop("model", OmegaConf.create())
+ args.model_config = model_config
+ if "deepspeed" in config:
+ deepspeed_config = config.pop("deepspeed", OmegaConf.create())
+ args.deepspeed_config = OmegaConf.to_object(deepspeed_config)
+ if "data" in config:
+ data_config = config.pop("data", OmegaConf.create())
+ args.data_config = data_config
+
+ return args
diff --git a/sat/configs/cogvideox_2b_infer.yaml b/sat/configs/cogvideox_2b_infer.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..adf9de21b42ffdaa3a2c2eddad770700ab758c70
--- /dev/null
+++ b/sat/configs/cogvideox_2b_infer.yaml
@@ -0,0 +1,166 @@
+args:
+ latent_channels: 16
+ mode: inference
+ load: "CogVideoX-2b-sat/transformer"
+ batch_size: 1
+ input_type: txt
+ input_file: test.txt
+ sampling_num_frames: 13 # Must be 13, 11 or 9
+ sampling_fps: 8
+ fp16: True
+ output_dir: outputs/
+ force_inference: True
+
+model:
+ scale_factor: 1.15258426
+ disable_first_stage_autocast: true
+ log_keys:
+ - txt
+
+ denoiser_config:
+ target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
+ params:
+ num_idx: 1000
+ quantize_c_noise: False
+
+ weighting_config:
+ target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
+ scaling_config:
+ target: sgm.modules.diffusionmodules.denoiser_scaling.VideoScaling
+ discretization_config:
+ target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
+ params:
+ shift_scale: 3.0
+
+ network_config:
+ target: dit_video_concat.DiffusionTransformer
+ params:
+ time_embed_dim: 512
+ elementwise_affine: True
+ num_frames: 49
+ time_compressed_rate: 4
+ latent_width: 90
+ latent_height: 60
+ num_layers: 30
+ patch_size: 2
+ in_channels: 16
+ out_channels: 16
+ hidden_size: 1920
+ adm_in_channels: 256
+ num_attention_heads: 30
+
+ transformer_args:
+ vocab_size: 1
+ max_sequence_length: 64
+ layernorm_order: pre
+ skip_init: false
+ model_parallel_size: 1
+ is_decoder: false
+
+ modules:
+ pos_embed_config:
+ target: dit_video_concat.Basic3DPositionEmbeddingMixin
+ params:
+ text_length: 226
+ height_interpolation: 1.875
+ width_interpolation: 1.875
+
+ patch_embed_config:
+ target: dit_video_concat.ImagePatchEmbeddingMixin
+ params:
+ text_hidden_size: 4096
+
+ adaln_layer_config:
+ target: dit_video_concat.AdaLNMixin
+ params:
+ qk_ln: True
+
+ final_layer_config:
+ target: dit_video_concat.FinalLayerMixin
+
+ conditioner_config:
+ target: sgm.modules.GeneralConditioner
+ params:
+ emb_models:
+ - is_trainable: false
+ input_key: txt
+ ucg_rate: 0.1
+ target: sgm.modules.encoders.modules.FrozenT5Embedder
+ params:
+ model_dir: "google/t5-v1_1-xxl"
+ max_length: 226
+
+ first_stage_config:
+ target: vae_modules.autoencoder.VideoAutoencoderInferenceWrapper
+ params:
+ cp_size: 1
+ ckpt_path: "CogVideoX-2b-sat/vae/3d-vae.pt"
+ ignore_keys: [ 'loss' ]
+
+ loss_config:
+ target: torch.nn.Identity
+
+ regularizer_config:
+ target: vae_modules.regularizers.DiagonalGaussianRegularizer
+
+ encoder_config:
+ target: vae_modules.cp_enc_dec.ContextParallelEncoder3D
+ params:
+ double_z: true
+ z_channels: 16
+ resolution: 256
+ in_channels: 3
+ out_ch: 3
+ ch: 128
+ ch_mult: [ 1, 2, 2, 4 ]
+ attn_resolutions: [ ]
+ num_res_blocks: 3
+ dropout: 0.0
+ gather_norm: True
+
+ decoder_config:
+ target: vae_modules.cp_enc_dec.ContextParallelDecoder3D
+ params:
+ double_z: True
+ z_channels: 16
+ resolution: 256
+ in_channels: 3
+ out_ch: 3
+ ch: 128
+ ch_mult: [ 1, 2, 2, 4 ]
+ attn_resolutions: [ ]
+ num_res_blocks: 3
+ dropout: 0.0
+ gather_norm: false
+
+ loss_fn_config:
+ target: sgm.modules.diffusionmodules.loss.VideoDiffusionLoss
+ params:
+ offset_noise_level: 0
+ sigma_sampler_config:
+ target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
+ params:
+ uniform_sampling: True
+ num_idx: 1000
+ discretization_config:
+ target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
+ params:
+ shift_scale: 3.0
+
+ sampler_config:
+ target: sgm.modules.diffusionmodules.sampling.VPSDEDPMPP2MSampler
+ params:
+ num_steps: 50
+ verbose: True
+
+ discretization_config:
+ target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
+ params:
+ shift_scale: 3.0
+
+ guider_config:
+ target: sgm.modules.diffusionmodules.guiders.DynamicCFG
+ params:
+ scale: 6
+ exp: 5
+ num_steps: 50
\ No newline at end of file
diff --git a/sat/configs/cogvideox_2b_sft.yaml b/sat/configs/cogvideox_2b_sft.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..1cac09b079d946f117d44a0eb2f0f66fc9601281
--- /dev/null
+++ b/sat/configs/cogvideox_2b_sft.yaml
@@ -0,0 +1,225 @@
+args:
+ checkpoint_activations: True ## using gradient checkpointing
+ model_parallel_size: 1
+ experiment_name: lora-disney
+ mode: finetune
+ load: "CogVideoX-2b-sat/transformer"
+ no_load_rng: True
+ train_iters: 1000
+ eval_iters: 1
+ eval_interval: 100
+ eval_batch_size: 1
+ save: ckpts
+ save_interval: 100
+ log_interval: 20
+ train_data: ["disney"]
+ valid_data: ["disney"]
+ split: 1,0,0
+ num_workers: 8
+ force_train: True
+ only_log_video_latents: True
+
+data:
+ target: data_video.SFTDataset
+ params:
+ video_size: [480, 720]
+ fps: 8
+ max_num_frames: 49
+ skip_frms_num: 3.
+
+deepspeed:
+ train_micro_batch_size_per_gpu: 1
+ gradient_accumulation_steps: 1
+ steps_per_print: 50
+ gradient_clipping: 0.1
+ zero_optimization:
+ stage: 2
+ cpu_offload: false
+ contiguous_gradients: false
+ overlap_comm: true
+ reduce_scatter: true
+ reduce_bucket_size: 1000000000
+ allgather_bucket_size: 1000000000
+ load_from_fp32_weights: false
+ zero_allow_untested_optimizer: true
+ bf16:
+ enabled: False
+ fp16:
+ enabled: True
+ loss_scale: 0
+ loss_scale_window: 400
+ hysteresis: 2
+ min_loss_scale: 1
+ optimizer:
+ type: sat.ops.FusedEmaAdam
+ params:
+ lr: 0.0002
+ betas: [0.9, 0.95]
+ eps: 1e-8
+ weight_decay: 1e-4
+ activation_checkpointing:
+ partition_activations: false
+ contiguous_memory_optimization: false
+ wall_clock_breakdown: false
+
+
+model:
+ scale_factor: 1.15258426
+ disable_first_stage_autocast: true
+ not_trainable_prefixes: ['all'] ## Using Lora
+ log_keys:
+ - txt
+
+ denoiser_config:
+ target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
+ params:
+ num_idx: 1000
+ quantize_c_noise: False
+
+ weighting_config:
+ target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
+ scaling_config:
+ target: sgm.modules.diffusionmodules.denoiser_scaling.VideoScaling
+ discretization_config:
+ target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
+ params:
+ shift_scale: 3.0
+
+ network_config:
+ target: dit_video_concat.DiffusionTransformer
+ params:
+ time_embed_dim: 512
+ elementwise_affine: True
+ num_frames: 49
+ time_compressed_rate: 4
+ latent_width: 90
+ latent_height: 60
+ num_layers: 30
+ patch_size: 2
+ in_channels: 16
+ out_channels: 16
+ hidden_size: 1920
+ adm_in_channels: 256
+ num_attention_heads: 30
+
+ transformer_args:
+ checkpoint_activations: True ## using gradient checkpointing
+ vocab_size: 1
+ max_sequence_length: 64
+ layernorm_order: pre
+ skip_init: false
+ model_parallel_size: 1
+ is_decoder: false
+
+ modules:
+ pos_embed_config:
+ target: dit_video_concat.Basic3DPositionEmbeddingMixin
+ params:
+ text_length: 226
+ height_interpolation: 1.875
+ width_interpolation: 1.875
+
+ lora_config: ## Using Lora
+ target: sat.model.finetune.lora2.LoraMixin
+ params:
+ r: 128
+
+ patch_embed_config:
+ target: dit_video_concat.ImagePatchEmbeddingMixin
+ params:
+ text_hidden_size: 4096
+
+ adaln_layer_config:
+ target: dit_video_concat.AdaLNMixin
+ params:
+ qk_ln: True
+
+ final_layer_config:
+ target: dit_video_concat.FinalLayerMixin
+
+ conditioner_config:
+ target: sgm.modules.GeneralConditioner
+ params:
+ emb_models:
+ - is_trainable: false
+ input_key: txt
+ ucg_rate: 0.1
+ target: sgm.modules.encoders.modules.FrozenT5Embedder
+ params:
+ model_dir: "google/t5-v1_1-xxl"
+ max_length: 226
+
+ first_stage_config:
+ target: vae_modules.autoencoder.VideoAutoencoderInferenceWrapper
+ params:
+ cp_size: 1
+ ckpt_path: "CogVideoX-2b-sat/vae/3d-vae.pt"
+ ignore_keys: [ 'loss' ]
+
+ loss_config:
+ target: torch.nn.Identity
+
+ regularizer_config:
+ target: vae_modules.regularizers.DiagonalGaussianRegularizer
+
+ encoder_config:
+ target: vae_modules.cp_enc_dec.ContextParallelEncoder3D
+ params:
+ double_z: true
+ z_channels: 16
+ resolution: 256
+ in_channels: 3
+ out_ch: 3
+ ch: 128
+ ch_mult: [ 1, 2, 2, 4 ]
+ attn_resolutions: [ ]
+ num_res_blocks: 3
+ dropout: 0.0
+ gather_norm: True
+
+ decoder_config:
+ target: vae_modules.cp_enc_dec.ContextParallelDecoder3D
+ params:
+ double_z: True
+ z_channels: 16
+ resolution: 256
+ in_channels: 3
+ out_ch: 3
+ ch: 128
+ ch_mult: [ 1, 2, 2, 4 ]
+ attn_resolutions: [ ]
+ num_res_blocks: 3
+ dropout: 0.0
+ gather_norm: false
+
+ loss_fn_config:
+ target: sgm.modules.diffusionmodules.loss.VideoDiffusionLoss
+ params:
+ offset_noise_level: 0
+ sigma_sampler_config:
+ target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
+ params:
+ uniform_sampling: True
+ num_idx: 1000
+ discretization_config:
+ target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
+ params:
+ shift_scale: 3.0
+
+ sampler_config:
+ target: sgm.modules.diffusionmodules.sampling.VPSDEDPMPP2MSampler
+ params:
+ num_steps: 50
+ verbose: True
+
+ discretization_config:
+ target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
+ params:
+ shift_scale: 3.0
+
+ guider_config:
+ target: sgm.modules.diffusionmodules.guiders.DynamicCFG
+ params:
+ scale: 6
+ exp: 5
+ num_steps: 50
\ No newline at end of file
diff --git a/sat/configs/test.txt b/sat/configs/test.txt
new file mode 100644
index 0000000000000000000000000000000000000000..b732bbd4e16e1e9696d7fda4d7021b8cae5f6d4b
--- /dev/null
+++ b/sat/configs/test.txt
@@ -0,0 +1,3 @@
+In the haunting backdrop of a war-torn city, where ruins and crumbled walls tell a story of devastation, a poignant close-up frames a young girl. Her face is smudged with ash, a silent testament to the chaos around her. Her eyes glistening with a mix of sorrow and resilience, capturing the raw emotion of a world that has lost its innocence to the ravages of conflict.
+The camera follows behind a white vintage SUV with a black roof rack as it speeds up a steep dirt road surrounded by pine trees on a steep mountain slope, dust kicks up from its tires, the sunlight shines on the SUV as it speeds along the dirt road, casting a warm glow over the scene. The dirt road curves gently into the distance, with no other cars or vehicles in sight. The trees on either side of the road are redwoods, with patches of greenery scattered throughout. The car is seen from the rear following the curve with ease, making it seem as if it is on a rugged drive through the rugged terrain. The dirt road itself is surrounded by steep hills and mountains, with a clear blue sky above with wispy clouds.
+A detailed wooden toy ship with intricately carved masts and sails is seen gliding smoothly over a plush, blue carpet that mimics the waves of the sea. The ship's hull is painted a rich brown, with tiny windows. The carpet, soft and textured, provides a perfect backdrop, resembling an oceanic expanse. Surrounding the ship are various other toys and children's items, hinting at a playful environment. The scene captures the innocence and imagination of childhood, with the toy ship's journey symbolizing endless adventures in a whimsical, indoor setting.
\ No newline at end of file
diff --git a/sat/data_video.py b/sat/data_video.py
new file mode 100644
index 0000000000000000000000000000000000000000..ccfea461570ac2f35c11bd2e6cf75d43f1e43a8e
--- /dev/null
+++ b/sat/data_video.py
@@ -0,0 +1,451 @@
+import io
+import os
+import sys
+from functools import partial
+import math
+import torchvision.transforms as TT
+from sgm.webds import MetaDistributedWebDataset
+import random
+from fractions import Fraction
+from typing import Union, Optional, Dict, Any, Tuple
+from torchvision.io.video import av
+import numpy as np
+import torch
+from torchvision.io import _video_opt
+from torchvision.io.video import _check_av_available, _read_from_stream, _align_audio_frames
+from torchvision.transforms.functional import center_crop, resize
+from torchvision.transforms import InterpolationMode
+import decord
+from decord import VideoReader
+from torch.utils.data import Dataset
+
+
+def read_video(
+ filename: str,
+ start_pts: Union[float, Fraction] = 0,
+ end_pts: Optional[Union[float, Fraction]] = None,
+ pts_unit: str = "pts",
+ output_format: str = "THWC",
+) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]:
+ """
+ Reads a video from a file, returning both the video frames and the audio frames
+
+ Args:
+ filename (str): path to the video file
+ start_pts (int if pts_unit = 'pts', float / Fraction if pts_unit = 'sec', optional):
+ The start presentation time of the video
+ end_pts (int if pts_unit = 'pts', float / Fraction if pts_unit = 'sec', optional):
+ The end presentation time
+ pts_unit (str, optional): unit in which start_pts and end_pts values will be interpreted,
+ either 'pts' or 'sec'. Defaults to 'pts'.
+ output_format (str, optional): The format of the output video tensors. Can be either "THWC" (default) or "TCHW".
+
+ Returns:
+ vframes (Tensor[T, H, W, C] or Tensor[T, C, H, W]): the `T` video frames
+ aframes (Tensor[K, L]): the audio frames, where `K` is the number of channels and `L` is the number of points
+ info (Dict): metadata for the video and audio. Can contain the fields video_fps (float) and audio_fps (int)
+ """
+
+ output_format = output_format.upper()
+ if output_format not in ("THWC", "TCHW"):
+ raise ValueError(f"output_format should be either 'THWC' or 'TCHW', got {output_format}.")
+
+ _check_av_available()
+
+ if end_pts is None:
+ end_pts = float("inf")
+
+ if end_pts < start_pts:
+ raise ValueError(f"end_pts should be larger than start_pts, got start_pts={start_pts} and end_pts={end_pts}")
+
+ info = {}
+ audio_frames = []
+ audio_timebase = _video_opt.default_timebase
+
+ with av.open(filename, metadata_errors="ignore") as container:
+ if container.streams.audio:
+ audio_timebase = container.streams.audio[0].time_base
+ if container.streams.video:
+ video_frames = _read_from_stream(
+ container,
+ start_pts,
+ end_pts,
+ pts_unit,
+ container.streams.video[0],
+ {"video": 0},
+ )
+ video_fps = container.streams.video[0].average_rate
+ # guard against potentially corrupted files
+ if video_fps is not None:
+ info["video_fps"] = float(video_fps)
+
+ if container.streams.audio:
+ audio_frames = _read_from_stream(
+ container,
+ start_pts,
+ end_pts,
+ pts_unit,
+ container.streams.audio[0],
+ {"audio": 0},
+ )
+ info["audio_fps"] = container.streams.audio[0].rate
+
+ aframes_list = [frame.to_ndarray() for frame in audio_frames]
+
+ vframes = torch.empty((0, 1, 1, 3), dtype=torch.uint8)
+
+ if aframes_list:
+ aframes = np.concatenate(aframes_list, 1)
+ aframes = torch.as_tensor(aframes)
+ if pts_unit == "sec":
+ start_pts = int(math.floor(start_pts * (1 / audio_timebase)))
+ if end_pts != float("inf"):
+ end_pts = int(math.ceil(end_pts * (1 / audio_timebase)))
+ aframes = _align_audio_frames(aframes, audio_frames, start_pts, end_pts)
+ else:
+ aframes = torch.empty((1, 0), dtype=torch.float32)
+
+ if output_format == "TCHW":
+ # [T,H,W,C] --> [T,C,H,W]
+ vframes = vframes.permute(0, 3, 1, 2)
+
+ return vframes, aframes, info
+
+
+def resize_for_rectangle_crop(arr, image_size, reshape_mode="random"):
+ if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]:
+ arr = resize(
+ arr,
+ size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])],
+ interpolation=InterpolationMode.BICUBIC,
+ )
+ else:
+ arr = resize(
+ arr,
+ size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]],
+ interpolation=InterpolationMode.BICUBIC,
+ )
+
+ h, w = arr.shape[2], arr.shape[3]
+ arr = arr.squeeze(0)
+
+ delta_h = h - image_size[0]
+ delta_w = w - image_size[1]
+
+ if reshape_mode == "random" or reshape_mode == "none":
+ top = np.random.randint(0, delta_h + 1)
+ left = np.random.randint(0, delta_w + 1)
+ elif reshape_mode == "center":
+ top, left = delta_h // 2, delta_w // 2
+ else:
+ raise NotImplementedError
+ arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1])
+ return arr
+
+
+def pad_last_frame(tensor, num_frames):
+ # T, H, W, C
+ if tensor.shape[0] < num_frames:
+ last_frame = tensor[-int(num_frames - tensor.shape[1]) :]
+ padded_tensor = torch.cat([tensor, last_frame], dim=0)
+ return padded_tensor
+ else:
+ return tensor[:num_frames]
+
+
+def load_video(
+ video_data,
+ sampling="uniform",
+ duration=None,
+ num_frames=4,
+ wanted_fps=None,
+ actual_fps=None,
+ skip_frms_num=0.0,
+ nb_read_frames=None,
+):
+ decord.bridge.set_bridge("torch")
+ vr = VideoReader(uri=video_data, height=-1, width=-1)
+ if nb_read_frames is not None:
+ ori_vlen = nb_read_frames
+ else:
+ ori_vlen = min(int(duration * actual_fps) - 1, len(vr))
+
+ max_seek = int(ori_vlen - skip_frms_num - num_frames / wanted_fps * actual_fps)
+ start = random.randint(skip_frms_num, max_seek + 1)
+ end = int(start + num_frames / wanted_fps * actual_fps)
+ n_frms = num_frames
+
+ if sampling == "uniform":
+ indices = np.arange(start, end, (end - start) / n_frms).astype(int)
+ else:
+ raise NotImplementedError
+
+ # get_batch -> T, H, W, C
+ temp_frms = vr.get_batch(np.arange(start, end))
+ assert temp_frms is not None
+ tensor_frms = torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms
+ tensor_frms = tensor_frms[torch.tensor((indices - start).tolist())]
+
+ return pad_last_frame(tensor_frms, num_frames)
+
+
+import threading
+
+
+def load_video_with_timeout(*args, **kwargs):
+ video_container = {}
+
+ def target_function():
+ video = load_video(*args, **kwargs)
+ video_container["video"] = video
+
+ thread = threading.Thread(target=target_function)
+ thread.start()
+ timeout = 20
+ thread.join(timeout)
+
+ if thread.is_alive():
+ print("Loading video timed out")
+ raise TimeoutError
+ return video_container.get("video", None).contiguous()
+
+
+def process_video(
+ video_path,
+ image_size=None,
+ duration=None,
+ num_frames=4,
+ wanted_fps=None,
+ actual_fps=None,
+ skip_frms_num=0.0,
+ nb_read_frames=None,
+):
+ """
+ video_path: str or io.BytesIO
+ image_size: .
+ duration: preknow the duration to speed up by seeking to sampled start. TODO by_pass if unknown.
+ num_frames: wanted num_frames.
+ wanted_fps: .
+ skip_frms_num: ignore the first and the last xx frames, avoiding transitions.
+ """
+
+ video = load_video_with_timeout(
+ video_path,
+ duration=duration,
+ num_frames=num_frames,
+ wanted_fps=wanted_fps,
+ actual_fps=actual_fps,
+ skip_frms_num=skip_frms_num,
+ nb_read_frames=nb_read_frames,
+ )
+
+ # --- copy and modify the image process ---
+ video = video.permute(0, 3, 1, 2) # [T, C, H, W]
+
+ # resize
+ if image_size is not None:
+ video = resize_for_rectangle_crop(video, image_size, reshape_mode="center")
+
+ return video
+
+
+def process_fn_video(src, image_size, fps, num_frames, skip_frms_num=0.0, txt_key="caption"):
+ while True:
+ r = next(src)
+ if "mp4" in r:
+ video_data = r["mp4"]
+ elif "avi" in r:
+ video_data = r["avi"]
+ else:
+ print("No video data found")
+ continue
+
+ if txt_key not in r:
+ txt = ""
+ else:
+ txt = r[txt_key]
+
+ if isinstance(txt, bytes):
+ txt = txt.decode("utf-8")
+ else:
+ txt = str(txt)
+
+ duration = r.get("duration", None)
+ if duration is not None:
+ duration = float(duration)
+ else:
+ continue
+
+ actual_fps = r.get("fps", None)
+ if actual_fps is not None:
+ actual_fps = float(actual_fps)
+ else:
+ continue
+
+ required_frames = num_frames / fps * actual_fps + 2 * skip_frms_num
+ required_duration = num_frames / fps + 2 * skip_frms_num / actual_fps
+
+ if duration is not None and duration < required_duration:
+ continue
+
+ try:
+ frames = process_video(
+ io.BytesIO(video_data),
+ num_frames=num_frames,
+ wanted_fps=fps,
+ image_size=image_size,
+ duration=duration,
+ actual_fps=actual_fps,
+ skip_frms_num=skip_frms_num,
+ )
+ frames = (frames - 127.5) / 127.5
+ except Exception as e:
+ print(e)
+ continue
+
+ item = {
+ "mp4": frames,
+ "txt": txt,
+ "num_frames": num_frames,
+ "fps": fps,
+ }
+
+ yield item
+
+
+class VideoDataset(MetaDistributedWebDataset):
+ def __init__(
+ self,
+ path,
+ image_size,
+ num_frames,
+ fps,
+ skip_frms_num=0.0,
+ nshards=sys.maxsize,
+ seed=1,
+ meta_names=None,
+ shuffle_buffer=1000,
+ include_dirs=None,
+ txt_key="caption",
+ **kwargs,
+ ):
+ if seed == -1:
+ seed = random.randint(0, 1000000)
+ if meta_names is None:
+ meta_names = []
+
+ if path.startswith(";"):
+ path, include_dirs = path.split(";", 1)
+ super().__init__(
+ path,
+ partial(
+ process_fn_video, num_frames=num_frames, image_size=image_size, fps=fps, skip_frms_num=skip_frms_num
+ ),
+ seed,
+ meta_names=meta_names,
+ shuffle_buffer=shuffle_buffer,
+ nshards=nshards,
+ include_dirs=include_dirs,
+ )
+
+ @classmethod
+ def create_dataset_function(cls, path, args, **kwargs):
+ return cls(path, **kwargs)
+
+
+class SFTDataset(Dataset):
+ def __init__(self, data_dir, video_size, fps, max_num_frames, skip_frms_num=3):
+ """
+ skip_frms_num: ignore the first and the last xx frames, avoiding transitions.
+ """
+ super(SFTDataset, self).__init__()
+
+ self.videos_list = []
+ self.captions_list = []
+ self.num_frames_list = []
+ self.fps_list = []
+
+ decord.bridge.set_bridge("torch")
+ for root, dirnames, filenames in os.walk(data_dir):
+ for filename in filenames:
+ if filename.endswith(".mp4"):
+ video_path = os.path.join(root, filename)
+ vr = VideoReader(uri=video_path, height=-1, width=-1)
+ actual_fps = vr.get_avg_fps()
+ ori_vlen = len(vr)
+
+ if ori_vlen / actual_fps * fps > max_num_frames:
+ num_frames = max_num_frames
+ start = int(skip_frms_num)
+ end = int(start + num_frames / fps * actual_fps)
+ indices = np.arange(start, end, (end - start) / num_frames).astype(int)
+ temp_frms = vr.get_batch(np.arange(start, end))
+ assert temp_frms is not None
+ tensor_frms = torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms
+ tensor_frms = tensor_frms[torch.tensor((indices - start).tolist())]
+ else:
+ if ori_vlen > max_num_frames:
+ num_frames = max_num_frames
+ start = int(skip_frms_num)
+ end = int(ori_vlen - skip_frms_num)
+ indices = np.arange(start, end, (end - start) / num_frames).astype(int)
+ temp_frms = vr.get_batch(np.arange(start, end))
+ assert temp_frms is not None
+ tensor_frms = (
+ torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms
+ )
+ tensor_frms = tensor_frms[torch.tensor((indices - start).tolist())]
+ else:
+
+ def nearest_smaller_4k_plus_1(n):
+ remainder = n % 4
+ if remainder == 0:
+ return n - 3
+ else:
+ return n - remainder + 1
+
+ start = int(skip_frms_num)
+ end = int(ori_vlen - skip_frms_num)
+ num_frames = nearest_smaller_4k_plus_1(
+ end - start
+ ) # 3D VAE requires the number of frames to be 4k+1
+ end = int(start + num_frames)
+ temp_frms = vr.get_batch(np.arange(start, end))
+ assert temp_frms is not None
+ tensor_frms = (
+ torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms
+ )
+
+ tensor_frms = pad_last_frame(
+ tensor_frms, num_frames
+ ) # the len of indices may be less than num_frames, due to round error
+ tensor_frms = tensor_frms.permute(0, 3, 1, 2) # [T, H, W, C] -> [T, C, H, W]
+ tensor_frms = resize_for_rectangle_crop(tensor_frms, video_size, reshape_mode="center")
+ tensor_frms = (tensor_frms - 127.5) / 127.5
+ self.videos_list.append(tensor_frms)
+
+ # caption
+ caption_path = os.path.join(root, filename.replace("videos", "labels").replace(".mp4", ".txt"))
+ if os.path.exists(caption_path):
+ caption = open(caption_path, "r").read().splitlines()[0]
+ else:
+ caption = ""
+ self.captions_list.append(caption)
+ self.num_frames_list.append(num_frames)
+ self.fps_list.append(fps)
+
+ def __getitem__(self, index):
+ item = {
+ "mp4": self.videos_list[index],
+ "txt": self.captions_list[index],
+ "num_frames": self.num_frames_list[index],
+ "fps": self.fps_list[index],
+ }
+ return item
+
+ def __len__(self):
+ return len(self.fps_list)
+
+ @classmethod
+ def create_dataset_function(cls, path, args, **kwargs):
+ return cls(data_dir=path, **kwargs)
diff --git a/sat/diffusion_video.py b/sat/diffusion_video.py
new file mode 100644
index 0000000000000000000000000000000000000000..951e93e2c754bc9353da10eb43f7d8b258e8a970
--- /dev/null
+++ b/sat/diffusion_video.py
@@ -0,0 +1,318 @@
+import math
+from contextlib import contextmanager
+from typing import Any, Dict, List, Tuple, Union, Optional
+from omegaconf import ListConfig, OmegaConf
+from copy import deepcopy
+import torch.nn.functional as F
+
+from sat.helpers import print_rank0
+import torch
+from torch import nn
+
+from sgm.modules import UNCONDITIONAL_CONFIG
+from sgm.modules.autoencoding.temporal_ae import VideoDecoder
+from sgm.modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER
+from sgm.util import (
+ default,
+ disabled_train,
+ get_obj_from_str,
+ instantiate_from_config,
+ log_txt_as_img,
+)
+import gc
+from sat import mpu
+import random
+
+
+class SATVideoDiffusionEngine(nn.Module):
+ def __init__(self, args, **kwargs):
+ super().__init__()
+
+ model_config = args.model_config
+ # model args preprocess
+ log_keys = model_config.get("log_keys", None)
+ input_key = model_config.get("input_key", "mp4")
+ network_config = model_config.get("network_config", None)
+ network_wrapper = model_config.get("network_wrapper", None)
+ denoiser_config = model_config.get("denoiser_config", None)
+ sampler_config = model_config.get("sampler_config", None)
+ conditioner_config = model_config.get("conditioner_config", None)
+ first_stage_config = model_config.get("first_stage_config", None)
+ loss_fn_config = model_config.get("loss_fn_config", None)
+ scale_factor = model_config.get("scale_factor", 1.0)
+ latent_input = model_config.get("latent_input", False)
+ disable_first_stage_autocast = model_config.get("disable_first_stage_autocast", False)
+ no_cond_log = model_config.get("disable_first_stage_autocast", False)
+ not_trainable_prefixes = model_config.get("not_trainable_prefixes", ["first_stage_model", "conditioner"])
+ compile_model = model_config.get("compile_model", False)
+ en_and_decode_n_samples_a_time = model_config.get("en_and_decode_n_samples_a_time", None)
+ lr_scale = model_config.get("lr_scale", None)
+ lora_train = model_config.get("lora_train", False)
+ self.use_pd = model_config.get("use_pd", False) # progressive distillation
+
+ self.log_keys = log_keys
+ self.input_key = input_key
+ self.not_trainable_prefixes = not_trainable_prefixes
+ self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time
+ self.lr_scale = lr_scale
+ self.lora_train = lora_train
+ self.noised_image_input = model_config.get("noised_image_input", False)
+ self.noised_image_all_concat = model_config.get("noised_image_all_concat", False)
+ self.noised_image_dropout = model_config.get("noised_image_dropout", 0.0)
+ if args.fp16:
+ dtype = torch.float16
+ dtype_str = "fp16"
+ elif args.bf16:
+ dtype = torch.bfloat16
+ dtype_str = "bf16"
+ else:
+ dtype = torch.float32
+ dtype_str = "fp32"
+ self.dtype = dtype
+ self.dtype_str = dtype_str
+
+ network_config["params"]["dtype"] = dtype_str
+ model = instantiate_from_config(network_config)
+ self.model = get_obj_from_str(default(network_wrapper, OPENAIUNETWRAPPER))(
+ model, compile_model=compile_model, dtype=dtype
+ )
+
+ self.denoiser = instantiate_from_config(denoiser_config)
+ self.sampler = instantiate_from_config(sampler_config) if sampler_config is not None else None
+ self.conditioner = instantiate_from_config(default(conditioner_config, UNCONDITIONAL_CONFIG))
+
+ self._init_first_stage(first_stage_config)
+
+ self.loss_fn = instantiate_from_config(loss_fn_config) if loss_fn_config is not None else None
+
+ self.latent_input = latent_input
+ self.scale_factor = scale_factor
+ self.disable_first_stage_autocast = disable_first_stage_autocast
+ self.no_cond_log = no_cond_log
+ self.device = args.device
+
+ def disable_untrainable_params(self):
+ total_trainable = 0
+ for n, p in self.named_parameters():
+ if p.requires_grad == False:
+ continue
+ flag = False
+ for prefix in self.not_trainable_prefixes:
+ if n.startswith(prefix) or prefix == "all":
+ flag = True
+ break
+
+ lora_prefix = ["matrix_A", "matrix_B"]
+ for prefix in lora_prefix:
+ if prefix in n:
+ flag = False
+ break
+
+ if flag:
+ p.requires_grad_(False)
+ else:
+ total_trainable += p.numel()
+
+ print_rank0("***** Total trainable parameters: " + str(total_trainable) + " *****")
+
+ def reinit(self, parent_model=None):
+ # reload the initial params from previous trained modules
+ # you can also get access to other mixins through parent_model.get_mixin().
+ pass
+
+ def _init_first_stage(self, config):
+ model = instantiate_from_config(config).eval()
+ model.train = disabled_train
+ for param in model.parameters():
+ param.requires_grad = False
+ self.first_stage_model = model
+
+ def forward(self, x, batch):
+ loss = self.loss_fn(self.model, self.denoiser, self.conditioner, x, batch)
+ loss_mean = loss.mean()
+ loss_dict = {"loss": loss_mean}
+ return loss_mean, loss_dict
+
+ def shared_step(self, batch: Dict) -> Any:
+ x = self.get_input(batch)
+ if self.lr_scale is not None:
+ lr_x = F.interpolate(x, scale_factor=1 / self.lr_scale, mode="bilinear", align_corners=False)
+ lr_x = F.interpolate(lr_x, scale_factor=self.lr_scale, mode="bilinear", align_corners=False)
+ lr_z = self.encode_first_stage(lr_x, batch)
+ batch["lr_input"] = lr_z
+
+ x = x.permute(0, 2, 1, 3, 4).contiguous()
+ x = self.encode_first_stage(x, batch)
+ x = x.permute(0, 2, 1, 3, 4).contiguous()
+
+ gc.collect()
+ torch.cuda.empty_cache()
+ loss, loss_dict = self(x, batch)
+ return loss, loss_dict
+
+ def get_input(self, batch):
+ return batch[self.input_key].to(self.dtype)
+
+ @torch.no_grad()
+ def decode_first_stage(self, z):
+ z = 1.0 / self.scale_factor * z
+ n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0])
+ n_rounds = math.ceil(z.shape[0] / n_samples)
+ all_out = []
+ with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
+ for n in range(n_rounds):
+ if isinstance(self.first_stage_model.decoder, VideoDecoder):
+ kwargs = {"timesteps": len(z[n * n_samples : (n + 1) * n_samples])}
+ else:
+ kwargs = {}
+ use_cp = False
+ out = self.first_stage_model.decode(z[n * n_samples : (n + 1) * n_samples], **kwargs)
+ all_out.append(out)
+ out = torch.cat(all_out, dim=0)
+ return out
+
+ @torch.no_grad()
+ def encode_first_stage(self, x, batch):
+ frame = x.shape[2]
+
+ if frame > 1 and self.latent_input:
+ x = x.permute(0, 2, 1, 3, 4).contiguous()
+ return x * self.scale_factor # already encoded
+
+ use_cp = False
+
+ n_samples = default(self.en_and_decode_n_samples_a_time, x.shape[0])
+ n_rounds = math.ceil(x.shape[0] / n_samples)
+ all_out = []
+ with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
+ for n in range(n_rounds):
+ out = self.first_stage_model.encode(x[n * n_samples : (n + 1) * n_samples])
+ all_out.append(out)
+ z = torch.cat(all_out, dim=0)
+ z = self.scale_factor * z
+ return z
+
+ @torch.no_grad()
+ def sample(
+ self,
+ cond: Dict,
+ uc: Union[Dict, None] = None,
+ batch_size: int = 16,
+ shape: Union[None, Tuple, List] = None,
+ prefix=None,
+ concat_images=None,
+ **kwargs,
+ ):
+ randn = torch.randn(batch_size, *shape).to(torch.float32).to(self.device)
+ if hasattr(self, "seeded_noise"):
+ randn = self.seeded_noise(randn)
+
+ if prefix is not None:
+ randn = torch.cat([prefix, randn[:, prefix.shape[1] :]], dim=1)
+
+ # broadcast noise
+ mp_size = mpu.get_model_parallel_world_size()
+ if mp_size > 1:
+ global_rank = torch.distributed.get_rank() // mp_size
+ src = global_rank * mp_size
+ torch.distributed.broadcast(randn, src=src, group=mpu.get_model_parallel_group())
+
+ scale = None
+ scale_emb = None
+
+ denoiser = lambda input, sigma, c, **addtional_model_inputs: self.denoiser(
+ self.model, input, sigma, c, concat_images=concat_images, **addtional_model_inputs
+ )
+
+ samples = self.sampler(denoiser, randn, cond, uc=uc, scale=scale, scale_emb=scale_emb)
+ samples = samples.to(self.dtype)
+ return samples
+
+ @torch.no_grad()
+ def log_conditionings(self, batch: Dict, n: int) -> Dict:
+ """
+ Defines heuristics to log different conditionings.
+ These can be lists of strings (text-to-image), tensors, ints, ...
+ """
+ image_h, image_w = batch[self.input_key].shape[3:]
+ log = dict()
+
+ for embedder in self.conditioner.embedders:
+ if ((self.log_keys is None) or (embedder.input_key in self.log_keys)) and not self.no_cond_log:
+ x = batch[embedder.input_key][:n]
+ if isinstance(x, torch.Tensor):
+ if x.dim() == 1:
+ # class-conditional, convert integer to string
+ x = [str(x[i].item()) for i in range(x.shape[0])]
+ xc = log_txt_as_img((image_h, image_w), x, size=image_h // 4)
+ elif x.dim() == 2:
+ # size and crop cond and the like
+ x = ["x".join([str(xx) for xx in x[i].tolist()]) for i in range(x.shape[0])]
+ xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
+ else:
+ raise NotImplementedError()
+ elif isinstance(x, (List, ListConfig)):
+ if isinstance(x[0], str):
+ xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
+ else:
+ raise NotImplementedError()
+ else:
+ raise NotImplementedError()
+ log[embedder.input_key] = xc
+ return log
+
+ @torch.no_grad()
+ def log_video(
+ self,
+ batch: Dict,
+ N: int = 8,
+ ucg_keys: List[str] = None,
+ only_log_video_latents=False,
+ **kwargs,
+ ) -> Dict:
+ conditioner_input_keys = [e.input_key for e in self.conditioner.embedders]
+ if ucg_keys:
+ assert all(map(lambda x: x in conditioner_input_keys, ucg_keys)), (
+ "Each defined ucg key for sampling must be in the provided conditioner input keys,"
+ f"but we have {ucg_keys} vs. {conditioner_input_keys}"
+ )
+ else:
+ ucg_keys = conditioner_input_keys
+ log = dict()
+
+ x = self.get_input(batch)
+
+ c, uc = self.conditioner.get_unconditional_conditioning(
+ batch,
+ force_uc_zero_embeddings=ucg_keys if len(self.conditioner.embedders) > 0 else [],
+ )
+
+ sampling_kwargs = {}
+
+ N = min(x.shape[0], N)
+ x = x.to(self.device)[:N]
+ if not self.latent_input:
+ log["inputs"] = x.to(torch.float32)
+ x = x.permute(0, 2, 1, 3, 4).contiguous()
+ z = self.encode_first_stage(x, batch)
+ if not only_log_video_latents:
+ log["reconstructions"] = self.decode_first_stage(z).to(torch.float32)
+ log["reconstructions"] = log["reconstructions"].permute(0, 2, 1, 3, 4).contiguous()
+ z = z.permute(0, 2, 1, 3, 4).contiguous()
+
+ log.update(self.log_conditionings(batch, N))
+
+ for k in c:
+ if isinstance(c[k], torch.Tensor):
+ c[k], uc[k] = map(lambda y: y[k][:N].to(self.device), (c, uc))
+
+ samples = self.sample(c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs) # b t c h w
+ samples = samples.permute(0, 2, 1, 3, 4).contiguous()
+ if only_log_video_latents:
+ latents = 1.0 / self.scale_factor * samples
+ log["latents"] = latents
+ else:
+ samples = self.decode_first_stage(samples).to(torch.float32)
+ samples = samples.permute(0, 2, 1, 3, 4).contiguous()
+ log["samples"] = samples
+ return log
diff --git a/sat/dit_video_concat.py b/sat/dit_video_concat.py
new file mode 100644
index 0000000000000000000000000000000000000000..d42a92e6693e8b8cc48d19bb28a990bba8848dff
--- /dev/null
+++ b/sat/dit_video_concat.py
@@ -0,0 +1,858 @@
+from functools import partial
+from einops import rearrange, repeat
+import numpy as np
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+from sat.model.base_model import BaseModel, non_conflict
+from sat.model.mixins import BaseMixin
+from sat.transformer_defaults import HOOKS_DEFAULT, attention_fn_default
+from sat.mpu.layers import ColumnParallelLinear
+from sgm.util import instantiate_from_config
+
+from sgm.modules.diffusionmodules.openaimodel import Timestep
+from sgm.modules.diffusionmodules.util import (
+ linear,
+ timestep_embedding,
+)
+from sat.ops.layernorm import LayerNorm, RMSNorm
+
+
+class ImagePatchEmbeddingMixin(BaseMixin):
+ def __init__(
+ self,
+ in_channels,
+ hidden_size,
+ patch_size,
+ bias=True,
+ text_hidden_size=None,
+ ):
+ super().__init__()
+ self.proj = nn.Conv2d(in_channels, hidden_size, kernel_size=patch_size, stride=patch_size, bias=bias)
+ if text_hidden_size is not None:
+ self.text_proj = nn.Linear(text_hidden_size, hidden_size)
+ else:
+ self.text_proj = None
+
+ def word_embedding_forward(self, input_ids, **kwargs):
+ # now is 3d patch
+ images = kwargs["images"] # (b,t,c,h,w)
+ B, T = images.shape[:2]
+ emb = images.view(-1, *images.shape[2:])
+ emb = self.proj(emb) # ((b t),d,h/2,w/2)
+ emb = emb.view(B, T, *emb.shape[1:])
+ emb = emb.flatten(3).transpose(2, 3) # (b,t,n,d)
+ emb = rearrange(emb, "b t n d -> b (t n) d")
+
+ if self.text_proj is not None:
+ text_emb = self.text_proj(kwargs["encoder_outputs"])
+ emb = torch.cat((text_emb, emb), dim=1) # (b,n_t+t*n_i,d)
+
+ emb = emb.contiguous()
+ return emb # (b,n_t+t*n_i,d)
+
+ def reinit(self, parent_model=None):
+ w = self.proj.weight.data
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
+ nn.init.constant_(self.proj.bias, 0)
+ del self.transformer.word_embeddings
+
+
+def get_3d_sincos_pos_embed(
+ embed_dim,
+ grid_height,
+ grid_width,
+ t_size,
+ cls_token=False,
+ height_interpolation=1.0,
+ width_interpolation=1.0,
+ time_interpolation=1.0,
+):
+ """
+ grid_size: int of the grid height and width
+ t_size: int of the temporal size
+ return:
+ pos_embed: [t_size*grid_size*grid_size, embed_dim] or [1+t_size*grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
+ """
+ assert embed_dim % 4 == 0
+ embed_dim_spatial = embed_dim // 4 * 3
+ embed_dim_temporal = embed_dim // 4
+
+ # spatial
+ grid_h = np.arange(grid_height, dtype=np.float32) / height_interpolation
+ grid_w = np.arange(grid_width, dtype=np.float32) / width_interpolation
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
+ grid = np.stack(grid, axis=0)
+
+ grid = grid.reshape([2, 1, grid_height, grid_width])
+ pos_embed_spatial = get_2d_sincos_pos_embed_from_grid(embed_dim_spatial, grid)
+
+ # temporal
+ grid_t = np.arange(t_size, dtype=np.float32) / time_interpolation
+ pos_embed_temporal = get_1d_sincos_pos_embed_from_grid(embed_dim_temporal, grid_t)
+
+ # concate: [T, H, W] order
+ pos_embed_temporal = pos_embed_temporal[:, np.newaxis, :]
+ pos_embed_temporal = np.repeat(pos_embed_temporal, grid_height * grid_width, axis=1) # [T, H*W, D // 4]
+ pos_embed_spatial = pos_embed_spatial[np.newaxis, :, :]
+ pos_embed_spatial = np.repeat(pos_embed_spatial, t_size, axis=0) # [T, H*W, D // 4 * 3]
+
+ pos_embed = np.concatenate([pos_embed_temporal, pos_embed_spatial], axis=-1)
+ # pos_embed = pos_embed.reshape([-1, embed_dim]) # [T*H*W, D]
+
+ return pos_embed # [T, H*W, D]
+
+
+def get_2d_sincos_pos_embed(embed_dim, grid_height, grid_width, cls_token=False, extra_tokens=0):
+ """
+ grid_size: int of the grid height and width
+ return:
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
+ """
+ grid_h = np.arange(grid_height, dtype=np.float32)
+ grid_w = np.arange(grid_width, dtype=np.float32)
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
+ grid = np.stack(grid, axis=0)
+
+ grid = grid.reshape([2, 1, grid_height, grid_width])
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
+ if cls_token and extra_tokens > 0:
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
+ return pos_embed
+
+
+def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
+ assert embed_dim % 2 == 0
+
+ # use half of dimensions to encode grid_h
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
+
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
+ return emb
+
+
+def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
+ """
+ embed_dim: output dimension for each position
+ pos: a list of positions to be encoded: size (M,)
+ out: (M, D)
+ """
+ assert embed_dim % 2 == 0
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
+ omega /= embed_dim / 2.0
+ omega = 1.0 / 10000**omega # (D/2,)
+
+ pos = pos.reshape(-1) # (M,)
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
+
+ emb_sin = np.sin(out) # (M, D/2)
+ emb_cos = np.cos(out) # (M, D/2)
+
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
+ return emb
+
+
+class Basic3DPositionEmbeddingMixin(BaseMixin):
+ def __init__(
+ self,
+ height,
+ width,
+ compressed_num_frames,
+ hidden_size,
+ text_length=0,
+ height_interpolation=1.0,
+ width_interpolation=1.0,
+ time_interpolation=1.0,
+ ):
+ super().__init__()
+ self.height = height
+ self.width = width
+ self.text_length = text_length
+ self.compressed_num_frames = compressed_num_frames
+ self.spatial_length = height * width
+ self.num_patches = height * width * compressed_num_frames
+ self.pos_embedding = nn.Parameter(
+ torch.zeros(1, int(text_length + self.num_patches), int(hidden_size)), requires_grad=False
+ )
+ self.height_interpolation = height_interpolation
+ self.width_interpolation = width_interpolation
+ self.time_interpolation = time_interpolation
+
+ def position_embedding_forward(self, position_ids, **kwargs):
+ if kwargs["images"].shape[1] == 1:
+ return self.pos_embedding[:, : self.text_length + self.spatial_length]
+
+ return self.pos_embedding[:, : self.text_length + kwargs["seq_length"]]
+
+ def reinit(self, parent_model=None):
+ del self.transformer.position_embeddings
+ pos_embed = get_3d_sincos_pos_embed(
+ self.pos_embedding.shape[-1],
+ self.height,
+ self.width,
+ self.compressed_num_frames,
+ height_interpolation=self.height_interpolation,
+ width_interpolation=self.width_interpolation,
+ time_interpolation=self.time_interpolation,
+ )
+ pos_embed = torch.from_numpy(pos_embed).float()
+ pos_embed = rearrange(pos_embed, "t n d -> (t n) d")
+ self.pos_embedding.data[:, -self.num_patches :].copy_(pos_embed)
+
+
+def broadcat(tensors, dim=-1):
+ num_tensors = len(tensors)
+ shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
+ assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
+ shape_len = list(shape_lens)[0]
+ dim = (dim + shape_len) if dim < 0 else dim
+ dims = list(zip(*map(lambda t: list(t.shape), tensors)))
+ expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
+ assert all(
+ [*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]
+ ), "invalid dimensions for broadcastable concatentation"
+ max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
+ expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
+ expanded_dims.insert(dim, (dim, dims[dim]))
+ expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
+ tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
+ return torch.cat(tensors, dim=dim)
+
+
+def rotate_half(x):
+ x = rearrange(x, "... (d r) -> ... d r", r=2)
+ x1, x2 = x.unbind(dim=-1)
+ x = torch.stack((-x2, x1), dim=-1)
+ return rearrange(x, "... d r -> ... (d r)")
+
+
+class Rotary3DPositionEmbeddingMixin(BaseMixin):
+ def __init__(
+ self,
+ height,
+ width,
+ compressed_num_frames,
+ hidden_size,
+ hidden_size_head,
+ text_length,
+ theta=10000,
+ rot_v=False,
+ pnp=False,
+ learnable_pos_embed=False,
+ ):
+ super().__init__()
+ self.rot_v = rot_v
+
+ dim_t = hidden_size_head // 4
+ dim_h = hidden_size_head // 8 * 3
+ dim_w = hidden_size_head // 8 * 3
+
+ # 'lang':
+ freqs_t = 1.0 / (theta ** (torch.arange(0, dim_t, 2)[: (dim_t // 2)].float() / dim_t))
+ freqs_h = 1.0 / (theta ** (torch.arange(0, dim_h, 2)[: (dim_h // 2)].float() / dim_h))
+ freqs_w = 1.0 / (theta ** (torch.arange(0, dim_w, 2)[: (dim_w // 2)].float() / dim_w))
+
+ grid_t = torch.arange(compressed_num_frames, dtype=torch.float32)
+ grid_h = torch.arange(height, dtype=torch.float32)
+ grid_w = torch.arange(width, dtype=torch.float32)
+
+ freqs_t = torch.einsum("..., f -> ... f", grid_t, freqs_t)
+ freqs_h = torch.einsum("..., f -> ... f", grid_h, freqs_h)
+ freqs_w = torch.einsum("..., f -> ... f", grid_w, freqs_w)
+
+ freqs_t = repeat(freqs_t, "... n -> ... (n r)", r=2)
+ freqs_h = repeat(freqs_h, "... n -> ... (n r)", r=2)
+ freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2)
+
+ freqs = broadcat((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1)
+ # (T H W D)
+
+ self.pnp = pnp
+
+ if not self.pnp:
+ freqs = rearrange(freqs, "t h w d -> (t h w) d")
+
+ freqs = freqs.contiguous()
+ freqs_sin = freqs.sin()
+ freqs_cos = freqs.cos()
+ self.register_buffer("freqs_sin", freqs_sin)
+ self.register_buffer("freqs_cos", freqs_cos)
+
+ self.text_length = text_length
+ if learnable_pos_embed:
+ num_patches = height * width * compressed_num_frames + text_length
+ self.pos_embedding = nn.Parameter(torch.zeros(1, num_patches, int(hidden_size)), requires_grad=True)
+ else:
+ self.pos_embedding = None
+
+ def rotary(self, t, **kwargs):
+ if self.pnp:
+ t_coords = kwargs["rope_position_ids"][:, :, 0]
+ x_coords = kwargs["rope_position_ids"][:, :, 1]
+ y_coords = kwargs["rope_position_ids"][:, :, 2]
+ mask = (x_coords != -1) & (y_coords != -1) & (t_coords != -1)
+ freqs = torch.zeros([t.shape[0], t.shape[2], t.shape[3]], dtype=t.dtype, device=t.device)
+ freqs[mask] = self.freqs[t_coords[mask], x_coords[mask], y_coords[mask]]
+
+ else:
+
+ def reshape_freq(freqs):
+ frame = t.shape[2]
+ freqs = freqs[:frame].contiguous()
+ freqs = freqs.unsqueeze(0).unsqueeze(0)
+ return freqs
+
+ freqs_cos = reshape_freq(self.freqs_cos)
+ freqs_sin = reshape_freq(self.freqs_sin)
+
+ return t * freqs_cos + rotate_half(t) * freqs_sin
+
+ def position_embedding_forward(self, position_ids, **kwargs):
+ if self.pos_embedding is not None:
+ return self.pos_embedding[:, : self.text_length + kwargs["seq_length"]]
+ else:
+ return None
+
+ def attention_fn(
+ self,
+ query_layer,
+ key_layer,
+ value_layer,
+ attention_mask,
+ attention_dropout=None,
+ log_attention_weights=None,
+ scaling_attention_score=True,
+ **kwargs,
+ ):
+ attention_fn_default = HOOKS_DEFAULT["attention_fn"]
+
+ if self.pnp:
+ query_layer = self.rotary(query_layer, **kwargs)
+ key_layer = self.rotary(key_layer, **kwargs)
+ if self.rot_v:
+ value_layer = self.rotary(value_layer)
+ else:
+ query_layer = torch.cat(
+ (
+ query_layer[
+ :,
+ :,
+ : kwargs["text_length"],
+ ],
+ self.rotary(
+ query_layer[
+ :,
+ :,
+ kwargs["text_length"] :,
+ ]
+ ),
+ ),
+ dim=2,
+ )
+ key_layer = torch.cat(
+ (
+ key_layer[
+ :,
+ :,
+ : kwargs["text_length"],
+ ],
+ self.rotary(
+ key_layer[
+ :,
+ :,
+ kwargs["text_length"] :,
+ ]
+ ),
+ ),
+ dim=2,
+ )
+ if self.rot_v:
+ value_layer = torch.cat(
+ (
+ value_layer[
+ :,
+ :,
+ : kwargs["text_length"],
+ ],
+ self.rotary(
+ value_layer[
+ :,
+ :,
+ kwargs["text_length"] :,
+ ]
+ ),
+ ),
+ dim=2,
+ )
+
+ return attention_fn_default(
+ query_layer,
+ key_layer,
+ value_layer,
+ attention_mask,
+ attention_dropout=attention_dropout,
+ log_attention_weights=log_attention_weights,
+ scaling_attention_score=scaling_attention_score,
+ **kwargs,
+ )
+
+
+def modulate(x, shift, scale):
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
+
+
+def unpatchify(x, c, p, w, h, rope_position_ids=None, **kwargs):
+ """
+ x: (N, T/2 * S, patch_size**3 * C)
+ imgs: (N, T, H, W, C)
+ """
+ if rope_position_ids is not None:
+ assert NotImplementedError
+ # do pix2struct unpatchify
+ L = x.shape[1]
+ x = x.reshape(shape=(x.shape[0], L, p, p, c))
+ x = torch.einsum("nlpqc->ncplq", x)
+ imgs = x.reshape(shape=(x.shape[0], c, p, L * p))
+ else:
+ b = x.shape[0]
+ imgs = rearrange(x, "b (t h w) (c p q) -> b t c (h p) (w q)", b=b, h=h, w=w, c=c, p=p, q=p)
+
+ return imgs
+
+
+class FinalLayerMixin(BaseMixin):
+ def __init__(
+ self,
+ hidden_size,
+ time_embed_dim,
+ patch_size,
+ out_channels,
+ latent_width,
+ latent_height,
+ elementwise_affine,
+ ):
+ super().__init__()
+ self.hidden_size = hidden_size
+ self.patch_size = patch_size
+ self.out_channels = out_channels
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=elementwise_affine, eps=1e-6)
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(time_embed_dim, 2 * hidden_size, bias=True))
+
+ self.spatial_length = latent_width * latent_height // patch_size**2
+ self.latent_width = latent_width
+ self.latent_height = latent_height
+
+ def final_forward(self, logits, **kwargs):
+ x, emb = logits[:, kwargs["text_length"] :, :], kwargs["emb"] # x:(b,(t n),d)
+
+ shift, scale = self.adaLN_modulation(emb).chunk(2, dim=1)
+ x = modulate(self.norm_final(x), shift, scale)
+ x = self.linear(x)
+
+ return unpatchify(
+ x,
+ c=self.out_channels,
+ p=self.patch_size,
+ w=self.latent_width // self.patch_size,
+ h=self.latent_height // self.patch_size,
+ rope_position_ids=kwargs.get("rope_position_ids", None),
+ **kwargs,
+ )
+
+ def reinit(self, parent_model=None):
+ nn.init.xavier_uniform_(self.linear.weight)
+ nn.init.constant_(self.linear.bias, 0)
+
+
+class SwiGLUMixin(BaseMixin):
+ def __init__(self, num_layers, in_features, hidden_features, bias=False):
+ super().__init__()
+ self.w2 = nn.ModuleList(
+ [
+ ColumnParallelLinear(
+ in_features,
+ hidden_features,
+ gather_output=False,
+ bias=bias,
+ module=self,
+ name="dense_h_to_4h_gate",
+ )
+ for i in range(num_layers)
+ ]
+ )
+
+ def mlp_forward(self, hidden_states, **kw_args):
+ x = hidden_states
+ origin = self.transformer.layers[kw_args["layer_id"]].mlp
+ x1 = origin.dense_h_to_4h(x)
+ x2 = self.w2[kw_args["layer_id"]](x)
+ hidden = origin.activation_func(x2) * x1
+ x = origin.dense_4h_to_h(hidden)
+ return x
+
+
+class AdaLNMixin(BaseMixin):
+ def __init__(
+ self,
+ width,
+ height,
+ hidden_size,
+ num_layers,
+ time_embed_dim,
+ compressed_num_frames,
+ qk_ln=True,
+ hidden_size_head=None,
+ elementwise_affine=True,
+ ):
+ super().__init__()
+ self.num_layers = num_layers
+ self.width = width
+ self.height = height
+ self.compressed_num_frames = compressed_num_frames
+
+ self.adaLN_modulations = nn.ModuleList(
+ [nn.Sequential(nn.SiLU(), nn.Linear(time_embed_dim, 12 * hidden_size)) for _ in range(num_layers)]
+ )
+
+ self.qk_ln = qk_ln
+ if qk_ln:
+ self.query_layernorm_list = nn.ModuleList(
+ [
+ LayerNorm(hidden_size_head, eps=1e-6, elementwise_affine=elementwise_affine)
+ for _ in range(num_layers)
+ ]
+ )
+ self.key_layernorm_list = nn.ModuleList(
+ [
+ LayerNorm(hidden_size_head, eps=1e-6, elementwise_affine=elementwise_affine)
+ for _ in range(num_layers)
+ ]
+ )
+
+ def layer_forward(
+ self,
+ hidden_states,
+ mask,
+ *args,
+ **kwargs,
+ ):
+ text_length = kwargs["text_length"]
+ # hidden_states (b,(n_t+t*n_i),d)
+ text_hidden_states = hidden_states[:, :text_length] # (b,n,d)
+ img_hidden_states = hidden_states[:, text_length:] # (b,(t n),d)
+ layer = self.transformer.layers[kwargs["layer_id"]]
+ adaLN_modulation = self.adaLN_modulations[kwargs["layer_id"]]
+
+ (
+ shift_msa,
+ scale_msa,
+ gate_msa,
+ shift_mlp,
+ scale_mlp,
+ gate_mlp,
+ text_shift_msa,
+ text_scale_msa,
+ text_gate_msa,
+ text_shift_mlp,
+ text_scale_mlp,
+ text_gate_mlp,
+ ) = adaLN_modulation(kwargs["emb"]).chunk(12, dim=1)
+ gate_msa, gate_mlp, text_gate_msa, text_gate_mlp = (
+ gate_msa.unsqueeze(1),
+ gate_mlp.unsqueeze(1),
+ text_gate_msa.unsqueeze(1),
+ text_gate_mlp.unsqueeze(1),
+ )
+
+ # self full attention (b,(t n),d)
+ img_attention_input = layer.input_layernorm(img_hidden_states)
+ text_attention_input = layer.input_layernorm(text_hidden_states)
+ img_attention_input = modulate(img_attention_input, shift_msa, scale_msa)
+ text_attention_input = modulate(text_attention_input, text_shift_msa, text_scale_msa)
+
+ attention_input = torch.cat((text_attention_input, img_attention_input), dim=1) # (b,n_t+t*n_i,d)
+ attention_output = layer.attention(attention_input, mask, **kwargs)
+ text_attention_output = attention_output[:, :text_length] # (b,n,d)
+ img_attention_output = attention_output[:, text_length:] # (b,(t n),d)
+
+ if self.transformer.layernorm_order == "sandwich":
+ text_attention_output = layer.third_layernorm(text_attention_output)
+ img_attention_output = layer.third_layernorm(img_attention_output)
+ img_hidden_states = img_hidden_states + gate_msa * img_attention_output # (b,(t n),d)
+ text_hidden_states = text_hidden_states + text_gate_msa * text_attention_output # (b,n,d)
+
+ # mlp (b,(t n),d)
+ img_mlp_input = layer.post_attention_layernorm(img_hidden_states) # vision (b,(t n),d)
+ text_mlp_input = layer.post_attention_layernorm(text_hidden_states) # language (b,n,d)
+ img_mlp_input = modulate(img_mlp_input, shift_mlp, scale_mlp)
+ text_mlp_input = modulate(text_mlp_input, text_shift_mlp, text_scale_mlp)
+ mlp_input = torch.cat((text_mlp_input, img_mlp_input), dim=1) # (b,(n_t+t*n_i),d
+ mlp_output = layer.mlp(mlp_input, **kwargs)
+ img_mlp_output = mlp_output[:, text_length:] # vision (b,(t n),d)
+ text_mlp_output = mlp_output[:, :text_length] # language (b,n,d)
+ if self.transformer.layernorm_order == "sandwich":
+ text_mlp_output = layer.fourth_layernorm(text_mlp_output)
+ img_mlp_output = layer.fourth_layernorm(img_mlp_output)
+
+ img_hidden_states = img_hidden_states + gate_mlp * img_mlp_output # vision (b,(t n),d)
+ text_hidden_states = text_hidden_states + text_gate_mlp * text_mlp_output # language (b,n,d)
+
+ hidden_states = torch.cat((text_hidden_states, img_hidden_states), dim=1) # (b,(n_t+t*n_i),d)
+ return hidden_states
+
+ def reinit(self, parent_model=None):
+ for layer in self.adaLN_modulations:
+ nn.init.constant_(layer[-1].weight, 0)
+ nn.init.constant_(layer[-1].bias, 0)
+
+ @non_conflict
+ def attention_fn(
+ self,
+ query_layer,
+ key_layer,
+ value_layer,
+ attention_mask,
+ attention_dropout=None,
+ log_attention_weights=None,
+ scaling_attention_score=True,
+ old_impl=attention_fn_default,
+ **kwargs,
+ ):
+ if self.qk_ln:
+ query_layernorm = self.query_layernorm_list[kwargs["layer_id"]]
+ key_layernorm = self.key_layernorm_list[kwargs["layer_id"]]
+ query_layer = query_layernorm(query_layer)
+ key_layer = key_layernorm(key_layer)
+
+ return old_impl(
+ query_layer,
+ key_layer,
+ value_layer,
+ attention_mask,
+ attention_dropout=attention_dropout,
+ log_attention_weights=log_attention_weights,
+ scaling_attention_score=scaling_attention_score,
+ **kwargs,
+ )
+
+
+str_to_dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}
+
+
+class DiffusionTransformer(BaseModel):
+ def __init__(
+ self,
+ transformer_args,
+ num_frames,
+ time_compressed_rate,
+ latent_width,
+ latent_height,
+ patch_size,
+ in_channels,
+ out_channels,
+ hidden_size,
+ num_layers,
+ num_attention_heads,
+ elementwise_affine,
+ time_embed_dim=None,
+ num_classes=None,
+ modules={},
+ input_time="adaln",
+ adm_in_channels=None,
+ parallel_output=True,
+ height_interpolation=1.0,
+ width_interpolation=1.0,
+ time_interpolation=1.0,
+ use_SwiGLU=False,
+ use_RMSNorm=False,
+ zero_init_y_embed=False,
+ **kwargs,
+ ):
+ self.latent_width = latent_width
+ self.latent_height = latent_height
+ self.patch_size = patch_size
+ self.num_frames = num_frames
+ self.time_compressed_rate = time_compressed_rate
+ self.spatial_length = latent_width * latent_height // patch_size**2
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.hidden_size = hidden_size
+ self.model_channels = hidden_size
+ self.time_embed_dim = time_embed_dim if time_embed_dim is not None else hidden_size
+ self.num_classes = num_classes
+ self.adm_in_channels = adm_in_channels
+ self.input_time = input_time
+ self.num_layers = num_layers
+ self.num_attention_heads = num_attention_heads
+ self.is_decoder = transformer_args.is_decoder
+ self.elementwise_affine = elementwise_affine
+ self.height_interpolation = height_interpolation
+ self.width_interpolation = width_interpolation
+ self.time_interpolation = time_interpolation
+ self.inner_hidden_size = hidden_size * 4
+ self.zero_init_y_embed = zero_init_y_embed
+ try:
+ self.dtype = str_to_dtype[kwargs.pop("dtype")]
+ except:
+ self.dtype = torch.float32
+
+ if use_SwiGLU:
+ kwargs["activation_func"] = F.silu
+ elif "activation_func" not in kwargs:
+ approx_gelu = nn.GELU(approximate="tanh")
+ kwargs["activation_func"] = approx_gelu
+
+ if use_RMSNorm:
+ kwargs["layernorm"] = RMSNorm
+ else:
+ kwargs["layernorm"] = partial(LayerNorm, elementwise_affine=elementwise_affine, eps=1e-6)
+
+ transformer_args.num_layers = num_layers
+ transformer_args.hidden_size = hidden_size
+ transformer_args.num_attention_heads = num_attention_heads
+ transformer_args.parallel_output = parallel_output
+ super().__init__(args=transformer_args, transformer=None, **kwargs)
+
+ module_configs = modules
+ self._build_modules(module_configs)
+
+ if use_SwiGLU:
+ self.add_mixin(
+ "swiglu", SwiGLUMixin(num_layers, hidden_size, self.inner_hidden_size, bias=False), reinit=True
+ )
+
+ def _build_modules(self, module_configs):
+ model_channels = self.hidden_size
+ # time_embed_dim = model_channels * 4
+ time_embed_dim = self.time_embed_dim
+ self.time_embed = nn.Sequential(
+ linear(model_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ )
+
+ if self.num_classes is not None:
+ if isinstance(self.num_classes, int):
+ self.label_emb = nn.Embedding(self.num_classes, time_embed_dim)
+ elif self.num_classes == "continuous":
+ print("setting up linear c_adm embedding layer")
+ self.label_emb = nn.Linear(1, time_embed_dim)
+ elif self.num_classes == "timestep":
+ self.label_emb = nn.Sequential(
+ Timestep(model_channels),
+ nn.Sequential(
+ linear(model_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ ),
+ )
+ elif self.num_classes == "sequential":
+ assert self.adm_in_channels is not None
+ self.label_emb = nn.Sequential(
+ nn.Sequential(
+ linear(self.adm_in_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ )
+ )
+ if self.zero_init_y_embed:
+ nn.init.constant_(self.label_emb[0][2].weight, 0)
+ nn.init.constant_(self.label_emb[0][2].bias, 0)
+ else:
+ raise ValueError()
+
+ pos_embed_config = module_configs["pos_embed_config"]
+ self.add_mixin(
+ "pos_embed",
+ instantiate_from_config(
+ pos_embed_config,
+ height=self.latent_height // self.patch_size,
+ width=self.latent_width // self.patch_size,
+ compressed_num_frames=(self.num_frames - 1) // self.time_compressed_rate + 1,
+ hidden_size=self.hidden_size,
+ ),
+ reinit=True,
+ )
+
+ patch_embed_config = module_configs["patch_embed_config"]
+ self.add_mixin(
+ "patch_embed",
+ instantiate_from_config(
+ patch_embed_config,
+ patch_size=self.patch_size,
+ hidden_size=self.hidden_size,
+ in_channels=self.in_channels,
+ ),
+ reinit=True,
+ )
+ if self.input_time == "adaln":
+ adaln_layer_config = module_configs["adaln_layer_config"]
+ self.add_mixin(
+ "adaln_layer",
+ instantiate_from_config(
+ adaln_layer_config,
+ height=self.latent_height // self.patch_size,
+ width=self.latent_width // self.patch_size,
+ hidden_size=self.hidden_size,
+ num_layers=self.num_layers,
+ compressed_num_frames=(self.num_frames - 1) // self.time_compressed_rate + 1,
+ hidden_size_head=self.hidden_size // self.num_attention_heads,
+ time_embed_dim=self.time_embed_dim,
+ elementwise_affine=self.elementwise_affine,
+ ),
+ )
+ else:
+ raise NotImplementedError
+
+ final_layer_config = module_configs["final_layer_config"]
+ self.add_mixin(
+ "final_layer",
+ instantiate_from_config(
+ final_layer_config,
+ hidden_size=self.hidden_size,
+ patch_size=self.patch_size,
+ out_channels=self.out_channels,
+ time_embed_dim=self.time_embed_dim,
+ latent_width=self.latent_width,
+ latent_height=self.latent_height,
+ elementwise_affine=self.elementwise_affine,
+ ),
+ reinit=True,
+ )
+
+ if "lora_config" in module_configs:
+ lora_config = module_configs["lora_config"]
+ self.add_mixin("lora", instantiate_from_config(lora_config, layer_num=self.num_layers), reinit=True)
+
+ return
+
+ def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
+ b, t, d, h, w = x.shape
+ if x.dtype != self.dtype:
+ x = x.to(self.dtype)
+ assert (y is not None) == (
+ self.num_classes is not None
+ ), "must specify y if and only if the model is class-conditional"
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False, dtype=self.dtype)
+ emb = self.time_embed(t_emb)
+
+ if self.num_classes is not None:
+ # assert y.shape[0] == x.shape[0]
+ assert x.shape[0] % y.shape[0] == 0
+ y = y.repeat_interleave(x.shape[0] // y.shape[0], dim=0)
+ emb = emb + self.label_emb(y)
+
+ kwargs["seq_length"] = t * h * w // (self.patch_size**2)
+ kwargs["images"] = x
+ kwargs["emb"] = emb
+ kwargs["encoder_outputs"] = context
+ kwargs["text_length"] = context.shape[1]
+
+ kwargs["input_ids"] = kwargs["position_ids"] = kwargs["attention_mask"] = torch.ones((1, 1)).to(x.dtype)
+ output = super().forward(**kwargs)[0]
+
+ return output
diff --git a/sat/finetune.sh b/sat/finetune.sh
new file mode 100644
index 0000000000000000000000000000000000000000..da3124786045974f6943513b6906a2fcd2c3869b
--- /dev/null
+++ b/sat/finetune.sh
@@ -0,0 +1,12 @@
+#! /bin/bash
+
+echo "RUN on `hostname`, CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
+
+environs="WORLD_SIZE=1 RANK=0 LOCAL_RANK=0 LOCAL_WORLD_SIZE=1"
+
+run_cmd="$environs python train_video.py --base configs/cogvideox_2b_sft.yaml --seed $RANDOM"
+
+echo ${run_cmd}
+eval ${run_cmd}
+
+echo "DONE on `hostname`"
\ No newline at end of file
diff --git a/sat/inference.sh b/sat/inference.sh
new file mode 100644
index 0000000000000000000000000000000000000000..8b446eee613d30ddaff45f0317af80caf2981944
--- /dev/null
+++ b/sat/inference.sh
@@ -0,0 +1,12 @@
+#! /bin/bash
+
+echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
+
+environs="WORLD_SIZE=1 RANK=0 LOCAL_RANK=0 LOCAL_WORLD_SIZE=1"
+
+run_cmd="$environs python sample_video.py --base configs/cogvideox_2b_infer.yaml"
+
+echo ${run_cmd}
+eval ${run_cmd}
+
+echo "DONE on `hostname`"
\ No newline at end of file
diff --git a/sat/requirements.txt b/sat/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..e209a7ffa17c2690e66daed88e1c490166774cc5
--- /dev/null
+++ b/sat/requirements.txt
@@ -0,0 +1,17 @@
+git+https://github.com/spacegoing/SwissArmyTransformer.git
+diffusers>=0.29.2
+omegaconf>=2.3.0
+torch>=2.3.1
+torchvision>=0.19.0
+pytorch_lightning>=2.3.3
+kornia>=0.7.3
+beartype>=0.18.5
+numpy>=2.0.1
+fsspec>=2024.5.0
+safetensors>=0.4.3
+imageio-ffmpeg>=0.5.1
+imageio>=2.34.2
+scipy>=1.14.0
+decord>=0.6.0
+wandb>=0.17.5
+deepspeed>=0.14.4
diff --git a/sat/sample_video.py b/sat/sample_video.py
new file mode 100644
index 0000000000000000000000000000000000000000..ad1940cc5237e39ca80b941a10302ef348bad811
--- /dev/null
+++ b/sat/sample_video.py
@@ -0,0 +1,236 @@
+import os
+import math
+import argparse
+from typing import List, Union
+from tqdm import tqdm
+from omegaconf import ListConfig
+import imageio
+
+import torch
+import numpy as np
+from einops import rearrange
+import torchvision.transforms as TT
+
+from sat.model.base_model import get_model
+from sat.training.model_io import load_checkpoint
+from sat import mpu
+
+from diffusion_video import SATVideoDiffusionEngine
+from arguments import get_args
+from torchvision.transforms.functional import center_crop, resize
+from torchvision.transforms import InterpolationMode
+
+
+def read_from_cli():
+ cnt = 0
+ try:
+ while True:
+ x = input("Please input English text (Ctrl-D quit): ")
+ yield x.strip(), cnt
+ cnt += 1
+ except EOFError as e:
+ pass
+
+
+def read_from_file(p, rank=0, world_size=1):
+ with open(p, "r") as fin:
+ cnt = -1
+ for l in fin:
+ cnt += 1
+ if cnt % world_size != rank:
+ continue
+ yield l.strip(), cnt
+
+
+def get_unique_embedder_keys_from_conditioner(conditioner):
+ return list(set([x.input_key for x in conditioner.embedders]))
+
+
+def get_batch(keys, value_dict, N: Union[List, ListConfig], T=None, device="cuda"):
+ batch = {}
+ batch_uc = {}
+
+ for key in keys:
+ if key == "txt":
+ batch["txt"] = np.repeat([value_dict["prompt"]], repeats=math.prod(N)).reshape(N).tolist()
+ batch_uc["txt"] = np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)).reshape(N).tolist()
+ else:
+ batch[key] = value_dict[key]
+
+ if T is not None:
+ batch["num_video_frames"] = T
+
+ for key in batch.keys():
+ if key not in batch_uc and isinstance(batch[key], torch.Tensor):
+ batch_uc[key] = torch.clone(batch[key])
+ return batch, batch_uc
+
+
+def save_video_as_grid_and_mp4(video_batch: torch.Tensor, save_path: str, fps: int = 5, args=None, key=None):
+ os.makedirs(save_path, exist_ok=True)
+
+ for i, vid in enumerate(video_batch):
+ gif_frames = []
+ for frame in vid:
+ frame = rearrange(frame, "c h w -> h w c")
+ frame = (255.0 * frame).cpu().numpy().astype(np.uint8)
+ gif_frames.append(frame)
+ now_save_path = os.path.join(save_path, f"{i:06d}.mp4")
+ with imageio.get_writer(now_save_path, fps=fps) as writer:
+ for frame in gif_frames:
+ writer.append_data(frame)
+
+
+def resize_for_rectangle_crop(arr, image_size, reshape_mode="random"):
+ if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]:
+ arr = resize(
+ arr,
+ size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])],
+ interpolation=InterpolationMode.BICUBIC,
+ )
+ else:
+ arr = resize(
+ arr,
+ size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]],
+ interpolation=InterpolationMode.BICUBIC,
+ )
+
+ h, w = arr.shape[2], arr.shape[3]
+ arr = arr.squeeze(0)
+
+ delta_h = h - image_size[0]
+ delta_w = w - image_size[1]
+
+ if reshape_mode == "random" or reshape_mode == "none":
+ top = np.random.randint(0, delta_h + 1)
+ left = np.random.randint(0, delta_w + 1)
+ elif reshape_mode == "center":
+ top, left = delta_h // 2, delta_w // 2
+ else:
+ raise NotImplementedError
+ arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1])
+ return arr
+
+
+def sampling_main(args, model_cls):
+ if isinstance(model_cls, type):
+ model = get_model(args, model_cls)
+ else:
+ model = model_cls
+
+ load_checkpoint(model, args)
+ model.eval()
+
+ if args.input_type == "cli":
+ data_iter = read_from_cli()
+ elif args.input_type == "txt":
+ rank, world_size = mpu.get_data_parallel_rank(), mpu.get_data_parallel_world_size()
+ print("rank and world_size", rank, world_size)
+ data_iter = read_from_file(args.input_file, rank=rank, world_size=world_size)
+ else:
+ raise NotImplementedError
+
+ image_size = [480, 720]
+
+ sample_func = model.sample
+ T, H, W, C, F = args.sampling_num_frames, image_size[0], image_size[1], args.latent_channels, 8
+ num_samples = [1]
+ force_uc_zero_embeddings = ["txt"]
+ device = model.device
+ with torch.no_grad():
+ for text, cnt in tqdm(data_iter):
+ # reload model on GPU
+ model.to(device)
+ print("rank:", rank, "start to process", text, cnt)
+ # TODO: broadcast image2video
+ value_dict = {
+ "prompt": text,
+ "negative_prompt": "",
+ "num_frames": torch.tensor(T).unsqueeze(0),
+ }
+
+ batch, batch_uc = get_batch(
+ get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, num_samples
+ )
+ for key in batch:
+ if isinstance(batch[key], torch.Tensor):
+ print(key, batch[key].shape)
+ elif isinstance(batch[key], list):
+ print(key, [len(l) for l in batch[key]])
+ else:
+ print(key, batch[key])
+ c, uc = model.conditioner.get_unconditional_conditioning(
+ batch,
+ batch_uc=batch_uc,
+ force_uc_zero_embeddings=force_uc_zero_embeddings,
+ )
+
+ for k in c:
+ if not k == "crossattn":
+ c[k], uc[k] = map(lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc))
+ for index in range(args.batch_size):
+ # reload model on GPU
+ model.to(device)
+ samples_z = sample_func(
+ c,
+ uc=uc,
+ batch_size=1,
+ shape=(T, C, H // F, W // F),
+ )
+ samples_z = samples_z.permute(0, 2, 1, 3, 4).contiguous()
+
+ # Unload the model from GPU to save GPU memory
+ model.to("cpu")
+ torch.cuda.empty_cache()
+ first_stage_model = model.first_stage_model
+ first_stage_model = first_stage_model.to(device)
+
+ latent = 1.0 / model.scale_factor * samples_z
+
+ # Decode latent serial to save GPU memory
+ recons = []
+ loop_num = (T - 1) // 2
+ for i in range(loop_num):
+ if i == 0:
+ start_frame, end_frame = 0, 3
+ else:
+ start_frame, end_frame = i * 2 + 1, i * 2 + 3
+ if i == loop_num - 1:
+ clear_fake_cp_cache = True
+ else:
+ clear_fake_cp_cache = False
+ with torch.no_grad():
+ recon = first_stage_model.decode(
+ latent[:, :, start_frame:end_frame].contiguous(), clear_fake_cp_cache=clear_fake_cp_cache
+ )
+
+ recons.append(recon)
+
+ recon = torch.cat(recons, dim=2).to(torch.float32)
+ samples_x = recon.permute(0, 2, 1, 3, 4).contiguous()
+ samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0).cpu()
+
+ save_path = os.path.join(
+ args.output_dir, str(cnt) + "_" + text.replace(" ", "_").replace("/", "")[:120], str(index)
+ )
+ if mpu.get_model_parallel_rank() == 0:
+ save_video_as_grid_and_mp4(samples, save_path, fps=args.sampling_fps)
+
+
+if __name__ == "__main__":
+ if "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ:
+ os.environ["LOCAL_RANK"] = os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]
+ os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"]
+ os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"]
+ py_parser = argparse.ArgumentParser(add_help=False)
+ known, args_list = py_parser.parse_known_args()
+
+ args = get_args(args_list)
+ args = argparse.Namespace(**vars(args), **vars(known))
+ del args.deepspeed_config
+ args.model_config.first_stage_config.params.cp_size = 1
+ args.model_config.network_config.params.transformer_args.model_parallel_size = 1
+ args.model_config.network_config.params.transformer_args.checkpoint_activations = False
+ args.model_config.loss_fn_config.params.sigma_sampler_config.params.uniform_sampling = False
+
+ sampling_main(args, model_cls=SATVideoDiffusionEngine)
diff --git a/sat/sgm/__init__.py b/sat/sgm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c4482364f4f054d67e5ec9ef57862976a2c6aa7
--- /dev/null
+++ b/sat/sgm/__init__.py
@@ -0,0 +1,4 @@
+from .models import AutoencodingEngine
+from .util import get_configs_path, instantiate_from_config
+
+__version__ = "0.1.0"
diff --git a/sat/sgm/lr_scheduler.py b/sat/sgm/lr_scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..b45db6983b731819de0eea23723bf83ea141f685
--- /dev/null
+++ b/sat/sgm/lr_scheduler.py
@@ -0,0 +1,110 @@
+import numpy as np
+
+
+class LambdaWarmUpCosineScheduler:
+ """
+ note: use with a base_lr of 1.0
+ """
+
+ def __init__(
+ self,
+ warm_up_steps,
+ lr_min,
+ lr_max,
+ lr_start,
+ max_decay_steps,
+ verbosity_interval=0,
+ ):
+ self.lr_warm_up_steps = warm_up_steps
+ self.lr_start = lr_start
+ self.lr_min = lr_min
+ self.lr_max = lr_max
+ self.lr_max_decay_steps = max_decay_steps
+ self.last_lr = 0.0
+ self.verbosity_interval = verbosity_interval
+
+ def schedule(self, n, **kwargs):
+ if self.verbosity_interval > 0:
+ if n % self.verbosity_interval == 0:
+ print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
+ if n < self.lr_warm_up_steps:
+ lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
+ self.last_lr = lr
+ return lr
+ else:
+ t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
+ t = min(t, 1.0)
+ lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (1 + np.cos(t * np.pi))
+ self.last_lr = lr
+ return lr
+
+ def __call__(self, n, **kwargs):
+ return self.schedule(n, **kwargs)
+
+
+class LambdaWarmUpCosineScheduler2:
+ """
+ supports repeated iterations, configurable via lists
+ note: use with a base_lr of 1.0.
+ """
+
+ def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
+ assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
+ self.lr_warm_up_steps = warm_up_steps
+ self.f_start = f_start
+ self.f_min = f_min
+ self.f_max = f_max
+ self.cycle_lengths = cycle_lengths
+ self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
+ self.last_f = 0.0
+ self.verbosity_interval = verbosity_interval
+
+ def find_in_interval(self, n):
+ interval = 0
+ for cl in self.cum_cycles[1:]:
+ if n <= cl:
+ return interval
+ interval += 1
+
+ def schedule(self, n, **kwargs):
+ cycle = self.find_in_interval(n)
+ n = n - self.cum_cycles[cycle]
+ if self.verbosity_interval > 0:
+ if n % self.verbosity_interval == 0:
+ print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " f"current cycle {cycle}")
+ if n < self.lr_warm_up_steps[cycle]:
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
+ self.last_f = f
+ return f
+ else:
+ t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
+ t = min(t, 1.0)
+ f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (1 + np.cos(t * np.pi))
+ self.last_f = f
+ return f
+
+ def __call__(self, n, **kwargs):
+ return self.schedule(n, **kwargs)
+
+
+class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
+ def schedule(self, n, **kwargs):
+ cycle = self.find_in_interval(n)
+ n = n - self.cum_cycles[cycle]
+ if self.verbosity_interval > 0:
+ if n % self.verbosity_interval == 0:
+ print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " f"current cycle {cycle}")
+
+ if n < self.lr_warm_up_steps[cycle]:
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
+ self.last_f = f
+ return f
+ else:
+ f = (
+ self.f_min[cycle]
+ + (self.f_max[cycle] - self.f_min[cycle])
+ * (self.cycle_lengths[cycle] - n)
+ / (self.cycle_lengths[cycle])
+ )
+ self.last_f = f
+ return f
diff --git a/sat/sgm/models/__init__.py b/sat/sgm/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e72b865963a7d7dfbe526f6b7aba63c5aa00a1e4
--- /dev/null
+++ b/sat/sgm/models/__init__.py
@@ -0,0 +1 @@
+from .autoencoder import AutoencodingEngine
diff --git a/sat/sgm/models/autoencoder.py b/sat/sgm/models/autoencoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ae44d055b4bbcf3543988bb7c74f7697e84ecfb
--- /dev/null
+++ b/sat/sgm/models/autoencoder.py
@@ -0,0 +1,630 @@
+import logging
+import math
+import re
+import random
+from abc import abstractmethod
+from contextlib import contextmanager
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import pytorch_lightning as pl
+import torch
+import torch.distributed
+import torch.nn as nn
+from einops import rearrange
+from packaging import version
+
+from ..modules.autoencoding.regularizers import AbstractRegularizer
+from ..modules.ema import LitEma
+from ..util import (
+ default,
+ get_nested_attribute,
+ get_obj_from_str,
+ instantiate_from_config,
+ initialize_context_parallel,
+ get_context_parallel_group,
+ get_context_parallel_group_rank,
+ is_context_parallel_initialized,
+)
+from ..modules.cp_enc_dec import _conv_split, _conv_gather
+
+logpy = logging.getLogger(__name__)
+
+
+class AbstractAutoencoder(pl.LightningModule):
+ """
+ This is the base class for all autoencoders, including image autoencoders, image autoencoders with discriminators,
+ unCLIP models, etc. Hence, it is fairly general, and specific features
+ (e.g. discriminator training, encoding, decoding) must be implemented in subclasses.
+ """
+
+ def __init__(
+ self,
+ ema_decay: Union[None, float] = None,
+ monitor: Union[None, str] = None,
+ input_key: str = "jpg",
+ ):
+ super().__init__()
+
+ self.input_key = input_key
+ self.use_ema = ema_decay is not None
+ if monitor is not None:
+ self.monitor = monitor
+
+ if self.use_ema:
+ self.model_ema = LitEma(self, decay=ema_decay)
+ logpy.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
+
+ if version.parse(torch.__version__) >= version.parse("2.0.0"):
+ self.automatic_optimization = False
+
+ def apply_ckpt(self, ckpt: Union[None, str, dict]):
+ if ckpt is None:
+ return
+ if isinstance(ckpt, str):
+ ckpt = {
+ "target": "sgm.modules.checkpoint.CheckpointEngine",
+ "params": {"ckpt_path": ckpt},
+ }
+ engine = instantiate_from_config(ckpt)
+ engine(self)
+
+ @abstractmethod
+ def get_input(self, batch) -> Any:
+ raise NotImplementedError()
+
+ def on_train_batch_end(self, *args, **kwargs):
+ # for EMA computation
+ if self.use_ema:
+ self.model_ema(self)
+
+ @contextmanager
+ def ema_scope(self, context=None):
+ if self.use_ema:
+ self.model_ema.store(self.parameters())
+ self.model_ema.copy_to(self)
+ if context is not None:
+ logpy.info(f"{context}: Switched to EMA weights")
+ try:
+ yield None
+ finally:
+ if self.use_ema:
+ self.model_ema.restore(self.parameters())
+ if context is not None:
+ logpy.info(f"{context}: Restored training weights")
+
+ @abstractmethod
+ def encode(self, *args, **kwargs) -> torch.Tensor:
+ raise NotImplementedError("encode()-method of abstract base class called")
+
+ @abstractmethod
+ def decode(self, *args, **kwargs) -> torch.Tensor:
+ raise NotImplementedError("decode()-method of abstract base class called")
+
+ def instantiate_optimizer_from_config(self, params, lr, cfg):
+ logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config")
+ return get_obj_from_str(cfg["target"])(params, lr=lr, **cfg.get("params", dict()))
+
+ def configure_optimizers(self) -> Any:
+ raise NotImplementedError()
+
+
+class AutoencodingEngine(AbstractAutoencoder):
+ """
+ Base class for all image autoencoders that we train, like VQGAN or AutoencoderKL
+ (we also restore them explicitly as special cases for legacy reasons).
+ Regularizations such as KL or VQ are moved to the regularizer class.
+ """
+
+ def __init__(
+ self,
+ *args,
+ encoder_config: Dict,
+ decoder_config: Dict,
+ loss_config: Dict,
+ regularizer_config: Dict,
+ optimizer_config: Union[Dict, None] = None,
+ lr_g_factor: float = 1.0,
+ trainable_ae_params: Optional[List[List[str]]] = None,
+ ae_optimizer_args: Optional[List[dict]] = None,
+ trainable_disc_params: Optional[List[List[str]]] = None,
+ disc_optimizer_args: Optional[List[dict]] = None,
+ disc_start_iter: int = 0,
+ diff_boost_factor: float = 3.0,
+ ckpt_engine: Union[None, str, dict] = None,
+ ckpt_path: Optional[str] = None,
+ additional_decode_keys: Optional[List[str]] = None,
+ **kwargs,
+ ):
+ super().__init__(*args, **kwargs)
+ self.automatic_optimization = False # pytorch lightning
+
+ self.encoder: torch.nn.Module = instantiate_from_config(encoder_config)
+ self.decoder: torch.nn.Module = instantiate_from_config(decoder_config)
+ self.loss: torch.nn.Module = instantiate_from_config(loss_config)
+ self.regularization: AbstractRegularizer = instantiate_from_config(regularizer_config)
+ self.optimizer_config = default(optimizer_config, {"target": "torch.optim.Adam"})
+ self.diff_boost_factor = diff_boost_factor
+ self.disc_start_iter = disc_start_iter
+ self.lr_g_factor = lr_g_factor
+ self.trainable_ae_params = trainable_ae_params
+ if self.trainable_ae_params is not None:
+ self.ae_optimizer_args = default(
+ ae_optimizer_args,
+ [{} for _ in range(len(self.trainable_ae_params))],
+ )
+ assert len(self.ae_optimizer_args) == len(self.trainable_ae_params)
+ else:
+ self.ae_optimizer_args = [{}] # makes type consitent
+
+ self.trainable_disc_params = trainable_disc_params
+ if self.trainable_disc_params is not None:
+ self.disc_optimizer_args = default(
+ disc_optimizer_args,
+ [{} for _ in range(len(self.trainable_disc_params))],
+ )
+ assert len(self.disc_optimizer_args) == len(self.trainable_disc_params)
+ else:
+ self.disc_optimizer_args = [{}] # makes type consitent
+
+ if ckpt_path is not None:
+ assert ckpt_engine is None, "Can't set ckpt_engine and ckpt_path"
+ logpy.warn("Checkpoint path is deprecated, use `checkpoint_egnine` instead")
+ self.apply_ckpt(default(ckpt_path, ckpt_engine))
+ self.additional_decode_keys = set(default(additional_decode_keys, []))
+
+ def get_input(self, batch: Dict) -> torch.Tensor:
+ # assuming unified data format, dataloader returns a dict.
+ # image tensors should be scaled to -1 ... 1 and in channels-first
+ # format (e.g., bchw instead if bhwc)
+ return batch[self.input_key]
+
+ def get_autoencoder_params(self) -> list:
+ params = []
+ if hasattr(self.loss, "get_trainable_autoencoder_parameters"):
+ params += list(self.loss.get_trainable_autoencoder_parameters())
+ if hasattr(self.regularization, "get_trainable_parameters"):
+ params += list(self.regularization.get_trainable_parameters())
+ params = params + list(self.encoder.parameters())
+ params = params + list(self.decoder.parameters())
+ return params
+
+ def get_discriminator_params(self) -> list:
+ if hasattr(self.loss, "get_trainable_parameters"):
+ params = list(self.loss.get_trainable_parameters()) # e.g., discriminator
+ else:
+ params = []
+ return params
+
+ def get_last_layer(self):
+ return self.decoder.get_last_layer()
+
+ def encode(
+ self,
+ x: torch.Tensor,
+ return_reg_log: bool = False,
+ unregularized: bool = False,
+ **kwargs,
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
+ z = self.encoder(x, **kwargs)
+ if unregularized:
+ return z, dict()
+ z, reg_log = self.regularization(z)
+ if return_reg_log:
+ return z, reg_log
+ return z
+
+ def decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor:
+ x = self.decoder(z, **kwargs)
+ return x
+
+ def forward(self, x: torch.Tensor, **additional_decode_kwargs) -> Tuple[torch.Tensor, torch.Tensor, dict]:
+ z, reg_log = self.encode(x, return_reg_log=True)
+ dec = self.decode(z, **additional_decode_kwargs)
+ return z, dec, reg_log
+
+ def inner_training_step(self, batch: dict, batch_idx: int, optimizer_idx: int = 0) -> torch.Tensor:
+ x = self.get_input(batch)
+ additional_decode_kwargs = {key: batch[key] for key in self.additional_decode_keys.intersection(batch)}
+ z, xrec, regularization_log = self(x, **additional_decode_kwargs)
+ if hasattr(self.loss, "forward_keys"):
+ extra_info = {
+ "z": z,
+ "optimizer_idx": optimizer_idx,
+ "global_step": self.global_step,
+ "last_layer": self.get_last_layer(),
+ "split": "train",
+ "regularization_log": regularization_log,
+ "autoencoder": self,
+ }
+ extra_info = {k: extra_info[k] for k in self.loss.forward_keys}
+ else:
+ extra_info = dict()
+
+ if optimizer_idx == 0:
+ # autoencode
+ out_loss = self.loss(x, xrec, **extra_info)
+ if isinstance(out_loss, tuple):
+ aeloss, log_dict_ae = out_loss
+ else:
+ # simple loss function
+ aeloss = out_loss
+ log_dict_ae = {"train/loss/rec": aeloss.detach()}
+
+ self.log_dict(
+ log_dict_ae,
+ prog_bar=False,
+ logger=True,
+ on_step=True,
+ on_epoch=True,
+ sync_dist=False,
+ )
+ self.log(
+ "loss",
+ aeloss.mean().detach(),
+ prog_bar=True,
+ logger=False,
+ on_epoch=False,
+ on_step=True,
+ )
+ return aeloss
+ elif optimizer_idx == 1:
+ # discriminator
+ discloss, log_dict_disc = self.loss(x, xrec, **extra_info)
+ # -> discriminator always needs to return a tuple
+ self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
+ return discloss
+ else:
+ raise NotImplementedError(f"Unknown optimizer {optimizer_idx}")
+
+ def training_step(self, batch: dict, batch_idx: int):
+ opts = self.optimizers()
+ if not isinstance(opts, list):
+ # Non-adversarial case
+ opts = [opts]
+ optimizer_idx = batch_idx % len(opts)
+ if self.global_step < self.disc_start_iter:
+ optimizer_idx = 0
+ opt = opts[optimizer_idx]
+ opt.zero_grad()
+ with opt.toggle_model():
+ loss = self.inner_training_step(batch, batch_idx, optimizer_idx=optimizer_idx)
+ self.manual_backward(loss)
+ opt.step()
+
+ def validation_step(self, batch: dict, batch_idx: int) -> Dict:
+ log_dict = self._validation_step(batch, batch_idx)
+ with self.ema_scope():
+ log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
+ log_dict.update(log_dict_ema)
+ return log_dict
+
+ def _validation_step(self, batch: dict, batch_idx: int, postfix: str = "") -> Dict:
+ x = self.get_input(batch)
+
+ z, xrec, regularization_log = self(x)
+ if hasattr(self.loss, "forward_keys"):
+ extra_info = {
+ "z": z,
+ "optimizer_idx": 0,
+ "global_step": self.global_step,
+ "last_layer": self.get_last_layer(),
+ "split": "val" + postfix,
+ "regularization_log": regularization_log,
+ "autoencoder": self,
+ }
+ extra_info = {k: extra_info[k] for k in self.loss.forward_keys}
+ else:
+ extra_info = dict()
+ out_loss = self.loss(x, xrec, **extra_info)
+ if isinstance(out_loss, tuple):
+ aeloss, log_dict_ae = out_loss
+ else:
+ # simple loss function
+ aeloss = out_loss
+ log_dict_ae = {f"val{postfix}/loss/rec": aeloss.detach()}
+ full_log_dict = log_dict_ae
+
+ if "optimizer_idx" in extra_info:
+ extra_info["optimizer_idx"] = 1
+ discloss, log_dict_disc = self.loss(x, xrec, **extra_info)
+ full_log_dict.update(log_dict_disc)
+ self.log(
+ f"val{postfix}/loss/rec",
+ log_dict_ae[f"val{postfix}/loss/rec"],
+ sync_dist=True,
+ )
+ self.log_dict(full_log_dict, sync_dist=True)
+ return full_log_dict
+
+ def get_param_groups(
+ self, parameter_names: List[List[str]], optimizer_args: List[dict]
+ ) -> Tuple[List[Dict[str, Any]], int]:
+ groups = []
+ num_params = 0
+ for names, args in zip(parameter_names, optimizer_args):
+ params = []
+ for pattern_ in names:
+ pattern_params = []
+ pattern = re.compile(pattern_)
+ for p_name, param in self.named_parameters():
+ if re.match(pattern, p_name):
+ pattern_params.append(param)
+ num_params += param.numel()
+ if len(pattern_params) == 0:
+ logpy.warn(f"Did not find parameters for pattern {pattern_}")
+ params.extend(pattern_params)
+ groups.append({"params": params, **args})
+ return groups, num_params
+
+ def configure_optimizers(self) -> List[torch.optim.Optimizer]:
+ if self.trainable_ae_params is None:
+ ae_params = self.get_autoencoder_params()
+ else:
+ ae_params, num_ae_params = self.get_param_groups(self.trainable_ae_params, self.ae_optimizer_args)
+ logpy.info(f"Number of trainable autoencoder parameters: {num_ae_params:,}")
+ if self.trainable_disc_params is None:
+ disc_params = self.get_discriminator_params()
+ else:
+ disc_params, num_disc_params = self.get_param_groups(self.trainable_disc_params, self.disc_optimizer_args)
+ logpy.info(f"Number of trainable discriminator parameters: {num_disc_params:,}")
+ opt_ae = self.instantiate_optimizer_from_config(
+ ae_params,
+ default(self.lr_g_factor, 1.0) * self.learning_rate,
+ self.optimizer_config,
+ )
+ opts = [opt_ae]
+ if len(disc_params) > 0:
+ opt_disc = self.instantiate_optimizer_from_config(disc_params, self.learning_rate, self.optimizer_config)
+ opts.append(opt_disc)
+
+ return opts
+
+ @torch.no_grad()
+ def log_images(self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs) -> dict:
+ log = dict()
+ additional_decode_kwargs = {}
+ x = self.get_input(batch)
+ additional_decode_kwargs.update({key: batch[key] for key in self.additional_decode_keys.intersection(batch)})
+
+ _, xrec, _ = self(x, **additional_decode_kwargs)
+ log["inputs"] = x
+ log["reconstructions"] = xrec
+ diff = 0.5 * torch.abs(torch.clamp(xrec, -1.0, 1.0) - x)
+ diff.clamp_(0, 1.0)
+ log["diff"] = 2.0 * diff - 1.0
+ # diff_boost shows location of small errors, by boosting their
+ # brightness.
+ log["diff_boost"] = 2.0 * torch.clamp(self.diff_boost_factor * diff, 0.0, 1.0) - 1
+ if hasattr(self.loss, "log_images"):
+ log.update(self.loss.log_images(x, xrec))
+ with self.ema_scope():
+ _, xrec_ema, _ = self(x, **additional_decode_kwargs)
+ log["reconstructions_ema"] = xrec_ema
+ diff_ema = 0.5 * torch.abs(torch.clamp(xrec_ema, -1.0, 1.0) - x)
+ diff_ema.clamp_(0, 1.0)
+ log["diff_ema"] = 2.0 * diff_ema - 1.0
+ log["diff_boost_ema"] = 2.0 * torch.clamp(self.diff_boost_factor * diff_ema, 0.0, 1.0) - 1
+ if additional_log_kwargs:
+ additional_decode_kwargs.update(additional_log_kwargs)
+ _, xrec_add, _ = self(x, **additional_decode_kwargs)
+ log_str = "reconstructions-" + "-".join(
+ [f"{key}={additional_log_kwargs[key]}" for key in additional_log_kwargs]
+ )
+ log[log_str] = xrec_add
+ return log
+
+
+class AutoencodingEngineLegacy(AutoencodingEngine):
+ def __init__(self, embed_dim: int, **kwargs):
+ self.max_batch_size = kwargs.pop("max_batch_size", None)
+ ddconfig = kwargs.pop("ddconfig")
+ ckpt_path = kwargs.pop("ckpt_path", None)
+ ckpt_engine = kwargs.pop("ckpt_engine", None)
+ super().__init__(
+ encoder_config={
+ "target": "sgm.modules.diffusionmodules.model.Encoder",
+ "params": ddconfig,
+ },
+ decoder_config={
+ "target": "sgm.modules.diffusionmodules.model.Decoder",
+ "params": ddconfig,
+ },
+ **kwargs,
+ )
+ self.quant_conv = torch.nn.Conv2d(
+ (1 + ddconfig["double_z"]) * ddconfig["z_channels"],
+ (1 + ddconfig["double_z"]) * embed_dim,
+ 1,
+ )
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
+ self.embed_dim = embed_dim
+
+ self.apply_ckpt(default(ckpt_path, ckpt_engine))
+
+ def get_autoencoder_params(self) -> list:
+ params = super().get_autoencoder_params()
+ return params
+
+ def encode(self, x: torch.Tensor, return_reg_log: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
+ if self.max_batch_size is None:
+ z = self.encoder(x)
+ z = self.quant_conv(z)
+ else:
+ N = x.shape[0]
+ bs = self.max_batch_size
+ n_batches = int(math.ceil(N / bs))
+ z = list()
+ for i_batch in range(n_batches):
+ z_batch = self.encoder(x[i_batch * bs : (i_batch + 1) * bs])
+ z_batch = self.quant_conv(z_batch)
+ z.append(z_batch)
+ z = torch.cat(z, 0)
+
+ z, reg_log = self.regularization(z)
+ if return_reg_log:
+ return z, reg_log
+ return z
+
+ def decode(self, z: torch.Tensor, **decoder_kwargs) -> torch.Tensor:
+ if self.max_batch_size is None:
+ dec = self.post_quant_conv(z)
+ dec = self.decoder(dec, **decoder_kwargs)
+ else:
+ N = z.shape[0]
+ bs = self.max_batch_size
+ n_batches = int(math.ceil(N / bs))
+ dec = list()
+ for i_batch in range(n_batches):
+ dec_batch = self.post_quant_conv(z[i_batch * bs : (i_batch + 1) * bs])
+ dec_batch = self.decoder(dec_batch, **decoder_kwargs)
+ dec.append(dec_batch)
+ dec = torch.cat(dec, 0)
+
+ return dec
+
+
+class IdentityFirstStage(AbstractAutoencoder):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def get_input(self, x: Any) -> Any:
+ return x
+
+ def encode(self, x: Any, *args, **kwargs) -> Any:
+ return x
+
+ def decode(self, x: Any, *args, **kwargs) -> Any:
+ return
+
+
+class VideoAutoencodingEngine(AutoencodingEngine):
+ def __init__(
+ self,
+ ckpt_path: Union[None, str] = None,
+ ignore_keys: Union[Tuple, list] = (),
+ image_video_weights=[1, 1],
+ only_train_decoder=False,
+ context_parallel_size=0,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.context_parallel_size = context_parallel_size
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
+
+ def log_videos(self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs) -> dict:
+ return self.log_images(batch, additional_log_kwargs, **kwargs)
+
+ def get_input(self, batch: dict) -> torch.Tensor:
+ if self.context_parallel_size > 0:
+ if not is_context_parallel_initialized():
+ initialize_context_parallel(self.context_parallel_size)
+
+ batch = batch[self.input_key]
+
+ global_src_rank = get_context_parallel_group_rank() * self.context_parallel_size
+ torch.distributed.broadcast(batch, src=global_src_rank, group=get_context_parallel_group())
+
+ batch = _conv_split(batch, dim=2, kernel_size=1)
+ return batch
+
+ return batch[self.input_key]
+
+ def apply_ckpt(self, ckpt: Union[None, str, dict]):
+ if ckpt is None:
+ return
+ self.init_from_ckpt(ckpt)
+
+ def init_from_ckpt(self, path, ignore_keys=list()):
+ sd = torch.load(path, map_location="cpu")["state_dict"]
+ keys = list(sd.keys())
+ for k in keys:
+ for ik in ignore_keys:
+ if k.startswith(ik):
+ del sd[k]
+ missing_keys, unexpected_keys = self.load_state_dict(sd, strict=False)
+ print("Missing keys: ", missing_keys)
+ print("Unexpected keys: ", unexpected_keys)
+ print(f"Restored from {path}")
+
+
+class VideoAutoencoderInferenceWrapper(VideoAutoencodingEngine):
+ def __init__(
+ self,
+ cp_size=0,
+ *args,
+ **kwargs,
+ ):
+ self.cp_size = cp_size
+ return super().__init__(*args, **kwargs)
+
+ def encode(
+ self,
+ x: torch.Tensor,
+ return_reg_log: bool = False,
+ unregularized: bool = False,
+ input_cp: bool = False,
+ output_cp: bool = False,
+ use_cp: bool = True,
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
+ if self.cp_size <= 1:
+ use_cp = False
+ if self.cp_size > 0 and use_cp and not input_cp:
+ if not is_context_parallel_initialized:
+ initialize_context_parallel(self.cp_size)
+
+ global_src_rank = get_context_parallel_group_rank() * self.cp_size
+ torch.distributed.broadcast(x, src=global_src_rank, group=get_context_parallel_group())
+
+ x = _conv_split(x, dim=2, kernel_size=1)
+
+ if return_reg_log:
+ z, reg_log = super().encode(x, return_reg_log, unregularized, use_cp=use_cp)
+ else:
+ z = super().encode(x, return_reg_log, unregularized, use_cp=use_cp)
+
+ if self.cp_size > 0 and use_cp and not output_cp:
+ z = _conv_gather(z, dim=2, kernel_size=1)
+
+ if return_reg_log:
+ return z, reg_log
+ return z
+
+ def decode(
+ self,
+ z: torch.Tensor,
+ input_cp: bool = False,
+ output_cp: bool = False,
+ use_cp: bool = True,
+ **kwargs,
+ ):
+ if self.cp_size <= 1:
+ use_cp = False
+ if self.cp_size > 0 and use_cp and not input_cp:
+ if not is_context_parallel_initialized:
+ initialize_context_parallel(self.cp_size)
+
+ global_src_rank = get_context_parallel_group_rank() * self.cp_size
+ torch.distributed.broadcast(z, src=global_src_rank, group=get_context_parallel_group())
+
+ z = _conv_split(z, dim=2, kernel_size=1)
+
+ x = super().decode(z, use_cp=use_cp, **kwargs)
+
+ if self.cp_size > 0 and use_cp and not output_cp:
+ x = _conv_gather(x, dim=2, kernel_size=1)
+
+ return x
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ input_cp: bool = False,
+ latent_cp: bool = False,
+ output_cp: bool = False,
+ **additional_decode_kwargs,
+ ) -> Tuple[torch.Tensor, torch.Tensor, dict]:
+ z, reg_log = self.encode(x, return_reg_log=True, input_cp=input_cp, output_cp=latent_cp)
+ dec = self.decode(z, input_cp=latent_cp, output_cp=output_cp, **additional_decode_kwargs)
+ return z, dec, reg_log
diff --git a/sat/sgm/modules/__init__.py b/sat/sgm/modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0db1d7716a6e48f77b86a4b59c9289d6fb76b50b
--- /dev/null
+++ b/sat/sgm/modules/__init__.py
@@ -0,0 +1,6 @@
+from .encoders.modules import GeneralConditioner
+
+UNCONDITIONAL_CONFIG = {
+ "target": "sgm.modules.GeneralConditioner",
+ "params": {"emb_models": []},
+}
diff --git a/sat/sgm/modules/attention.py b/sat/sgm/modules/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..1304eaae75617db70692b1cde4219f86c9898a9e
--- /dev/null
+++ b/sat/sgm/modules/attention.py
@@ -0,0 +1,572 @@
+import math
+from inspect import isfunction
+from typing import Any, Optional
+
+import torch
+import torch.nn.functional as F
+from einops import rearrange, repeat
+from packaging import version
+from torch import nn
+
+if version.parse(torch.__version__) >= version.parse("2.0.0"):
+ SDP_IS_AVAILABLE = True
+ from torch.backends.cuda import SDPBackend, sdp_kernel
+
+ BACKEND_MAP = {
+ SDPBackend.MATH: {
+ "enable_math": True,
+ "enable_flash": False,
+ "enable_mem_efficient": False,
+ },
+ SDPBackend.FLASH_ATTENTION: {
+ "enable_math": False,
+ "enable_flash": True,
+ "enable_mem_efficient": False,
+ },
+ SDPBackend.EFFICIENT_ATTENTION: {
+ "enable_math": False,
+ "enable_flash": False,
+ "enable_mem_efficient": True,
+ },
+ None: {"enable_math": True, "enable_flash": True, "enable_mem_efficient": True},
+ }
+else:
+ from contextlib import nullcontext
+
+ SDP_IS_AVAILABLE = False
+ sdp_kernel = nullcontext
+ BACKEND_MAP = {}
+ print(
+ f"No SDP backend available, likely because you are running in pytorch versions < 2.0. In fact, "
+ f"you are using PyTorch {torch.__version__}. You might want to consider upgrading."
+ )
+
+try:
+ import xformers
+ import xformers.ops
+
+ XFORMERS_IS_AVAILABLE = True
+except:
+ XFORMERS_IS_AVAILABLE = False
+ print("no module 'xformers'. Processing without...")
+
+from .diffusionmodules.util import checkpoint
+
+
+def exists(val):
+ return val is not None
+
+
+def uniq(arr):
+ return {el: True for el in arr}.keys()
+
+
+def default(val, d):
+ if exists(val):
+ return val
+ return d() if isfunction(d) else d
+
+
+def max_neg_value(t):
+ return -torch.finfo(t.dtype).max
+
+
+def init_(tensor):
+ dim = tensor.shape[-1]
+ std = 1 / math.sqrt(dim)
+ tensor.uniform_(-std, std)
+ return tensor
+
+
+# feedforward
+class GEGLU(nn.Module):
+ def __init__(self, dim_in, dim_out):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out * 2)
+
+ def forward(self, x):
+ x, gate = self.proj(x).chunk(2, dim=-1)
+ return x * F.gelu(gate)
+
+
+class FeedForward(nn.Module):
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
+ super().__init__()
+ inner_dim = int(dim * mult)
+ dim_out = default(dim_out, dim)
+ project_in = nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)
+
+ self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
+
+ def forward(self, x):
+ return self.net(x)
+
+
+def zero_module(module):
+ """
+ Zero out the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+
+def Normalize(in_channels):
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
+
+
+class LinearAttention(nn.Module):
+ def __init__(self, dim, heads=4, dim_head=32):
+ super().__init__()
+ self.heads = heads
+ hidden_dim = dim_head * heads
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
+
+ def forward(self, x):
+ b, c, h, w = x.shape
+ qkv = self.to_qkv(x)
+ q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3)
+ k = k.softmax(dim=-1)
+ context = torch.einsum("bhdn,bhen->bhde", k, v)
+ out = torch.einsum("bhde,bhdn->bhen", context, q)
+ out = rearrange(out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w)
+ return self.to_out(out)
+
+
+class SpatialSelfAttention(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+ self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+ self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+ self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b, c, h, w = q.shape
+ q = rearrange(q, "b c h w -> b (h w) c")
+ k = rearrange(k, "b c h w -> b c (h w)")
+ w_ = torch.einsum("bij,bjk->bik", q, k)
+
+ w_ = w_ * (int(c) ** (-0.5))
+ w_ = torch.nn.functional.softmax(w_, dim=2)
+
+ # attend to values
+ v = rearrange(v, "b c h w -> b c (h w)")
+ w_ = rearrange(w_, "b i j -> b j i")
+ h_ = torch.einsum("bij,bjk->bik", v, w_)
+ h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
+ h_ = self.proj_out(h_)
+
+ return x + h_
+
+
+class CrossAttention(nn.Module):
+ def __init__(
+ self,
+ query_dim,
+ context_dim=None,
+ heads=8,
+ dim_head=64,
+ dropout=0.0,
+ backend=None,
+ ):
+ super().__init__()
+ inner_dim = dim_head * heads
+ context_dim = default(context_dim, query_dim)
+
+ self.scale = dim_head**-0.5
+ self.heads = heads
+
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
+
+ self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
+ self.backend = backend
+
+ def forward(
+ self,
+ x,
+ context=None,
+ mask=None,
+ additional_tokens=None,
+ n_times_crossframe_attn_in_self=0,
+ ):
+ h = self.heads
+
+ if additional_tokens is not None:
+ # get the number of masked tokens at the beginning of the output sequence
+ n_tokens_to_mask = additional_tokens.shape[1]
+ # add additional token
+ x = torch.cat([additional_tokens, x], dim=1)
+
+ q = self.to_q(x)
+ context = default(context, x)
+ k = self.to_k(context)
+ v = self.to_v(context)
+
+ if n_times_crossframe_attn_in_self:
+ # reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439
+ assert x.shape[0] % n_times_crossframe_attn_in_self == 0
+ n_cp = x.shape[0] // n_times_crossframe_attn_in_self
+ k = repeat(k[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp)
+ v = repeat(v[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp)
+
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
+
+ ## old
+ """
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
+ del q, k
+
+ if exists(mask):
+ mask = rearrange(mask, 'b ... -> b (...)')
+ max_neg_value = -torch.finfo(sim.dtype).max
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
+ sim.masked_fill_(~mask, max_neg_value)
+
+ # attention, what we cannot get enough of
+ sim = sim.softmax(dim=-1)
+
+ out = einsum('b i j, b j d -> b i d', sim, v)
+ """
+ ## new
+ with sdp_kernel(**BACKEND_MAP[self.backend]):
+ # print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape)
+ out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask) # scale is dim_head ** -0.5 per default
+
+ del q, k, v
+ out = rearrange(out, "b h n d -> b n (h d)", h=h)
+
+ if additional_tokens is not None:
+ # remove additional token
+ out = out[:, n_tokens_to_mask:]
+ return self.to_out(out)
+
+
+class MemoryEfficientCrossAttention(nn.Module):
+ # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs):
+ super().__init__()
+ print(
+ f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
+ f"{heads} heads with a dimension of {dim_head}."
+ )
+ inner_dim = dim_head * heads
+ context_dim = default(context_dim, query_dim)
+
+ self.heads = heads
+ self.dim_head = dim_head
+
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
+
+ self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
+ self.attention_op: Optional[Any] = None
+
+ def forward(
+ self,
+ x,
+ context=None,
+ mask=None,
+ additional_tokens=None,
+ n_times_crossframe_attn_in_self=0,
+ ):
+ if additional_tokens is not None:
+ # get the number of masked tokens at the beginning of the output sequence
+ n_tokens_to_mask = additional_tokens.shape[1]
+ # add additional token
+ x = torch.cat([additional_tokens, x], dim=1)
+ q = self.to_q(x)
+ context = default(context, x)
+ k = self.to_k(context)
+ v = self.to_v(context)
+
+ if n_times_crossframe_attn_in_self:
+ # reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439
+ assert x.shape[0] % n_times_crossframe_attn_in_self == 0
+ # n_cp = x.shape[0]//n_times_crossframe_attn_in_self
+ k = repeat(
+ k[::n_times_crossframe_attn_in_self],
+ "b ... -> (b n) ...",
+ n=n_times_crossframe_attn_in_self,
+ )
+ v = repeat(
+ v[::n_times_crossframe_attn_in_self],
+ "b ... -> (b n) ...",
+ n=n_times_crossframe_attn_in_self,
+ )
+
+ b, _, _ = q.shape
+ q, k, v = map(
+ lambda t: t.unsqueeze(3)
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
+ .permute(0, 2, 1, 3)
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
+ .contiguous(),
+ (q, k, v),
+ )
+
+ # actually compute the attention, what we cannot get enough of
+ out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
+
+ # TODO: Use this directly in the attention operation, as a bias
+ if exists(mask):
+ raise NotImplementedError
+ out = (
+ out.unsqueeze(0)
+ .reshape(b, self.heads, out.shape[1], self.dim_head)
+ .permute(0, 2, 1, 3)
+ .reshape(b, out.shape[1], self.heads * self.dim_head)
+ )
+ if additional_tokens is not None:
+ # remove additional token
+ out = out[:, n_tokens_to_mask:]
+ return self.to_out(out)
+
+
+class BasicTransformerBlock(nn.Module):
+ ATTENTION_MODES = {
+ "softmax": CrossAttention, # vanilla attention
+ "softmax-xformers": MemoryEfficientCrossAttention, # ampere
+ }
+
+ def __init__(
+ self,
+ dim,
+ n_heads,
+ d_head,
+ dropout=0.0,
+ context_dim=None,
+ gated_ff=True,
+ checkpoint=True,
+ disable_self_attn=False,
+ attn_mode="softmax",
+ sdp_backend=None,
+ ):
+ super().__init__()
+ assert attn_mode in self.ATTENTION_MODES
+ if attn_mode != "softmax" and not XFORMERS_IS_AVAILABLE:
+ print(
+ f"Attention mode '{attn_mode}' is not available. Falling back to native attention. "
+ f"This is not a problem in Pytorch >= 2.0. FYI, you are running with PyTorch version {torch.__version__}"
+ )
+ attn_mode = "softmax"
+ elif attn_mode == "softmax" and not SDP_IS_AVAILABLE:
+ print("We do not support vanilla attention anymore, as it is too expensive. Sorry.")
+ if not XFORMERS_IS_AVAILABLE:
+ assert False, "Please install xformers via e.g. 'pip install xformers==0.0.16'"
+ else:
+ print("Falling back to xformers efficient attention.")
+ attn_mode = "softmax-xformers"
+ attn_cls = self.ATTENTION_MODES[attn_mode]
+ if version.parse(torch.__version__) >= version.parse("2.0.0"):
+ assert sdp_backend is None or isinstance(sdp_backend, SDPBackend)
+ else:
+ assert sdp_backend is None
+ self.disable_self_attn = disable_self_attn
+ self.attn1 = attn_cls(
+ query_dim=dim,
+ heads=n_heads,
+ dim_head=d_head,
+ dropout=dropout,
+ context_dim=context_dim if self.disable_self_attn else None,
+ backend=sdp_backend,
+ ) # is a self-attention if not self.disable_self_attn
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
+ self.attn2 = attn_cls(
+ query_dim=dim,
+ context_dim=context_dim,
+ heads=n_heads,
+ dim_head=d_head,
+ dropout=dropout,
+ backend=sdp_backend,
+ ) # is self-attn if context is none
+ self.norm1 = nn.LayerNorm(dim)
+ self.norm2 = nn.LayerNorm(dim)
+ self.norm3 = nn.LayerNorm(dim)
+ self.checkpoint = checkpoint
+ if self.checkpoint:
+ print(f"{self.__class__.__name__} is using checkpointing")
+
+ def forward(self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0):
+ kwargs = {"x": x}
+
+ if context is not None:
+ kwargs.update({"context": context})
+
+ if additional_tokens is not None:
+ kwargs.update({"additional_tokens": additional_tokens})
+
+ if n_times_crossframe_attn_in_self:
+ kwargs.update({"n_times_crossframe_attn_in_self": n_times_crossframe_attn_in_self})
+
+ # return mixed_checkpoint(self._forward, kwargs, self.parameters(), self.checkpoint)
+ return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
+
+ def _forward(self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0):
+ x = (
+ self.attn1(
+ self.norm1(x),
+ context=context if self.disable_self_attn else None,
+ additional_tokens=additional_tokens,
+ n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self if not self.disable_self_attn else 0,
+ )
+ + x
+ )
+ x = self.attn2(self.norm2(x), context=context, additional_tokens=additional_tokens) + x
+ x = self.ff(self.norm3(x)) + x
+ return x
+
+
+class BasicTransformerSingleLayerBlock(nn.Module):
+ ATTENTION_MODES = {
+ "softmax": CrossAttention, # vanilla attention
+ "softmax-xformers": MemoryEfficientCrossAttention, # on the A100s not quite as fast as the above version
+ # (todo might depend on head_dim, check, falls back to semi-optimized kernels for dim!=[16,32,64,128])
+ }
+
+ def __init__(
+ self,
+ dim,
+ n_heads,
+ d_head,
+ dropout=0.0,
+ context_dim=None,
+ gated_ff=True,
+ checkpoint=True,
+ attn_mode="softmax",
+ ):
+ super().__init__()
+ assert attn_mode in self.ATTENTION_MODES
+ attn_cls = self.ATTENTION_MODES[attn_mode]
+ self.attn1 = attn_cls(
+ query_dim=dim,
+ heads=n_heads,
+ dim_head=d_head,
+ dropout=dropout,
+ context_dim=context_dim,
+ )
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
+ self.norm1 = nn.LayerNorm(dim)
+ self.norm2 = nn.LayerNorm(dim)
+ self.checkpoint = checkpoint
+
+ def forward(self, x, context=None):
+ return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
+
+ def _forward(self, x, context=None):
+ x = self.attn1(self.norm1(x), context=context) + x
+ x = self.ff(self.norm2(x)) + x
+ return x
+
+
+class SpatialTransformer(nn.Module):
+ """
+ Transformer block for image-like data.
+ First, project the input (aka embedding)
+ and reshape to b, t, d.
+ Then apply standard transformer action.
+ Finally, reshape to image
+ NEW: use_linear for more efficiency instead of the 1x1 convs
+ """
+
+ def __init__(
+ self,
+ in_channels,
+ n_heads,
+ d_head,
+ depth=1,
+ dropout=0.0,
+ context_dim=None,
+ disable_self_attn=False,
+ use_linear=False,
+ attn_type="softmax",
+ use_checkpoint=True,
+ # sdp_backend=SDPBackend.FLASH_ATTENTION
+ sdp_backend=None,
+ ):
+ super().__init__()
+ print(f"constructing {self.__class__.__name__} of depth {depth} w/ {in_channels} channels and {n_heads} heads")
+ from omegaconf import ListConfig
+
+ if exists(context_dim) and not isinstance(context_dim, (list, ListConfig)):
+ context_dim = [context_dim]
+ if exists(context_dim) and isinstance(context_dim, list):
+ if depth != len(context_dim):
+ print(
+ f"WARNING: {self.__class__.__name__}: Found context dims {context_dim} of depth {len(context_dim)}, "
+ f"which does not match the specified 'depth' of {depth}. Setting context_dim to {depth * [context_dim[0]]} now."
+ )
+ # depth does not match context dims.
+ assert all(
+ map(lambda x: x == context_dim[0], context_dim)
+ ), "need homogenous context_dim to match depth automatically"
+ context_dim = depth * [context_dim[0]]
+ elif context_dim is None:
+ context_dim = [None] * depth
+ self.in_channels = in_channels
+ inner_dim = n_heads * d_head
+ self.norm = Normalize(in_channels)
+ if not use_linear:
+ self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
+ else:
+ self.proj_in = nn.Linear(in_channels, inner_dim)
+
+ self.transformer_blocks = nn.ModuleList(
+ [
+ BasicTransformerBlock(
+ inner_dim,
+ n_heads,
+ d_head,
+ dropout=dropout,
+ context_dim=context_dim[d],
+ disable_self_attn=disable_self_attn,
+ attn_mode=attn_type,
+ checkpoint=use_checkpoint,
+ sdp_backend=sdp_backend,
+ )
+ for d in range(depth)
+ ]
+ )
+ if not use_linear:
+ self.proj_out = zero_module(nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0))
+ else:
+ # self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
+ self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
+ self.use_linear = use_linear
+
+ def forward(self, x, context=None):
+ # note: if no context is given, cross-attention defaults to self-attention
+ if not isinstance(context, list):
+ context = [context]
+ b, c, h, w = x.shape
+ x_in = x
+ x = self.norm(x)
+ if not self.use_linear:
+ x = self.proj_in(x)
+ x = rearrange(x, "b c h w -> b (h w) c").contiguous()
+ if self.use_linear:
+ x = self.proj_in(x)
+ for i, block in enumerate(self.transformer_blocks):
+ if i > 0 and len(context) == 1:
+ i = 0 # use same context for each block
+ x = block(x, context=context[i])
+ if self.use_linear:
+ x = self.proj_out(x)
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
+ if not self.use_linear:
+ x = self.proj_out(x)
+ return x + x_in
diff --git a/sat/sgm/modules/autoencoding/__init__.py b/sat/sgm/modules/autoencoding/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/sat/sgm/modules/autoencoding/losses/__init__.py b/sat/sgm/modules/autoencoding/losses/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b3bb81d91cd91637bef2e04f8b9dcda5af4c7c2a
--- /dev/null
+++ b/sat/sgm/modules/autoencoding/losses/__init__.py
@@ -0,0 +1,8 @@
+__all__ = [
+ "GeneralLPIPSWithDiscriminator",
+ "LatentLPIPS",
+]
+
+from .discriminator_loss import GeneralLPIPSWithDiscriminator
+from .lpips import LatentLPIPS
+from .video_loss import VideoAutoencoderLoss
diff --git a/sat/sgm/modules/autoencoding/losses/discriminator_loss.py b/sat/sgm/modules/autoencoding/losses/discriminator_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..b5b144a043c7eccfaaef56adc5d2d7896a1849ae
--- /dev/null
+++ b/sat/sgm/modules/autoencoding/losses/discriminator_loss.py
@@ -0,0 +1,301 @@
+from typing import Dict, Iterator, List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torchvision
+from einops import rearrange
+from matplotlib import colormaps
+from matplotlib import pyplot as plt
+
+from ....util import default, instantiate_from_config
+from ..lpips.loss.lpips import LPIPS
+from ..lpips.model.model import weights_init
+from ..lpips.vqperceptual import hinge_d_loss, vanilla_d_loss
+
+
+class GeneralLPIPSWithDiscriminator(nn.Module):
+ def __init__(
+ self,
+ disc_start: int,
+ logvar_init: float = 0.0,
+ disc_num_layers: int = 3,
+ disc_in_channels: int = 3,
+ disc_factor: float = 1.0,
+ disc_weight: float = 1.0,
+ perceptual_weight: float = 1.0,
+ disc_loss: str = "hinge",
+ scale_input_to_tgt_size: bool = False,
+ dims: int = 2,
+ learn_logvar: bool = False,
+ regularization_weights: Union[None, Dict[str, float]] = None,
+ additional_log_keys: Optional[List[str]] = None,
+ discriminator_config: Optional[Dict] = None,
+ ):
+ super().__init__()
+ self.dims = dims
+ if self.dims > 2:
+ print(
+ f"running with dims={dims}. This means that for perceptual loss "
+ f"calculation, the LPIPS loss will be applied to each frame "
+ f"independently."
+ )
+ self.scale_input_to_tgt_size = scale_input_to_tgt_size
+ assert disc_loss in ["hinge", "vanilla"]
+ self.perceptual_loss = LPIPS().eval()
+ self.perceptual_weight = perceptual_weight
+ # output log variance
+ self.logvar = nn.Parameter(torch.full((), logvar_init), requires_grad=learn_logvar)
+ self.learn_logvar = learn_logvar
+
+ discriminator_config = default(
+ discriminator_config,
+ {
+ "target": "sgm.modules.autoencoding.lpips.model.model.NLayerDiscriminator",
+ "params": {
+ "input_nc": disc_in_channels,
+ "n_layers": disc_num_layers,
+ "use_actnorm": False,
+ },
+ },
+ )
+
+ self.discriminator = instantiate_from_config(discriminator_config).apply(weights_init)
+ self.discriminator_iter_start = disc_start
+ self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
+ self.disc_factor = disc_factor
+ self.discriminator_weight = disc_weight
+ self.regularization_weights = default(regularization_weights, {})
+
+ self.forward_keys = [
+ "optimizer_idx",
+ "global_step",
+ "last_layer",
+ "split",
+ "regularization_log",
+ ]
+
+ self.additional_log_keys = set(default(additional_log_keys, []))
+ self.additional_log_keys.update(set(self.regularization_weights.keys()))
+
+ def get_trainable_parameters(self) -> Iterator[nn.Parameter]:
+ return self.discriminator.parameters()
+
+ def get_trainable_autoencoder_parameters(self) -> Iterator[nn.Parameter]:
+ if self.learn_logvar:
+ yield self.logvar
+ yield from ()
+
+ @torch.no_grad()
+ def log_images(self, inputs: torch.Tensor, reconstructions: torch.Tensor) -> Dict[str, torch.Tensor]:
+ # calc logits of real/fake
+ logits_real = self.discriminator(inputs.contiguous().detach())
+ if len(logits_real.shape) < 4:
+ # Non patch-discriminator
+ return dict()
+ logits_fake = self.discriminator(reconstructions.contiguous().detach())
+ # -> (b, 1, h, w)
+
+ # parameters for colormapping
+ high = max(logits_fake.abs().max(), logits_real.abs().max()).item()
+ cmap = colormaps["PiYG"] # diverging colormap
+
+ def to_colormap(logits: torch.Tensor) -> torch.Tensor:
+ """(b, 1, ...) -> (b, 3, ...)"""
+ logits = (logits + high) / (2 * high)
+ logits_np = cmap(logits.cpu().numpy())[..., :3] # truncate alpha channel
+ # -> (b, 1, ..., 3)
+ logits = torch.from_numpy(logits_np).to(logits.device)
+ return rearrange(logits, "b 1 ... c -> b c ...")
+
+ logits_real = torch.nn.functional.interpolate(
+ logits_real,
+ size=inputs.shape[-2:],
+ mode="nearest",
+ antialias=False,
+ )
+ logits_fake = torch.nn.functional.interpolate(
+ logits_fake,
+ size=reconstructions.shape[-2:],
+ mode="nearest",
+ antialias=False,
+ )
+
+ # alpha value of logits for overlay
+ alpha_real = torch.abs(logits_real) / high
+ alpha_fake = torch.abs(logits_fake) / high
+ # -> (b, 1, h, w) in range [0, 0.5]
+ # alpha value of lines don't really matter, since the values are the same
+ # for both images and logits anyway
+ grid_alpha_real = torchvision.utils.make_grid(alpha_real, nrow=4)
+ grid_alpha_fake = torchvision.utils.make_grid(alpha_fake, nrow=4)
+ grid_alpha = 0.8 * torch.cat((grid_alpha_real, grid_alpha_fake), dim=1)
+ # -> (1, h, w)
+ # blend logits and images together
+
+ # prepare logits for plotting
+ logits_real = to_colormap(logits_real)
+ logits_fake = to_colormap(logits_fake)
+ # resize logits
+ # -> (b, 3, h, w)
+
+ # make some grids
+ # add all logits to one plot
+ logits_real = torchvision.utils.make_grid(logits_real, nrow=4)
+ logits_fake = torchvision.utils.make_grid(logits_fake, nrow=4)
+ # I just love how torchvision calls the number of columns `nrow`
+ grid_logits = torch.cat((logits_real, logits_fake), dim=1)
+ # -> (3, h, w)
+
+ grid_images_real = torchvision.utils.make_grid(0.5 * inputs + 0.5, nrow=4)
+ grid_images_fake = torchvision.utils.make_grid(0.5 * reconstructions + 0.5, nrow=4)
+ grid_images = torch.cat((grid_images_real, grid_images_fake), dim=1)
+ # -> (3, h, w) in range [0, 1]
+
+ grid_blend = grid_alpha * grid_logits + (1 - grid_alpha) * grid_images
+
+ # Create labeled colorbar
+ dpi = 100
+ height = 128 / dpi
+ width = grid_logits.shape[2] / dpi
+ fig, ax = plt.subplots(figsize=(width, height), dpi=dpi)
+ img = ax.imshow(np.array([[-high, high]]), cmap=cmap)
+ plt.colorbar(
+ img,
+ cax=ax,
+ orientation="horizontal",
+ fraction=0.9,
+ aspect=width / height,
+ pad=0.0,
+ )
+ img.set_visible(False)
+ fig.tight_layout()
+ fig.canvas.draw()
+ # manually convert figure to numpy
+ cbar_np = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
+ cbar_np = cbar_np.reshape(fig.canvas.get_width_height()[::-1] + (3,))
+ cbar = torch.from_numpy(cbar_np.copy()).to(grid_logits.dtype) / 255.0
+ cbar = rearrange(cbar, "h w c -> c h w").to(grid_logits.device)
+
+ # Add colorbar to plot
+ annotated_grid = torch.cat((grid_logits, cbar), dim=1)
+ blended_grid = torch.cat((grid_blend, cbar), dim=1)
+ return {
+ "vis_logits": 2 * annotated_grid[None, ...] - 1,
+ "vis_logits_blended": 2 * blended_grid[None, ...] - 1,
+ }
+
+ def calculate_adaptive_weight(
+ self, nll_loss: torch.Tensor, g_loss: torch.Tensor, last_layer: torch.Tensor
+ ) -> torch.Tensor:
+ nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
+ g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
+
+ d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
+ d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
+ d_weight = d_weight * self.discriminator_weight
+ return d_weight
+
+ def forward(
+ self,
+ inputs: torch.Tensor,
+ reconstructions: torch.Tensor,
+ *, # added because I changed the order here
+ regularization_log: Dict[str, torch.Tensor],
+ optimizer_idx: int,
+ global_step: int,
+ last_layer: torch.Tensor,
+ split: str = "train",
+ weights: Union[None, float, torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, dict]:
+ if self.scale_input_to_tgt_size:
+ inputs = torch.nn.functional.interpolate(inputs, reconstructions.shape[2:], mode="bicubic", antialias=True)
+
+ if self.dims > 2:
+ inputs, reconstructions = map(
+ lambda x: rearrange(x, "b c t h w -> (b t) c h w"),
+ (inputs, reconstructions),
+ )
+
+ rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
+ if self.perceptual_weight > 0:
+ frame_indices = torch.randn((inputs.shape[0], inputs.shape[2])).topk(1, dim=-1).indices
+
+ from sgm.modules.autoencoding.losses.video_loss import pick_video_frame
+
+ input_frames = pick_video_frame(inputs, frame_indices)
+ recon_frames = pick_video_frame(reconstructions, frame_indices)
+
+ p_loss = self.perceptual_loss(input_frames.contiguous(), recon_frames.contiguous()).mean()
+ rec_loss = rec_loss + self.perceptual_weight * p_loss
+
+ nll_loss, weighted_nll_loss = self.get_nll_loss(rec_loss, weights)
+
+ # now the GAN part
+ if optimizer_idx == 0:
+ # generator update
+ if global_step >= self.discriminator_iter_start or not self.training:
+ logits_fake = self.discriminator(reconstructions.contiguous())
+ g_loss = -torch.mean(logits_fake)
+ if self.training:
+ d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
+ else:
+ d_weight = torch.tensor(1.0)
+ else:
+ d_weight = torch.tensor(0.0)
+ g_loss = torch.tensor(0.0, requires_grad=True)
+
+ loss = weighted_nll_loss + d_weight * self.disc_factor * g_loss
+ log = dict()
+ for k in regularization_log:
+ if k in self.regularization_weights:
+ loss = loss + self.regularization_weights[k] * regularization_log[k]
+ if k in self.additional_log_keys:
+ log[f"{split}/{k}"] = regularization_log[k].detach().float().mean()
+
+ log.update(
+ {
+ f"{split}/loss/total": loss.clone().detach().mean(),
+ f"{split}/loss/nll": nll_loss.detach().mean(),
+ f"{split}/loss/rec": rec_loss.detach().mean(),
+ f"{split}/loss/percep": p_loss.detach().mean(),
+ f"{split}/loss/rec": rec_loss.detach().mean(),
+ f"{split}/loss/g": g_loss.detach().mean(),
+ f"{split}/scalars/logvar": self.logvar.detach(),
+ f"{split}/scalars/d_weight": d_weight.detach(),
+ }
+ )
+
+ return loss, log
+ elif optimizer_idx == 1:
+ # second pass for discriminator update
+ logits_real = self.discriminator(inputs.contiguous().detach())
+ logits_fake = self.discriminator(reconstructions.contiguous().detach())
+
+ if global_step >= self.discriminator_iter_start or not self.training:
+ d_loss = self.disc_factor * self.disc_loss(logits_real, logits_fake)
+ else:
+ d_loss = torch.tensor(0.0, requires_grad=True)
+
+ log = {
+ f"{split}/loss/disc": d_loss.clone().detach().mean(),
+ f"{split}/logits/real": logits_real.detach().mean(),
+ f"{split}/logits/fake": logits_fake.detach().mean(),
+ }
+ return d_loss, log
+ else:
+ raise NotImplementedError(f"Unknown optimizer_idx {optimizer_idx}")
+
+ def get_nll_loss(
+ self,
+ rec_loss: torch.Tensor,
+ weights: Optional[Union[float, torch.Tensor]] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
+ weighted_nll_loss = nll_loss
+ if weights is not None:
+ weighted_nll_loss = weights * nll_loss
+ weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
+ nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
+
+ return nll_loss, weighted_nll_loss
diff --git a/sat/sgm/modules/autoencoding/losses/lpips.py b/sat/sgm/modules/autoencoding/losses/lpips.py
new file mode 100644
index 0000000000000000000000000000000000000000..fed01d64fbc696af237c267ed6b9cb4ed790ab70
--- /dev/null
+++ b/sat/sgm/modules/autoencoding/losses/lpips.py
@@ -0,0 +1,64 @@
+import torch
+import torch.nn as nn
+
+from ....util import default, instantiate_from_config
+from ..lpips.loss.lpips import LPIPS
+
+
+class LatentLPIPS(nn.Module):
+ def __init__(
+ self,
+ decoder_config,
+ perceptual_weight=1.0,
+ latent_weight=1.0,
+ scale_input_to_tgt_size=False,
+ scale_tgt_to_input_size=False,
+ perceptual_weight_on_inputs=0.0,
+ ):
+ super().__init__()
+ self.scale_input_to_tgt_size = scale_input_to_tgt_size
+ self.scale_tgt_to_input_size = scale_tgt_to_input_size
+ self.init_decoder(decoder_config)
+ self.perceptual_loss = LPIPS().eval()
+ self.perceptual_weight = perceptual_weight
+ self.latent_weight = latent_weight
+ self.perceptual_weight_on_inputs = perceptual_weight_on_inputs
+
+ def init_decoder(self, config):
+ self.decoder = instantiate_from_config(config)
+ if hasattr(self.decoder, "encoder"):
+ del self.decoder.encoder
+
+ def forward(self, latent_inputs, latent_predictions, image_inputs, split="train"):
+ log = dict()
+ loss = (latent_inputs - latent_predictions) ** 2
+ log[f"{split}/latent_l2_loss"] = loss.mean().detach()
+ image_reconstructions = None
+ if self.perceptual_weight > 0.0:
+ image_reconstructions = self.decoder.decode(latent_predictions)
+ image_targets = self.decoder.decode(latent_inputs)
+ perceptual_loss = self.perceptual_loss(image_targets.contiguous(), image_reconstructions.contiguous())
+ loss = self.latent_weight * loss.mean() + self.perceptual_weight * perceptual_loss.mean()
+ log[f"{split}/perceptual_loss"] = perceptual_loss.mean().detach()
+
+ if self.perceptual_weight_on_inputs > 0.0:
+ image_reconstructions = default(image_reconstructions, self.decoder.decode(latent_predictions))
+ if self.scale_input_to_tgt_size:
+ image_inputs = torch.nn.functional.interpolate(
+ image_inputs,
+ image_reconstructions.shape[2:],
+ mode="bicubic",
+ antialias=True,
+ )
+ elif self.scale_tgt_to_input_size:
+ image_reconstructions = torch.nn.functional.interpolate(
+ image_reconstructions,
+ image_inputs.shape[2:],
+ mode="bicubic",
+ antialias=True,
+ )
+
+ perceptual_loss2 = self.perceptual_loss(image_inputs.contiguous(), image_reconstructions.contiguous())
+ loss = loss + self.perceptual_weight_on_inputs * perceptual_loss2.mean()
+ log[f"{split}/perceptual_loss_on_inputs"] = perceptual_loss2.mean().detach()
+ return loss, log
diff --git a/sat/sgm/modules/autoencoding/losses/video_loss.py b/sat/sgm/modules/autoencoding/losses/video_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..01c4c60b6275eef46042a0c471554bb32f7eff7e
--- /dev/null
+++ b/sat/sgm/modules/autoencoding/losses/video_loss.py
@@ -0,0 +1,712 @@
+from typing import Any, Union
+from math import log2
+from beartype import beartype
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch import Tensor
+from torch.autograd import grad as torch_grad
+from torch.cuda.amp import autocast
+
+import torchvision
+from torchvision.models import VGG16_Weights
+from einops import rearrange, einsum, repeat
+from einops.layers.torch import Rearrange
+from kornia.filters import filter3d
+
+from ..magvit2_pytorch import Residual, FeedForward, LinearSpaceAttention
+from .lpips import LPIPS
+
+from sgm.modules.autoencoding.vqvae.movq_enc_3d import CausalConv3d, DownSample3D
+from sgm.util import instantiate_from_config
+
+
+def exists(v):
+ return v is not None
+
+
+def pair(t):
+ return t if isinstance(t, tuple) else (t, t)
+
+
+def leaky_relu(p=0.1):
+ return nn.LeakyReLU(p)
+
+
+def hinge_discr_loss(fake, real):
+ return (F.relu(1 + fake) + F.relu(1 - real)).mean()
+
+
+def hinge_gen_loss(fake):
+ return -fake.mean()
+
+
+@autocast(enabled=False)
+@beartype
+def grad_layer_wrt_loss(loss: Tensor, layer: nn.Parameter):
+ return torch_grad(outputs=loss, inputs=layer, grad_outputs=torch.ones_like(loss), retain_graph=True)[0].detach()
+
+
+def pick_video_frame(video, frame_indices):
+ batch, device = video.shape[0], video.device
+ video = rearrange(video, "b c f ... -> b f c ...")
+ batch_indices = torch.arange(batch, device=device)
+ batch_indices = rearrange(batch_indices, "b -> b 1")
+ images = video[batch_indices, frame_indices]
+ images = rearrange(images, "b 1 c ... -> b c ...")
+ return images
+
+
+def gradient_penalty(images, output):
+ batch_size = images.shape[0]
+
+ gradients = torch_grad(
+ outputs=output,
+ inputs=images,
+ grad_outputs=torch.ones(output.size(), device=images.device),
+ create_graph=True,
+ retain_graph=True,
+ only_inputs=True,
+ )[0]
+
+ gradients = rearrange(gradients, "b ... -> b (...)")
+ return ((gradients.norm(2, dim=1) - 1) ** 2).mean()
+
+
+# discriminator with anti-aliased downsampling (blurpool Zhang et al.)
+
+
+class Blur(nn.Module):
+ def __init__(self):
+ super().__init__()
+ f = torch.Tensor([1, 2, 1])
+ self.register_buffer("f", f)
+
+ def forward(self, x, space_only=False, time_only=False):
+ assert not (space_only and time_only)
+
+ f = self.f
+
+ if space_only:
+ f = einsum("i, j -> i j", f, f)
+ f = rearrange(f, "... -> 1 1 ...")
+ elif time_only:
+ f = rearrange(f, "f -> 1 f 1 1")
+ else:
+ f = einsum("i, j, k -> i j k", f, f, f)
+ f = rearrange(f, "... -> 1 ...")
+
+ is_images = x.ndim == 4
+
+ if is_images:
+ x = rearrange(x, "b c h w -> b c 1 h w")
+
+ out = filter3d(x, f, normalized=True)
+
+ if is_images:
+ out = rearrange(out, "b c 1 h w -> b c h w")
+
+ return out
+
+
+class DiscriminatorBlock(nn.Module):
+ def __init__(self, input_channels, filters, downsample=True, antialiased_downsample=True):
+ super().__init__()
+ self.conv_res = nn.Conv2d(input_channels, filters, 1, stride=(2 if downsample else 1))
+
+ self.net = nn.Sequential(
+ nn.Conv2d(input_channels, filters, 3, padding=1),
+ leaky_relu(),
+ nn.Conv2d(filters, filters, 3, padding=1),
+ leaky_relu(),
+ )
+
+ self.maybe_blur = Blur() if antialiased_downsample else None
+
+ self.downsample = (
+ nn.Sequential(
+ Rearrange("b c (h p1) (w p2) -> b (c p1 p2) h w", p1=2, p2=2), nn.Conv2d(filters * 4, filters, 1)
+ )
+ if downsample
+ else None
+ )
+
+ def forward(self, x):
+ res = self.conv_res(x)
+
+ x = self.net(x)
+
+ if exists(self.downsample):
+ if exists(self.maybe_blur):
+ x = self.maybe_blur(x, space_only=True)
+
+ x = self.downsample(x)
+
+ x = (x + res) * (2**-0.5)
+ return x
+
+
+class Discriminator(nn.Module):
+ @beartype
+ def __init__(
+ self,
+ *,
+ dim,
+ image_size,
+ channels=3,
+ max_dim=512,
+ attn_heads=8,
+ attn_dim_head=32,
+ linear_attn_dim_head=8,
+ linear_attn_heads=16,
+ ff_mult=4,
+ antialiased_downsample=False,
+ ):
+ super().__init__()
+ image_size = pair(image_size)
+ min_image_resolution = min(image_size)
+
+ num_layers = int(log2(min_image_resolution) - 2)
+
+ blocks = []
+
+ layer_dims = [channels] + [(dim * 4) * (2**i) for i in range(num_layers + 1)]
+ layer_dims = [min(layer_dim, max_dim) for layer_dim in layer_dims]
+ layer_dims_in_out = tuple(zip(layer_dims[:-1], layer_dims[1:]))
+
+ blocks = []
+ attn_blocks = []
+
+ image_resolution = min_image_resolution
+
+ for ind, (in_chan, out_chan) in enumerate(layer_dims_in_out):
+ num_layer = ind + 1
+ is_not_last = ind != (len(layer_dims_in_out) - 1)
+
+ block = DiscriminatorBlock(
+ in_chan, out_chan, downsample=is_not_last, antialiased_downsample=antialiased_downsample
+ )
+
+ attn_block = nn.Sequential(
+ Residual(LinearSpaceAttention(dim=out_chan, heads=linear_attn_heads, dim_head=linear_attn_dim_head)),
+ Residual(FeedForward(dim=out_chan, mult=ff_mult, images=True)),
+ )
+
+ blocks.append(nn.ModuleList([block, attn_block]))
+
+ image_resolution //= 2
+
+ self.blocks = nn.ModuleList(blocks)
+
+ dim_last = layer_dims[-1]
+
+ downsample_factor = 2**num_layers
+ last_fmap_size = tuple(map(lambda n: n // downsample_factor, image_size))
+
+ latent_dim = last_fmap_size[0] * last_fmap_size[1] * dim_last
+
+ self.to_logits = nn.Sequential(
+ nn.Conv2d(dim_last, dim_last, 3, padding=1),
+ leaky_relu(),
+ Rearrange("b ... -> b (...)"),
+ nn.Linear(latent_dim, 1),
+ Rearrange("b 1 -> b"),
+ )
+
+ def forward(self, x):
+ for block, attn_block in self.blocks:
+ x = block(x)
+ x = attn_block(x)
+
+ return self.to_logits(x)
+
+
+class DiscriminatorBlock3D(nn.Module):
+ def __init__(
+ self,
+ input_channels,
+ filters,
+ antialiased_downsample=True,
+ ):
+ super().__init__()
+ self.conv_res = nn.Conv3d(input_channels, filters, 1, stride=2)
+
+ self.net = nn.Sequential(
+ nn.Conv3d(input_channels, filters, 3, padding=1),
+ leaky_relu(),
+ nn.Conv3d(filters, filters, 3, padding=1),
+ leaky_relu(),
+ )
+
+ self.maybe_blur = Blur() if antialiased_downsample else None
+
+ self.downsample = nn.Sequential(
+ Rearrange("b c (f p1) (h p2) (w p3) -> b (c p1 p2 p3) f h w", p1=2, p2=2, p3=2),
+ nn.Conv3d(filters * 8, filters, 1),
+ )
+
+ def forward(self, x):
+ res = self.conv_res(x)
+
+ x = self.net(x)
+
+ if exists(self.downsample):
+ if exists(self.maybe_blur):
+ x = self.maybe_blur(x, space_only=True)
+
+ x = self.downsample(x)
+
+ x = (x + res) * (2**-0.5)
+ return x
+
+
+class DiscriminatorBlock3DWithfirstframe(nn.Module):
+ def __init__(
+ self,
+ input_channels,
+ filters,
+ antialiased_downsample=True,
+ pad_mode="first",
+ ):
+ super().__init__()
+ self.downsample_res = DownSample3D(
+ in_channels=input_channels,
+ out_channels=filters,
+ with_conv=True,
+ compress_time=True,
+ )
+
+ self.net = nn.Sequential(
+ CausalConv3d(input_channels, filters, kernel_size=3, pad_mode=pad_mode),
+ leaky_relu(),
+ CausalConv3d(filters, filters, kernel_size=3, pad_mode=pad_mode),
+ leaky_relu(),
+ )
+
+ self.maybe_blur = Blur() if antialiased_downsample else None
+
+ self.downsample = DownSample3D(
+ in_channels=filters,
+ out_channels=filters,
+ with_conv=True,
+ compress_time=True,
+ )
+
+ def forward(self, x):
+ res = self.downsample_res(x)
+
+ x = self.net(x)
+
+ if exists(self.downsample):
+ if exists(self.maybe_blur):
+ x = self.maybe_blur(x, space_only=True)
+
+ x = self.downsample(x)
+
+ x = (x + res) * (2**-0.5)
+ return x
+
+
+class Discriminator3D(nn.Module):
+ @beartype
+ def __init__(
+ self,
+ *,
+ dim,
+ image_size,
+ frame_num,
+ channels=3,
+ max_dim=512,
+ linear_attn_dim_head=8,
+ linear_attn_heads=16,
+ ff_mult=4,
+ antialiased_downsample=False,
+ ):
+ super().__init__()
+ image_size = pair(image_size)
+ min_image_resolution = min(image_size)
+
+ num_layers = int(log2(min_image_resolution) - 2)
+ temporal_num_layers = int(log2(frame_num))
+ self.temporal_num_layers = temporal_num_layers
+
+ layer_dims = [channels] + [(dim * 4) * (2**i) for i in range(num_layers + 1)]
+ layer_dims = [min(layer_dim, max_dim) for layer_dim in layer_dims]
+ layer_dims_in_out = tuple(zip(layer_dims[:-1], layer_dims[1:]))
+
+ blocks = []
+
+ image_resolution = min_image_resolution
+ frame_resolution = frame_num
+
+ for ind, (in_chan, out_chan) in enumerate(layer_dims_in_out):
+ num_layer = ind + 1
+ is_not_last = ind != (len(layer_dims_in_out) - 1)
+
+ if ind < temporal_num_layers:
+ block = DiscriminatorBlock3D(
+ in_chan,
+ out_chan,
+ antialiased_downsample=antialiased_downsample,
+ )
+
+ blocks.append(block)
+
+ frame_resolution //= 2
+ else:
+ block = DiscriminatorBlock(
+ in_chan,
+ out_chan,
+ downsample=is_not_last,
+ antialiased_downsample=antialiased_downsample,
+ )
+ attn_block = nn.Sequential(
+ Residual(
+ LinearSpaceAttention(dim=out_chan, heads=linear_attn_heads, dim_head=linear_attn_dim_head)
+ ),
+ Residual(FeedForward(dim=out_chan, mult=ff_mult, images=True)),
+ )
+
+ blocks.append(nn.ModuleList([block, attn_block]))
+
+ image_resolution //= 2
+
+ self.blocks = nn.ModuleList(blocks)
+
+ dim_last = layer_dims[-1]
+
+ downsample_factor = 2**num_layers
+ last_fmap_size = tuple(map(lambda n: n // downsample_factor, image_size))
+
+ latent_dim = last_fmap_size[0] * last_fmap_size[1] * dim_last
+
+ self.to_logits = nn.Sequential(
+ nn.Conv2d(dim_last, dim_last, 3, padding=1),
+ leaky_relu(),
+ Rearrange("b ... -> b (...)"),
+ nn.Linear(latent_dim, 1),
+ Rearrange("b 1 -> b"),
+ )
+
+ def forward(self, x):
+ for i, layer in enumerate(self.blocks):
+ if i < self.temporal_num_layers:
+ x = layer(x)
+ if i == self.temporal_num_layers - 1:
+ x = rearrange(x, "b c f h w -> (b f) c h w")
+ else:
+ block, attn_block = layer
+ x = block(x)
+ x = attn_block(x)
+
+ return self.to_logits(x)
+
+
+class Discriminator3DWithfirstframe(nn.Module):
+ @beartype
+ def __init__(
+ self,
+ *,
+ dim,
+ image_size,
+ frame_num,
+ channels=3,
+ max_dim=512,
+ linear_attn_dim_head=8,
+ linear_attn_heads=16,
+ ff_mult=4,
+ antialiased_downsample=False,
+ ):
+ super().__init__()
+ image_size = pair(image_size)
+ min_image_resolution = min(image_size)
+
+ num_layers = int(log2(min_image_resolution) - 2)
+ temporal_num_layers = int(log2(frame_num))
+ self.temporal_num_layers = temporal_num_layers
+
+ layer_dims = [channels] + [(dim * 4) * (2**i) for i in range(num_layers + 1)]
+ layer_dims = [min(layer_dim, max_dim) for layer_dim in layer_dims]
+ layer_dims_in_out = tuple(zip(layer_dims[:-1], layer_dims[1:]))
+
+ blocks = []
+
+ image_resolution = min_image_resolution
+ frame_resolution = frame_num
+
+ for ind, (in_chan, out_chan) in enumerate(layer_dims_in_out):
+ num_layer = ind + 1
+ is_not_last = ind != (len(layer_dims_in_out) - 1)
+
+ if ind < temporal_num_layers:
+ block = DiscriminatorBlock3DWithfirstframe(
+ in_chan,
+ out_chan,
+ antialiased_downsample=antialiased_downsample,
+ )
+
+ blocks.append(block)
+
+ frame_resolution //= 2
+ else:
+ block = DiscriminatorBlock(
+ in_chan,
+ out_chan,
+ downsample=is_not_last,
+ antialiased_downsample=antialiased_downsample,
+ )
+ attn_block = nn.Sequential(
+ Residual(
+ LinearSpaceAttention(dim=out_chan, heads=linear_attn_heads, dim_head=linear_attn_dim_head)
+ ),
+ Residual(FeedForward(dim=out_chan, mult=ff_mult, images=True)),
+ )
+
+ blocks.append(nn.ModuleList([block, attn_block]))
+
+ image_resolution //= 2
+
+ self.blocks = nn.ModuleList(blocks)
+
+ dim_last = layer_dims[-1]
+
+ downsample_factor = 2**num_layers
+ last_fmap_size = tuple(map(lambda n: n // downsample_factor, image_size))
+
+ latent_dim = last_fmap_size[0] * last_fmap_size[1] * dim_last
+
+ self.to_logits = nn.Sequential(
+ nn.Conv2d(dim_last, dim_last, 3, padding=1),
+ leaky_relu(),
+ Rearrange("b ... -> b (...)"),
+ nn.Linear(latent_dim, 1),
+ Rearrange("b 1 -> b"),
+ )
+
+ def forward(self, x):
+ for i, layer in enumerate(self.blocks):
+ if i < self.temporal_num_layers:
+ x = layer(x)
+ if i == self.temporal_num_layers - 1:
+ x = x.mean(dim=2)
+ # x = rearrange(x, "b c f h w -> (b f) c h w")
+ else:
+ block, attn_block = layer
+ x = block(x)
+ x = attn_block(x)
+
+ return self.to_logits(x)
+
+
+class VideoAutoencoderLoss(nn.Module):
+ def __init__(
+ self,
+ disc_start,
+ perceptual_weight=1,
+ adversarial_loss_weight=0,
+ multiscale_adversarial_loss_weight=0,
+ grad_penalty_loss_weight=0,
+ quantizer_aux_loss_weight=0,
+ vgg_weights=VGG16_Weights.DEFAULT,
+ discr_kwargs=None,
+ discr_3d_kwargs=None,
+ ):
+ super().__init__()
+
+ self.disc_start = disc_start
+ self.perceptual_weight = perceptual_weight
+ self.adversarial_loss_weight = adversarial_loss_weight
+ self.multiscale_adversarial_loss_weight = multiscale_adversarial_loss_weight
+ self.grad_penalty_loss_weight = grad_penalty_loss_weight
+ self.quantizer_aux_loss_weight = quantizer_aux_loss_weight
+
+ if self.perceptual_weight > 0:
+ self.perceptual_model = LPIPS().eval()
+ # self.vgg = torchvision.models.vgg16(pretrained = True)
+ # self.vgg.requires_grad_(False)
+ # if self.adversarial_loss_weight > 0:
+ # self.discr = Discriminator(**discr_kwargs)
+ # else:
+ # self.discr = None
+ # if self.multiscale_adversarial_loss_weight > 0:
+ # self.multiscale_discrs = nn.ModuleList([*multiscale_discrs])
+ # else:
+ # self.multiscale_discrs = None
+ if discr_kwargs is not None:
+ self.discr = Discriminator(**discr_kwargs)
+ else:
+ self.discr = None
+ if discr_3d_kwargs is not None:
+ # self.discr_3d = Discriminator3D(**discr_3d_kwargs)
+ self.discr_3d = instantiate_from_config(discr_3d_kwargs)
+ else:
+ self.discr_3d = None
+ # self.multiscale_discrs = nn.ModuleList([*multiscale_discrs])
+
+ self.register_buffer("zero", torch.tensor(0.0), persistent=False)
+
+ def get_trainable_params(self) -> Any:
+ params = []
+ if self.discr is not None:
+ params += list(self.discr.parameters())
+ if self.discr_3d is not None:
+ params += list(self.discr_3d.parameters())
+ # if self.multiscale_discrs is not None:
+ # for discr in self.multiscale_discrs:
+ # params += list(discr.parameters())
+ return params
+
+ def get_trainable_parameters(self) -> Any:
+ return self.get_trainable_params()
+
+ def forward(
+ self,
+ inputs,
+ reconstructions,
+ optimizer_idx,
+ global_step,
+ aux_losses=None,
+ last_layer=None,
+ split="train",
+ ):
+ batch, channels, frames = inputs.shape[:3]
+
+ if optimizer_idx == 0:
+ recon_loss = F.mse_loss(inputs, reconstructions)
+
+ if self.perceptual_weight > 0:
+ frame_indices = torch.randn((batch, frames)).topk(1, dim=-1).indices
+
+ input_frames = pick_video_frame(inputs, frame_indices)
+ recon_frames = pick_video_frame(reconstructions, frame_indices)
+
+ perceptual_loss = self.perceptual_model(input_frames.contiguous(), recon_frames.contiguous()).mean()
+ else:
+ perceptual_loss = self.zero
+
+ if global_step >= self.disc_start or not self.training or self.adversarial_loss_weight == 0:
+ gen_loss = self.zero
+ adaptive_weight = 0
+ else:
+ # frame_indices = torch.randn((batch, frames)).topk(1, dim = -1).indices
+ # recon_video_frames = pick_video_frame(reconstructions, frame_indices)
+
+ # fake_logits = self.discr(recon_video_frames)
+ fake_logits = self.discr_3d(reconstructions)
+ gen_loss = hinge_gen_loss(fake_logits)
+
+ adaptive_weight = 1
+ if self.perceptual_weight > 0 and last_layer is not None:
+ norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(perceptual_loss, last_layer).norm(p=2)
+ norm_grad_wrt_gen_loss = grad_layer_wrt_loss(gen_loss, last_layer).norm(p=2)
+ adaptive_weight = norm_grad_wrt_perceptual_loss / norm_grad_wrt_gen_loss.clamp(min=1e-3)
+ adaptive_weight.clamp_(max=1e3)
+
+ if torch.isnan(adaptive_weight).any():
+ adaptive_weight = 1
+
+ # multiscale discriminator losses
+
+ # multiscale_gen_losses = []
+ # multiscale_gen_adaptive_weights = []
+ # if self.multiscale_adversarial_loss_weight > 0:
+ # if not exists(recon_video_frames):
+ # frame_indices = torch.randn((batch, frames)).topk(1, dim = -1).indices
+ # recon_video_frames = pick_video_frame(reconstructions, frame_indices)
+ # for discr in self.multiscale_discrs:
+ # fake_logits = recon_video_frames
+
+ # multiscale_gen_loss = hinge_gen_loss(fake_logits)
+ # multiscale_gen_losses.append(multiscale_gen_loss)
+
+ # multiscale_adaptive_weight = 1.
+
+ # if exists(norm_grad_wrt_perceptual_loss):
+ # norm_grad_wrt_gen_loss = grad_layer_wrt_loss(multiscale_gen_loss, last_layer).norm(p = 2)
+ # multiscale_adaptive_weight = norm_grad_wrt_perceptual_loss / norm_grad_wrt_gen_loss.clamp(min = 1e-5)
+ # multiscale_adaptive_weight.clamp_(max = 1e3)
+
+ # multiscale_gen_adaptive_weights.append(multiscale_adaptive_weight)
+ # weighted_multiscale_gen_losses = sum(loss * weight for loss, weight in zip(multiscale_gen_losses, multiscale_gen_adaptive_weights))
+ # else:
+ # weighted_multiscale_gen_losses = self.zero
+
+ if aux_losses is None:
+ aux_losses = self.zero
+
+ total_loss = (
+ recon_loss
+ + aux_losses * self.quantizer_aux_loss_weight
+ + perceptual_loss * self.perceptual_weight
+ + gen_loss * self.adversarial_loss_weight
+ )
+ # gen_loss * adaptive_weight * self.adversarial_loss_weight + \
+ # weighted_multiscale_gen_losses * self.multiscale_adversarial_loss_weight
+
+ log = {
+ "{}/total_loss".format(split): total_loss.detach(),
+ "{}/recon_loss".format(split): recon_loss.detach(),
+ "{}/perceptual_loss".format(split): perceptual_loss.detach(),
+ "{}/gen_loss".format(split): gen_loss.detach(),
+ "{}/aux_losses".format(split): aux_losses.detach(),
+ # "{}/weighted_multiscale_gen_losses".format(split): weighted_multiscale_gen_losses.detach(),
+ "{}/adaptive_weight".format(split): adaptive_weight,
+ # "{}/multiscale_adaptive_weights".format(split): sum(multiscale_gen_adaptive_weights),
+ }
+
+ return total_loss, log
+
+ if optimizer_idx == 1:
+ # frame_indices = torch.randn((batch, frames)).topk(1, dim = -1).indices
+
+ # real = pick_video_frame(inputs, frame_indices)
+ # fake = pick_video_frame(reconstructions, frame_indices)
+
+ # apply_gradient_penalty = self.grad_penalty_loss_weight > 0
+ # if apply_gradient_penalty:
+ # real = real.requires_grad_()
+
+ # real_logits = self.discr(real)
+ # fake_logits = self.discr(fake.detach())
+
+ apply_gradient_penalty = self.grad_penalty_loss_weight > 0
+ if apply_gradient_penalty:
+ inputs = inputs.requires_grad_()
+ real_logits = self.discr_3d(inputs)
+ fake_logits = self.discr_3d(reconstructions.detach())
+
+ discr_loss = hinge_discr_loss(fake_logits, real_logits)
+
+ # # multiscale discriminators
+ # multiscale_discr_losses = []
+ # if self.multiscale_adversarial_loss_weight > 0:
+ # for discr in self.multiscale_discrs:
+ # multiscale_real_logits = discr(inputs)
+ # multiscale_fake_logits = discr(reconstructions.detach())
+
+ # multiscale_discr_loss = hinge_discr_loss(multiscale_fake_logits, multiscale_real_logits)
+ # multiscale_discr_losses.append(multiscale_discr_loss)
+ # else:
+ # multiscale_discr_losses.append(self.zero)
+
+ # gradient penalty
+ if apply_gradient_penalty:
+ # gradient_penalty_loss = gradient_penalty(real, real_logits)
+ gradient_penalty_loss = gradient_penalty(inputs, real_logits)
+ else:
+ gradient_penalty_loss = self.zero
+
+ total_loss = discr_loss + self.grad_penalty_loss_weight * gradient_penalty_loss
+ # self.grad_penalty_loss_weight * gradient_penalty_loss + \
+ # sum(multiscale_discr_losses) * self.multiscale_adversarial_loss_weight
+
+ log = {
+ "{}/total_disc_loss".format(split): total_loss.detach(),
+ "{}/discr_loss".format(split): discr_loss.detach(),
+ "{}/grad_penalty_loss".format(split): gradient_penalty_loss.detach(),
+ # "{}/multiscale_discr_loss".format(split): sum(multiscale_discr_losses).detach(),
+ "{}/logits_real".format(split): real_logits.detach().mean(),
+ "{}/logits_fake".format(split): fake_logits.detach().mean(),
+ }
+ return total_loss, log
diff --git a/sat/sgm/modules/autoencoding/lpips/__init__.py b/sat/sgm/modules/autoencoding/lpips/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/sat/sgm/modules/autoencoding/lpips/loss/.gitignore b/sat/sgm/modules/autoencoding/lpips/loss/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..a92958a1cd4ffe005e1f5448ab3e6fd9c795a43a
--- /dev/null
+++ b/sat/sgm/modules/autoencoding/lpips/loss/.gitignore
@@ -0,0 +1 @@
+vgg.pth
\ No newline at end of file
diff --git a/sat/sgm/modules/autoencoding/lpips/loss/LICENSE b/sat/sgm/modules/autoencoding/lpips/loss/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..924cfc85b8d63ef538f5676f830a2a8497932108
--- /dev/null
+++ b/sat/sgm/modules/autoencoding/lpips/loss/LICENSE
@@ -0,0 +1,23 @@
+Copyright (c) 2018, Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+* Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+
+* Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
\ No newline at end of file
diff --git a/sat/sgm/modules/autoencoding/lpips/loss/__init__.py b/sat/sgm/modules/autoencoding/lpips/loss/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/sat/sgm/modules/autoencoding/lpips/loss/lpips.py b/sat/sgm/modules/autoencoding/lpips/loss/lpips.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0249cf74ca8b2c7fb51cb3b51ec61e02107a970
--- /dev/null
+++ b/sat/sgm/modules/autoencoding/lpips/loss/lpips.py
@@ -0,0 +1,132 @@
+"""Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""
+
+from collections import namedtuple
+
+import torch
+import torch.nn as nn
+from torchvision import models
+
+from ..util import get_ckpt_path
+
+
+class LPIPS(nn.Module):
+ # Learned perceptual metric
+ def __init__(self, use_dropout=True):
+ super().__init__()
+ self.scaling_layer = ScalingLayer()
+ self.chns = [64, 128, 256, 512, 512] # vg16 features
+ self.net = vgg16(pretrained=True, requires_grad=False)
+ self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
+ self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
+ self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
+ self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
+ self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
+ self.load_from_pretrained()
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def load_from_pretrained(self, name="vgg_lpips"):
+ ckpt = get_ckpt_path(name, "sgm/modules/autoencoding/lpips/loss")
+ self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
+ print("loaded pretrained LPIPS loss from {}".format(ckpt))
+
+ @classmethod
+ def from_pretrained(cls, name="vgg_lpips"):
+ if name != "vgg_lpips":
+ raise NotImplementedError
+ model = cls()
+ ckpt = get_ckpt_path(name)
+ model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
+ return model
+
+ def forward(self, input, target):
+ in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
+ outs0, outs1 = self.net(in0_input), self.net(in1_input)
+ feats0, feats1, diffs = {}, {}, {}
+ lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
+ for kk in range(len(self.chns)):
+ feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
+ diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
+
+ res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))]
+ val = res[0]
+ for l in range(1, len(self.chns)):
+ val += res[l]
+ return val
+
+
+class ScalingLayer(nn.Module):
+ def __init__(self):
+ super(ScalingLayer, self).__init__()
+ self.register_buffer("shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None])
+ self.register_buffer("scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None])
+
+ def forward(self, inp):
+ return (inp - self.shift) / self.scale
+
+
+class NetLinLayer(nn.Module):
+ """A single linear layer which does a 1x1 conv"""
+
+ def __init__(self, chn_in, chn_out=1, use_dropout=False):
+ super(NetLinLayer, self).__init__()
+ layers = (
+ [
+ nn.Dropout(),
+ ]
+ if (use_dropout)
+ else []
+ )
+ layers += [
+ nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),
+ ]
+ self.model = nn.Sequential(*layers)
+
+
+class vgg16(torch.nn.Module):
+ def __init__(self, requires_grad=False, pretrained=True):
+ super(vgg16, self).__init__()
+ vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
+ self.slice1 = torch.nn.Sequential()
+ self.slice2 = torch.nn.Sequential()
+ self.slice3 = torch.nn.Sequential()
+ self.slice4 = torch.nn.Sequential()
+ self.slice5 = torch.nn.Sequential()
+ self.N_slices = 5
+ for x in range(4):
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
+ for x in range(4, 9):
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
+ for x in range(9, 16):
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
+ for x in range(16, 23):
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
+ for x in range(23, 30):
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
+ if not requires_grad:
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, X):
+ h = self.slice1(X)
+ h_relu1_2 = h
+ h = self.slice2(h)
+ h_relu2_2 = h
+ h = self.slice3(h)
+ h_relu3_3 = h
+ h = self.slice4(h)
+ h_relu4_3 = h
+ h = self.slice5(h)
+ h_relu5_3 = h
+ vgg_outputs = namedtuple("VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"])
+ out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
+ return out
+
+
+def normalize_tensor(x, eps=1e-10):
+ norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True))
+ return x / (norm_factor + eps)
+
+
+def spatial_average(x, keepdim=True):
+ return x.mean([2, 3], keepdim=keepdim)
diff --git a/sat/sgm/modules/autoencoding/lpips/model/LICENSE b/sat/sgm/modules/autoencoding/lpips/model/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..4b356e66b5aa689b339f1a80a9f1b5ba378003bb
--- /dev/null
+++ b/sat/sgm/modules/autoencoding/lpips/model/LICENSE
@@ -0,0 +1,58 @@
+Copyright (c) 2017, Jun-Yan Zhu and Taesung Park
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+* Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+
+* Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+
+--------------------------- LICENSE FOR pix2pix --------------------------------
+BSD License
+
+For pix2pix software
+Copyright (c) 2016, Phillip Isola and Jun-Yan Zhu
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+* Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+
+* Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+----------------------------- LICENSE FOR DCGAN --------------------------------
+BSD License
+
+For dcgan.torch software
+
+Copyright (c) 2015, Facebook, Inc. All rights reserved.
+
+Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
+
+Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
+
+Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
+
+Neither the name Facebook nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
\ No newline at end of file
diff --git a/sat/sgm/modules/autoencoding/lpips/model/__init__.py b/sat/sgm/modules/autoencoding/lpips/model/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/sat/sgm/modules/autoencoding/lpips/model/model.py b/sat/sgm/modules/autoencoding/lpips/model/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee13babd77864bb81456a7c9634ba7e9e597983f
--- /dev/null
+++ b/sat/sgm/modules/autoencoding/lpips/model/model.py
@@ -0,0 +1,89 @@
+import functools
+
+import torch.nn as nn
+
+from ..util import ActNorm
+
+
+def weights_init(m):
+ classname = m.__class__.__name__
+ if classname.find("Conv") != -1:
+ try:
+ nn.init.normal_(m.weight.data, 0.0, 0.02)
+ except:
+ nn.init.normal_(m.conv.weight.data, 0.0, 0.02)
+ elif classname.find("BatchNorm") != -1:
+ nn.init.normal_(m.weight.data, 1.0, 0.02)
+ nn.init.constant_(m.bias.data, 0)
+
+
+class NLayerDiscriminator(nn.Module):
+ """Defines a PatchGAN discriminator as in Pix2Pix
+ --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
+ """
+
+ def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
+ """Construct a PatchGAN discriminator
+ Parameters:
+ input_nc (int) -- the number of channels in input images
+ ndf (int) -- the number of filters in the last conv layer
+ n_layers (int) -- the number of conv layers in the discriminator
+ norm_layer -- normalization layer
+ """
+ super(NLayerDiscriminator, self).__init__()
+ if not use_actnorm:
+ norm_layer = nn.BatchNorm2d
+ else:
+ norm_layer = ActNorm
+ if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
+ use_bias = norm_layer.func != nn.BatchNorm2d
+ else:
+ use_bias = norm_layer != nn.BatchNorm2d
+
+ kw = 4
+ padw = 1
+ sequence = [
+ nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
+ nn.LeakyReLU(0.2, True),
+ ]
+ nf_mult = 1
+ nf_mult_prev = 1
+ for n in range(1, n_layers): # gradually increase the number of filters
+ nf_mult_prev = nf_mult
+ nf_mult = min(2**n, 8)
+ sequence += [
+ nn.Conv2d(
+ ndf * nf_mult_prev,
+ ndf * nf_mult,
+ kernel_size=kw,
+ stride=2,
+ padding=padw,
+ bias=use_bias,
+ ),
+ norm_layer(ndf * nf_mult),
+ nn.LeakyReLU(0.2, True),
+ ]
+
+ nf_mult_prev = nf_mult
+ nf_mult = min(2**n_layers, 8)
+ sequence += [
+ nn.Conv2d(
+ ndf * nf_mult_prev,
+ ndf * nf_mult,
+ kernel_size=kw,
+ stride=1,
+ padding=padw,
+ bias=use_bias,
+ ),
+ norm_layer(ndf * nf_mult),
+ nn.LeakyReLU(0.2, True),
+ ]
+
+ sequence += [
+ nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)
+ ] # output 1 channel prediction map
+ self.main = nn.Sequential(*sequence)
+
+ def forward(self, input):
+ """Standard forward."""
+ return self.main(input)
diff --git a/sat/sgm/modules/autoencoding/lpips/util.py b/sat/sgm/modules/autoencoding/lpips/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb4a03624437b1a2498026a2669e57cb66409e6d
--- /dev/null
+++ b/sat/sgm/modules/autoencoding/lpips/util.py
@@ -0,0 +1,114 @@
+import hashlib
+import os
+
+import requests
+import torch
+import torch.nn as nn
+from tqdm import tqdm
+
+URL_MAP = {"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"}
+
+CKPT_MAP = {"vgg_lpips": "vgg.pth"}
+
+MD5_MAP = {"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"}
+
+
+def download(url, local_path, chunk_size=1024):
+ os.makedirs(os.path.split(local_path)[0], exist_ok=True)
+ with requests.get(url, stream=True) as r:
+ total_size = int(r.headers.get("content-length", 0))
+ with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
+ with open(local_path, "wb") as f:
+ for data in r.iter_content(chunk_size=chunk_size):
+ if data:
+ f.write(data)
+ pbar.update(chunk_size)
+
+
+def md5_hash(path):
+ with open(path, "rb") as f:
+ content = f.read()
+ return hashlib.md5(content).hexdigest()
+
+
+def get_ckpt_path(name, root, check=False):
+ assert name in URL_MAP
+ path = os.path.join(root, CKPT_MAP[name])
+ if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
+ print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
+ download(URL_MAP[name], path)
+ md5 = md5_hash(path)
+ assert md5 == MD5_MAP[name], md5
+ return path
+
+
+class ActNorm(nn.Module):
+ def __init__(self, num_features, logdet=False, affine=True, allow_reverse_init=False):
+ assert affine
+ super().__init__()
+ self.logdet = logdet
+ self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
+ self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
+ self.allow_reverse_init = allow_reverse_init
+
+ self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8))
+
+ def initialize(self, input):
+ with torch.no_grad():
+ flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
+ mean = flatten.mean(1).unsqueeze(1).unsqueeze(2).unsqueeze(3).permute(1, 0, 2, 3)
+ std = flatten.std(1).unsqueeze(1).unsqueeze(2).unsqueeze(3).permute(1, 0, 2, 3)
+
+ self.loc.data.copy_(-mean)
+ self.scale.data.copy_(1 / (std + 1e-6))
+
+ def forward(self, input, reverse=False):
+ if reverse:
+ return self.reverse(input)
+ if len(input.shape) == 2:
+ input = input[:, :, None, None]
+ squeeze = True
+ else:
+ squeeze = False
+
+ _, _, height, width = input.shape
+
+ if self.training and self.initialized.item() == 0:
+ self.initialize(input)
+ self.initialized.fill_(1)
+
+ h = self.scale * (input + self.loc)
+
+ if squeeze:
+ h = h.squeeze(-1).squeeze(-1)
+
+ if self.logdet:
+ log_abs = torch.log(torch.abs(self.scale))
+ logdet = height * width * torch.sum(log_abs)
+ logdet = logdet * torch.ones(input.shape[0]).to(input)
+ return h, logdet
+
+ return h
+
+ def reverse(self, output):
+ if self.training and self.initialized.item() == 0:
+ if not self.allow_reverse_init:
+ raise RuntimeError(
+ "Initializing ActNorm in reverse direction is "
+ "disabled by default. Use allow_reverse_init=True to enable."
+ )
+ else:
+ self.initialize(output)
+ self.initialized.fill_(1)
+
+ if len(output.shape) == 2:
+ output = output[:, :, None, None]
+ squeeze = True
+ else:
+ squeeze = False
+
+ h = output / self.scale - self.loc
+
+ if squeeze:
+ h = h.squeeze(-1).squeeze(-1)
+ return h
diff --git a/sat/sgm/modules/autoencoding/lpips/vqperceptual.py b/sat/sgm/modules/autoencoding/lpips/vqperceptual.py
new file mode 100644
index 0000000000000000000000000000000000000000..1e4944bd6c287fa0c74bf1c5f1cd8289a27c01b6
--- /dev/null
+++ b/sat/sgm/modules/autoencoding/lpips/vqperceptual.py
@@ -0,0 +1,16 @@
+import torch
+import torch.nn.functional as F
+
+
+def hinge_d_loss(logits_real, logits_fake):
+ loss_real = torch.mean(F.relu(1.0 - logits_real))
+ loss_fake = torch.mean(F.relu(1.0 + logits_fake))
+ d_loss = 0.5 * (loss_real + loss_fake)
+ return d_loss
+
+
+def vanilla_d_loss(logits_real, logits_fake):
+ d_loss = 0.5 * (
+ torch.mean(torch.nn.functional.softplus(-logits_real)) + torch.mean(torch.nn.functional.softplus(logits_fake))
+ )
+ return d_loss
diff --git a/sat/sgm/modules/autoencoding/magvit2_pytorch.py b/sat/sgm/modules/autoencoding/magvit2_pytorch.py
new file mode 100644
index 0000000000000000000000000000000000000000..58889526b30e76fd2715d153940f0059a64e3fb4
--- /dev/null
+++ b/sat/sgm/modules/autoencoding/magvit2_pytorch.py
@@ -0,0 +1,1762 @@
+import copy
+from pathlib import Path
+from math import log2, ceil, sqrt
+from functools import wraps, partial
+
+import torch
+import torch.nn.functional as F
+from torch.cuda.amp import autocast
+from torch import nn, einsum, Tensor
+from torch.nn import Module, ModuleList
+from torch.autograd import grad as torch_grad
+
+import torchvision
+from torchvision.models import VGG16_Weights
+
+from collections import namedtuple
+
+# from vector_quantize_pytorch import LFQ, FSQ
+from .regularizers.finite_scalar_quantization import FSQ
+from .regularizers.lookup_free_quantization import LFQ
+
+from einops import rearrange, repeat, reduce, pack, unpack
+from einops.layers.torch import Rearrange
+
+from beartype import beartype
+from beartype.typing import Union, Tuple, Optional, List
+
+from magvit2_pytorch.attend import Attend
+from magvit2_pytorch.version import __version__
+
+from gateloop_transformer import SimpleGateLoopLayer
+
+from taylor_series_linear_attention import TaylorSeriesLinearAttn
+
+from kornia.filters import filter3d
+
+import pickle
+
+# helper
+
+
+def exists(v):
+ return v is not None
+
+
+def default(v, d):
+ return v if exists(v) else d
+
+
+def safe_get_index(it, ind, default=None):
+ if ind < len(it):
+ return it[ind]
+ return default
+
+
+def pair(t):
+ return t if isinstance(t, tuple) else (t, t)
+
+
+def identity(t, *args, **kwargs):
+ return t
+
+
+def divisible_by(num, den):
+ return (num % den) == 0
+
+
+def pack_one(t, pattern):
+ return pack([t], pattern)
+
+
+def unpack_one(t, ps, pattern):
+ return unpack(t, ps, pattern)[0]
+
+
+def append_dims(t, ndims: int):
+ return t.reshape(*t.shape, *((1,) * ndims))
+
+
+def is_odd(n):
+ return not divisible_by(n, 2)
+
+
+def maybe_del_attr_(o, attr):
+ if hasattr(o, attr):
+ delattr(o, attr)
+
+
+def cast_tuple(t, length=1):
+ return t if isinstance(t, tuple) else ((t,) * length)
+
+
+# tensor helpers
+
+
+def l2norm(t):
+ return F.normalize(t, dim=-1, p=2)
+
+
+def pad_at_dim(t, pad, dim=-1, value=0.0):
+ dims_from_right = (-dim - 1) if dim < 0 else (t.ndim - dim - 1)
+ zeros = (0, 0) * dims_from_right
+ return F.pad(t, (*zeros, *pad), value=value)
+
+
+def pick_video_frame(video, frame_indices):
+ batch, device = video.shape[0], video.device
+ video = rearrange(video, "b c f ... -> b f c ...")
+ batch_indices = torch.arange(batch, device=device)
+ batch_indices = rearrange(batch_indices, "b -> b 1")
+ images = video[batch_indices, frame_indices]
+ images = rearrange(images, "b 1 c ... -> b c ...")
+ return images
+
+
+# gan related
+
+
+def gradient_penalty(images, output):
+ batch_size = images.shape[0]
+
+ gradients = torch_grad(
+ outputs=output,
+ inputs=images,
+ grad_outputs=torch.ones(output.size(), device=images.device),
+ create_graph=True,
+ retain_graph=True,
+ only_inputs=True,
+ )[0]
+
+ gradients = rearrange(gradients, "b ... -> b (...)")
+ return ((gradients.norm(2, dim=1) - 1) ** 2).mean()
+
+
+def leaky_relu(p=0.1):
+ return nn.LeakyReLU(p)
+
+
+def hinge_discr_loss(fake, real):
+ return (F.relu(1 + fake) + F.relu(1 - real)).mean()
+
+
+def hinge_gen_loss(fake):
+ return -fake.mean()
+
+
+@autocast(enabled=False)
+@beartype
+def grad_layer_wrt_loss(loss: Tensor, layer: nn.Parameter):
+ return torch_grad(outputs=loss, inputs=layer, grad_outputs=torch.ones_like(loss), retain_graph=True)[0].detach()
+
+
+# helper decorators
+
+
+def remove_vgg(fn):
+ @wraps(fn)
+ def inner(self, *args, **kwargs):
+ has_vgg = hasattr(self, "vgg")
+ if has_vgg:
+ vgg = self.vgg
+ delattr(self, "vgg")
+
+ out = fn(self, *args, **kwargs)
+
+ if has_vgg:
+ self.vgg = vgg
+
+ return out
+
+ return inner
+
+
+# helper classes
+
+
+def Sequential(*modules):
+ modules = [*filter(exists, modules)]
+
+ if len(modules) == 0:
+ return nn.Identity()
+
+ return nn.Sequential(*modules)
+
+
+class Residual(Module):
+ @beartype
+ def __init__(self, fn: Module):
+ super().__init__()
+ self.fn = fn
+
+ def forward(self, x, **kwargs):
+ return self.fn(x, **kwargs) + x
+
+
+# for a bunch of tensor operations to change tensor to (batch, time, feature dimension) and back
+
+
+class ToTimeSequence(Module):
+ @beartype
+ def __init__(self, fn: Module):
+ super().__init__()
+ self.fn = fn
+
+ def forward(self, x, **kwargs):
+ x = rearrange(x, "b c f ... -> b ... f c")
+ x, ps = pack_one(x, "* n c")
+
+ o = self.fn(x, **kwargs)
+
+ o = unpack_one(o, ps, "* n c")
+ return rearrange(o, "b ... f c -> b c f ...")
+
+
+class SqueezeExcite(Module):
+ # global context network - attention-esque squeeze-excite variant (https://arxiv.org/abs/2012.13375)
+
+ def __init__(self, dim, *, dim_out=None, dim_hidden_min=16, init_bias=-10):
+ super().__init__()
+ dim_out = default(dim_out, dim)
+
+ self.to_k = nn.Conv2d(dim, 1, 1)
+ dim_hidden = max(dim_hidden_min, dim_out // 2)
+
+ self.net = nn.Sequential(
+ nn.Conv2d(dim, dim_hidden, 1), nn.LeakyReLU(0.1), nn.Conv2d(dim_hidden, dim_out, 1), nn.Sigmoid()
+ )
+
+ nn.init.zeros_(self.net[-2].weight)
+ nn.init.constant_(self.net[-2].bias, init_bias)
+
+ def forward(self, x):
+ orig_input, batch = x, x.shape[0]
+ is_video = x.ndim == 5
+
+ if is_video:
+ x = rearrange(x, "b c f h w -> (b f) c h w")
+
+ context = self.to_k(x)
+
+ context = rearrange(context, "b c h w -> b c (h w)").softmax(dim=-1)
+ spatial_flattened_input = rearrange(x, "b c h w -> b c (h w)")
+
+ out = einsum("b i n, b c n -> b c i", context, spatial_flattened_input)
+ out = rearrange(out, "... -> ... 1")
+ gates = self.net(out)
+
+ if is_video:
+ gates = rearrange(gates, "(b f) c h w -> b c f h w", b=batch)
+
+ return gates * orig_input
+
+
+# token shifting
+
+
+class TokenShift(Module):
+ @beartype
+ def __init__(self, fn: Module):
+ super().__init__()
+ self.fn = fn
+
+ def forward(self, x, **kwargs):
+ x, x_shift = x.chunk(2, dim=1)
+ x_shift = pad_at_dim(x_shift, (1, -1), dim=2) # shift time dimension
+ x = torch.cat((x, x_shift), dim=1)
+ return self.fn(x, **kwargs)
+
+
+# rmsnorm
+
+
+class RMSNorm(Module):
+ def __init__(self, dim, channel_first=False, images=False, bias=False):
+ super().__init__()
+ broadcastable_dims = (1, 1, 1) if not images else (1, 1)
+ shape = (dim, *broadcastable_dims) if channel_first else (dim,)
+
+ self.channel_first = channel_first
+ self.scale = dim**0.5
+ self.gamma = nn.Parameter(torch.ones(shape))
+ self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
+
+ def forward(self, x):
+ return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
+
+
+class AdaptiveRMSNorm(Module):
+ def __init__(self, dim, *, dim_cond, channel_first=False, images=False, bias=False):
+ super().__init__()
+ broadcastable_dims = (1, 1, 1) if not images else (1, 1)
+ shape = (dim, *broadcastable_dims) if channel_first else (dim,)
+
+ self.dim_cond = dim_cond
+ self.channel_first = channel_first
+ self.scale = dim**0.5
+
+ self.to_gamma = nn.Linear(dim_cond, dim)
+ self.to_bias = nn.Linear(dim_cond, dim) if bias else None
+
+ nn.init.zeros_(self.to_gamma.weight)
+ nn.init.ones_(self.to_gamma.bias)
+
+ if bias:
+ nn.init.zeros_(self.to_bias.weight)
+ nn.init.zeros_(self.to_bias.bias)
+
+ @beartype
+ def forward(self, x: Tensor, *, cond: Tensor):
+ batch = x.shape[0]
+ assert cond.shape == (batch, self.dim_cond)
+
+ gamma = self.to_gamma(cond)
+
+ bias = 0.0
+ if exists(self.to_bias):
+ bias = self.to_bias(cond)
+
+ if self.channel_first:
+ gamma = append_dims(gamma, x.ndim - 2)
+
+ if exists(self.to_bias):
+ bias = append_dims(bias, x.ndim - 2)
+
+ return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * gamma + bias
+
+
+# attention
+
+
+class Attention(Module):
+ @beartype
+ def __init__(
+ self,
+ *,
+ dim,
+ dim_cond: Optional[int] = None,
+ causal=False,
+ dim_head=32,
+ heads=8,
+ flash=False,
+ dropout=0.0,
+ num_memory_kv=4,
+ ):
+ super().__init__()
+ dim_inner = dim_head * heads
+
+ self.need_cond = exists(dim_cond)
+
+ if self.need_cond:
+ self.norm = AdaptiveRMSNorm(dim, dim_cond=dim_cond)
+ else:
+ self.norm = RMSNorm(dim)
+
+ self.to_qkv = nn.Sequential(
+ nn.Linear(dim, dim_inner * 3, bias=False), Rearrange("b n (qkv h d) -> qkv b h n d", qkv=3, h=heads)
+ )
+
+ assert num_memory_kv > 0
+ self.mem_kv = nn.Parameter(torch.randn(2, heads, num_memory_kv, dim_head))
+
+ self.attend = Attend(causal=causal, dropout=dropout, flash=flash)
+
+ self.to_out = nn.Sequential(Rearrange("b h n d -> b n (h d)"), nn.Linear(dim_inner, dim, bias=False))
+
+ @beartype
+ def forward(self, x, mask: Optional[Tensor] = None, cond: Optional[Tensor] = None):
+ maybe_cond_kwargs = dict(cond=cond) if self.need_cond else dict()
+
+ x = self.norm(x, **maybe_cond_kwargs)
+
+ q, k, v = self.to_qkv(x)
+
+ mk, mv = map(lambda t: repeat(t, "h n d -> b h n d", b=q.shape[0]), self.mem_kv)
+ k = torch.cat((mk, k), dim=-2)
+ v = torch.cat((mv, v), dim=-2)
+
+ out = self.attend(q, k, v, mask=mask)
+ return self.to_out(out)
+
+
+class LinearAttention(Module):
+ """
+ using the specific linear attention proposed in https://arxiv.org/abs/2106.09681
+ """
+
+ @beartype
+ def __init__(self, *, dim, dim_cond: Optional[int] = None, dim_head=8, heads=8, dropout=0.0):
+ super().__init__()
+ dim_inner = dim_head * heads
+
+ self.need_cond = exists(dim_cond)
+
+ if self.need_cond:
+ self.norm = AdaptiveRMSNorm(dim, dim_cond=dim_cond)
+ else:
+ self.norm = RMSNorm(dim)
+
+ self.attn = TaylorSeriesLinearAttn(dim=dim, dim_head=dim_head, heads=heads)
+
+ def forward(self, x, cond: Optional[Tensor] = None):
+ maybe_cond_kwargs = dict(cond=cond) if self.need_cond else dict()
+
+ x = self.norm(x, **maybe_cond_kwargs)
+
+ return self.attn(x)
+
+
+class LinearSpaceAttention(LinearAttention):
+ def forward(self, x, *args, **kwargs):
+ x = rearrange(x, "b c ... h w -> b ... h w c")
+ x, batch_ps = pack_one(x, "* h w c")
+ x, seq_ps = pack_one(x, "b * c")
+
+ x = super().forward(x, *args, **kwargs)
+
+ x = unpack_one(x, seq_ps, "b * c")
+ x = unpack_one(x, batch_ps, "* h w c")
+ return rearrange(x, "b ... h w c -> b c ... h w")
+
+
+class SpaceAttention(Attention):
+ def forward(self, x, *args, **kwargs):
+ x = rearrange(x, "b c t h w -> b t h w c")
+ x, batch_ps = pack_one(x, "* h w c")
+ x, seq_ps = pack_one(x, "b * c")
+
+ x = super().forward(x, *args, **kwargs)
+
+ x = unpack_one(x, seq_ps, "b * c")
+ x = unpack_one(x, batch_ps, "* h w c")
+ return rearrange(x, "b t h w c -> b c t h w")
+
+
+class TimeAttention(Attention):
+ def forward(self, x, *args, **kwargs):
+ x = rearrange(x, "b c t h w -> b h w t c")
+ x, batch_ps = pack_one(x, "* t c")
+
+ x = super().forward(x, *args, **kwargs)
+
+ x = unpack_one(x, batch_ps, "* t c")
+ return rearrange(x, "b h w t c -> b c t h w")
+
+
+class GEGLU(Module):
+ def forward(self, x):
+ x, gate = x.chunk(2, dim=1)
+ return F.gelu(gate) * x
+
+
+class FeedForward(Module):
+ @beartype
+ def __init__(self, dim, *, dim_cond: Optional[int] = None, mult=4, images=False):
+ super().__init__()
+ conv_klass = nn.Conv2d if images else nn.Conv3d
+
+ rmsnorm_klass = RMSNorm if not exists(dim_cond) else partial(AdaptiveRMSNorm, dim_cond=dim_cond)
+
+ maybe_adaptive_norm_klass = partial(rmsnorm_klass, channel_first=True, images=images)
+
+ dim_inner = int(dim * mult * 2 / 3)
+
+ self.norm = maybe_adaptive_norm_klass(dim)
+
+ self.net = Sequential(conv_klass(dim, dim_inner * 2, 1), GEGLU(), conv_klass(dim_inner, dim, 1))
+
+ @beartype
+ def forward(self, x: Tensor, *, cond: Optional[Tensor] = None):
+ maybe_cond_kwargs = dict(cond=cond) if exists(cond) else dict()
+
+ x = self.norm(x, **maybe_cond_kwargs)
+ return self.net(x)
+
+
+# discriminator with anti-aliased downsampling (blurpool Zhang et al.)
+
+
+class Blur(Module):
+ def __init__(self):
+ super().__init__()
+ f = torch.Tensor([1, 2, 1])
+ self.register_buffer("f", f)
+
+ def forward(self, x, space_only=False, time_only=False):
+ assert not (space_only and time_only)
+
+ f = self.f
+
+ if space_only:
+ f = einsum("i, j -> i j", f, f)
+ f = rearrange(f, "... -> 1 1 ...")
+ elif time_only:
+ f = rearrange(f, "f -> 1 f 1 1")
+ else:
+ f = einsum("i, j, k -> i j k", f, f, f)
+ f = rearrange(f, "... -> 1 ...")
+
+ is_images = x.ndim == 4
+
+ if is_images:
+ x = rearrange(x, "b c h w -> b c 1 h w")
+
+ out = filter3d(x, f, normalized=True)
+
+ if is_images:
+ out = rearrange(out, "b c 1 h w -> b c h w")
+
+ return out
+
+
+class DiscriminatorBlock(Module):
+ def __init__(self, input_channels, filters, downsample=True, antialiased_downsample=True):
+ super().__init__()
+ self.conv_res = nn.Conv2d(input_channels, filters, 1, stride=(2 if downsample else 1))
+
+ self.net = nn.Sequential(
+ nn.Conv2d(input_channels, filters, 3, padding=1),
+ leaky_relu(),
+ nn.Conv2d(filters, filters, 3, padding=1),
+ leaky_relu(),
+ )
+
+ self.maybe_blur = Blur() if antialiased_downsample else None
+
+ self.downsample = (
+ nn.Sequential(
+ Rearrange("b c (h p1) (w p2) -> b (c p1 p2) h w", p1=2, p2=2), nn.Conv2d(filters * 4, filters, 1)
+ )
+ if downsample
+ else None
+ )
+
+ def forward(self, x):
+ res = self.conv_res(x)
+
+ x = self.net(x)
+
+ if exists(self.downsample):
+ if exists(self.maybe_blur):
+ x = self.maybe_blur(x, space_only=True)
+
+ x = self.downsample(x)
+
+ x = (x + res) * (2**-0.5)
+ return x
+
+
+class Discriminator(Module):
+ @beartype
+ def __init__(
+ self,
+ *,
+ dim,
+ image_size,
+ channels=3,
+ max_dim=512,
+ attn_heads=8,
+ attn_dim_head=32,
+ linear_attn_dim_head=8,
+ linear_attn_heads=16,
+ ff_mult=4,
+ antialiased_downsample=False,
+ ):
+ super().__init__()
+ image_size = pair(image_size)
+ min_image_resolution = min(image_size)
+
+ num_layers = int(log2(min_image_resolution) - 2)
+
+ blocks = []
+
+ layer_dims = [channels] + [(dim * 4) * (2**i) for i in range(num_layers + 1)]
+ layer_dims = [min(layer_dim, max_dim) for layer_dim in layer_dims]
+ layer_dims_in_out = tuple(zip(layer_dims[:-1], layer_dims[1:]))
+
+ blocks = []
+ attn_blocks = []
+
+ image_resolution = min_image_resolution
+
+ for ind, (in_chan, out_chan) in enumerate(layer_dims_in_out):
+ num_layer = ind + 1
+ is_not_last = ind != (len(layer_dims_in_out) - 1)
+
+ block = DiscriminatorBlock(
+ in_chan, out_chan, downsample=is_not_last, antialiased_downsample=antialiased_downsample
+ )
+
+ attn_block = Sequential(
+ Residual(LinearSpaceAttention(dim=out_chan, heads=linear_attn_heads, dim_head=linear_attn_dim_head)),
+ Residual(FeedForward(dim=out_chan, mult=ff_mult, images=True)),
+ )
+
+ blocks.append(ModuleList([block, attn_block]))
+
+ image_resolution //= 2
+
+ self.blocks = ModuleList(blocks)
+
+ dim_last = layer_dims[-1]
+
+ downsample_factor = 2**num_layers
+ last_fmap_size = tuple(map(lambda n: n // downsample_factor, image_size))
+
+ latent_dim = last_fmap_size[0] * last_fmap_size[1] * dim_last
+
+ self.to_logits = Sequential(
+ nn.Conv2d(dim_last, dim_last, 3, padding=1),
+ leaky_relu(),
+ Rearrange("b ... -> b (...)"),
+ nn.Linear(latent_dim, 1),
+ Rearrange("b 1 -> b"),
+ )
+
+ def forward(self, x):
+ for block, attn_block in self.blocks:
+ x = block(x)
+ x = attn_block(x)
+
+ return self.to_logits(x)
+
+
+# modulatable conv from Karras et al. Stylegan2
+# for conditioning on latents
+
+
+class Conv3DMod(Module):
+ @beartype
+ def __init__(
+ self, dim, *, spatial_kernel, time_kernel, causal=True, dim_out=None, demod=True, eps=1e-8, pad_mode="zeros"
+ ):
+ super().__init__()
+ dim_out = default(dim_out, dim)
+
+ self.eps = eps
+
+ assert is_odd(spatial_kernel) and is_odd(time_kernel)
+
+ self.spatial_kernel = spatial_kernel
+ self.time_kernel = time_kernel
+
+ time_padding = (time_kernel - 1, 0) if causal else ((time_kernel // 2,) * 2)
+
+ self.pad_mode = pad_mode
+ self.padding = (*((spatial_kernel // 2,) * 4), *time_padding)
+ self.weights = nn.Parameter(torch.randn((dim_out, dim, time_kernel, spatial_kernel, spatial_kernel)))
+
+ self.demod = demod
+
+ nn.init.kaiming_normal_(self.weights, a=0, mode="fan_in", nonlinearity="selu")
+
+ @beartype
+ def forward(self, fmap, cond: Tensor):
+ """
+ notation
+
+ b - batch
+ n - convs
+ o - output
+ i - input
+ k - kernel
+ """
+
+ b = fmap.shape[0]
+
+ # prepare weights for modulation
+
+ weights = self.weights
+
+ # do the modulation, demodulation, as done in stylegan2
+
+ cond = rearrange(cond, "b i -> b 1 i 1 1 1")
+
+ weights = weights * (cond + 1)
+
+ if self.demod:
+ inv_norm = reduce(weights**2, "b o i k0 k1 k2 -> b o 1 1 1 1", "sum").clamp(min=self.eps).rsqrt()
+ weights = weights * inv_norm
+
+ fmap = rearrange(fmap, "b c t h w -> 1 (b c) t h w")
+
+ weights = rearrange(weights, "b o ... -> (b o) ...")
+
+ fmap = F.pad(fmap, self.padding, mode=self.pad_mode)
+ fmap = F.conv3d(fmap, weights, groups=b)
+
+ return rearrange(fmap, "1 (b o) ... -> b o ...", b=b)
+
+
+# strided conv downsamples
+
+
+class SpatialDownsample2x(Module):
+ def __init__(self, dim, dim_out=None, kernel_size=3, antialias=False):
+ super().__init__()
+ dim_out = default(dim_out, dim)
+ self.maybe_blur = Blur() if antialias else identity
+ self.conv = nn.Conv2d(dim, dim_out, kernel_size, stride=2, padding=kernel_size // 2)
+
+ def forward(self, x):
+ x = self.maybe_blur(x, space_only=True)
+
+ x = rearrange(x, "b c t h w -> b t c h w")
+ x, ps = pack_one(x, "* c h w")
+
+ out = self.conv(x)
+
+ out = unpack_one(out, ps, "* c h w")
+ out = rearrange(out, "b t c h w -> b c t h w")
+ return out
+
+
+class TimeDownsample2x(Module):
+ def __init__(self, dim, dim_out=None, kernel_size=3, antialias=False):
+ super().__init__()
+ dim_out = default(dim_out, dim)
+ self.maybe_blur = Blur() if antialias else identity
+ self.time_causal_padding = (kernel_size - 1, 0)
+ self.conv = nn.Conv1d(dim, dim_out, kernel_size, stride=2)
+
+ def forward(self, x):
+ x = self.maybe_blur(x, time_only=True)
+
+ x = rearrange(x, "b c t h w -> b h w c t")
+ x, ps = pack_one(x, "* c t")
+
+ x = F.pad(x, self.time_causal_padding)
+ out = self.conv(x)
+
+ out = unpack_one(out, ps, "* c t")
+ out = rearrange(out, "b h w c t -> b c t h w")
+ return out
+
+
+# depth to space upsamples
+
+
+class SpatialUpsample2x(Module):
+ def __init__(self, dim, dim_out=None):
+ super().__init__()
+ dim_out = default(dim_out, dim)
+ conv = nn.Conv2d(dim, dim_out * 4, 1)
+
+ self.net = nn.Sequential(conv, nn.SiLU(), Rearrange("b (c p1 p2) h w -> b c (h p1) (w p2)", p1=2, p2=2))
+
+ self.init_conv_(conv)
+
+ def init_conv_(self, conv):
+ o, i, h, w = conv.weight.shape
+ conv_weight = torch.empty(o // 4, i, h, w)
+ nn.init.kaiming_uniform_(conv_weight)
+ conv_weight = repeat(conv_weight, "o ... -> (o 4) ...")
+
+ conv.weight.data.copy_(conv_weight)
+ nn.init.zeros_(conv.bias.data)
+
+ def forward(self, x):
+ x = rearrange(x, "b c t h w -> b t c h w")
+ x, ps = pack_one(x, "* c h w")
+
+ out = self.net(x)
+
+ out = unpack_one(out, ps, "* c h w")
+ out = rearrange(out, "b t c h w -> b c t h w")
+ return out
+
+
+class TimeUpsample2x(Module):
+ def __init__(self, dim, dim_out=None):
+ super().__init__()
+ dim_out = default(dim_out, dim)
+ conv = nn.Conv1d(dim, dim_out * 2, 1)
+
+ self.net = nn.Sequential(conv, nn.SiLU(), Rearrange("b (c p) t -> b c (t p)", p=2))
+
+ self.init_conv_(conv)
+
+ def init_conv_(self, conv):
+ o, i, t = conv.weight.shape
+ conv_weight = torch.empty(o // 2, i, t)
+ nn.init.kaiming_uniform_(conv_weight)
+ conv_weight = repeat(conv_weight, "o ... -> (o 2) ...")
+
+ conv.weight.data.copy_(conv_weight)
+ nn.init.zeros_(conv.bias.data)
+
+ def forward(self, x):
+ x = rearrange(x, "b c t h w -> b h w c t")
+ x, ps = pack_one(x, "* c t")
+
+ out = self.net(x)
+
+ out = unpack_one(out, ps, "* c t")
+ out = rearrange(out, "b h w c t -> b c t h w")
+ return out
+
+
+# autoencoder - only best variant here offered, with causal conv 3d
+
+
+def SameConv2d(dim_in, dim_out, kernel_size):
+ kernel_size = cast_tuple(kernel_size, 2)
+ padding = [k // 2 for k in kernel_size]
+ return nn.Conv2d(dim_in, dim_out, kernel_size=kernel_size, padding=padding)
+
+
+class CausalConv3d(Module):
+ @beartype
+ def __init__(
+ self, chan_in, chan_out, kernel_size: Union[int, Tuple[int, int, int]], pad_mode="constant", **kwargs
+ ):
+ super().__init__()
+ kernel_size = cast_tuple(kernel_size, 3)
+
+ time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
+
+ assert is_odd(height_kernel_size) and is_odd(width_kernel_size)
+
+ dilation = kwargs.pop("dilation", 1)
+ stride = kwargs.pop("stride", 1)
+
+ self.pad_mode = pad_mode
+ time_pad = dilation * (time_kernel_size - 1) + (1 - stride)
+ height_pad = height_kernel_size // 2
+ width_pad = width_kernel_size // 2
+
+ self.time_pad = time_pad
+ self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0)
+
+ stride = (stride, 1, 1)
+ dilation = (dilation, 1, 1)
+ self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
+
+ def forward(self, x):
+ pad_mode = self.pad_mode if self.time_pad < x.shape[2] else "constant"
+
+ x = F.pad(x, self.time_causal_padding, mode=pad_mode)
+ return self.conv(x)
+
+
+@beartype
+def ResidualUnit(dim, kernel_size: Union[int, Tuple[int, int, int]], pad_mode: str = "constant"):
+ net = Sequential(
+ CausalConv3d(dim, dim, kernel_size, pad_mode=pad_mode),
+ nn.ELU(),
+ nn.Conv3d(dim, dim, 1),
+ nn.ELU(),
+ SqueezeExcite(dim),
+ )
+
+ return Residual(net)
+
+
+@beartype
+class ResidualUnitMod(Module):
+ def __init__(
+ self, dim, kernel_size: Union[int, Tuple[int, int, int]], *, dim_cond, pad_mode: str = "constant", demod=True
+ ):
+ super().__init__()
+ kernel_size = cast_tuple(kernel_size, 3)
+ time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
+ assert height_kernel_size == width_kernel_size
+
+ self.to_cond = nn.Linear(dim_cond, dim)
+
+ self.conv = Conv3DMod(
+ dim=dim,
+ spatial_kernel=height_kernel_size,
+ time_kernel=time_kernel_size,
+ causal=True,
+ demod=demod,
+ pad_mode=pad_mode,
+ )
+
+ self.conv_out = nn.Conv3d(dim, dim, 1)
+
+ @beartype
+ def forward(
+ self,
+ x,
+ cond: Tensor,
+ ):
+ res = x
+ cond = self.to_cond(cond)
+
+ x = self.conv(x, cond=cond)
+ x = F.elu(x)
+ x = self.conv_out(x)
+ x = F.elu(x)
+ return x + res
+
+
+class CausalConvTranspose3d(Module):
+ def __init__(self, chan_in, chan_out, kernel_size: Union[int, Tuple[int, int, int]], *, time_stride, **kwargs):
+ super().__init__()
+ kernel_size = cast_tuple(kernel_size, 3)
+
+ time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
+
+ assert is_odd(height_kernel_size) and is_odd(width_kernel_size)
+
+ self.upsample_factor = time_stride
+
+ height_pad = height_kernel_size // 2
+ width_pad = width_kernel_size // 2
+
+ stride = (time_stride, 1, 1)
+ padding = (0, height_pad, width_pad)
+
+ self.conv = nn.ConvTranspose3d(chan_in, chan_out, kernel_size, stride, padding=padding, **kwargs)
+
+ def forward(self, x):
+ assert x.ndim == 5
+ t = x.shape[2]
+
+ out = self.conv(x)
+
+ out = out[..., : (t * self.upsample_factor), :, :]
+ return out
+
+
+# video tokenizer class
+
+LossBreakdown = namedtuple(
+ "LossBreakdown",
+ [
+ "recon_loss",
+ "lfq_aux_loss",
+ "quantizer_loss_breakdown",
+ "perceptual_loss",
+ "adversarial_gen_loss",
+ "adaptive_adversarial_weight",
+ "multiscale_gen_losses",
+ "multiscale_gen_adaptive_weights",
+ ],
+)
+
+DiscrLossBreakdown = namedtuple("DiscrLossBreakdown", ["discr_loss", "multiscale_discr_losses", "gradient_penalty"])
+
+
+class VideoTokenizer(Module):
+ @beartype
+ def __init__(
+ self,
+ *,
+ image_size,
+ layers: Tuple[Union[str, Tuple[str, int]], ...] = ("residual", "residual", "residual"),
+ residual_conv_kernel_size=3,
+ num_codebooks=1,
+ codebook_size: Optional[int] = None,
+ channels=3,
+ init_dim=64,
+ max_dim=float("inf"),
+ dim_cond=None,
+ dim_cond_expansion_factor=4.0,
+ input_conv_kernel_size: Tuple[int, int, int] = (7, 7, 7),
+ output_conv_kernel_size: Tuple[int, int, int] = (3, 3, 3),
+ pad_mode: str = "constant",
+ lfq_entropy_loss_weight=0.1,
+ lfq_commitment_loss_weight=1.0,
+ lfq_diversity_gamma=2.5,
+ quantizer_aux_loss_weight=1.0,
+ lfq_activation=nn.Identity(),
+ use_fsq=False,
+ fsq_levels: Optional[List[int]] = None,
+ attn_dim_head=32,
+ attn_heads=8,
+ attn_dropout=0.0,
+ linear_attn_dim_head=8,
+ linear_attn_heads=16,
+ vgg: Optional[Module] = None,
+ vgg_weights: VGG16_Weights = VGG16_Weights.DEFAULT,
+ perceptual_loss_weight=1e-1,
+ discr_kwargs: Optional[dict] = None,
+ multiscale_discrs: Tuple[Module, ...] = tuple(),
+ use_gan=True,
+ adversarial_loss_weight=1.0,
+ grad_penalty_loss_weight=10.0,
+ multiscale_adversarial_loss_weight=1.0,
+ flash_attn=True,
+ separate_first_frame_encoding=False,
+ ):
+ super().__init__()
+
+ # for autosaving the config
+
+ _locals = locals()
+ _locals.pop("self", None)
+ _locals.pop("__class__", None)
+ self._configs = pickle.dumps(_locals)
+
+ # image size
+
+ self.channels = channels
+ self.image_size = image_size
+
+ # initial encoder
+
+ self.conv_in = CausalConv3d(channels, init_dim, input_conv_kernel_size, pad_mode=pad_mode)
+
+ # whether to encode the first frame separately or not
+
+ self.conv_in_first_frame = nn.Identity()
+ self.conv_out_first_frame = nn.Identity()
+
+ if separate_first_frame_encoding:
+ self.conv_in_first_frame = SameConv2d(channels, init_dim, input_conv_kernel_size[-2:])
+ self.conv_out_first_frame = SameConv2d(init_dim, channels, output_conv_kernel_size[-2:])
+
+ self.separate_first_frame_encoding = separate_first_frame_encoding
+
+ # encoder and decoder layers
+
+ self.encoder_layers = ModuleList([])
+ self.decoder_layers = ModuleList([])
+
+ self.conv_out = CausalConv3d(init_dim, channels, output_conv_kernel_size, pad_mode=pad_mode)
+
+ dim = init_dim
+ dim_out = dim
+
+ layer_fmap_size = image_size
+ time_downsample_factor = 1
+ has_cond_across_layers = []
+
+ for layer_def in layers:
+ layer_type, *layer_params = cast_tuple(layer_def)
+
+ has_cond = False
+
+ if layer_type == "residual":
+ encoder_layer = ResidualUnit(dim, residual_conv_kernel_size)
+ decoder_layer = ResidualUnit(dim, residual_conv_kernel_size)
+
+ elif layer_type == "consecutive_residual":
+ (num_consecutive,) = layer_params
+ encoder_layer = Sequential(
+ *[ResidualUnit(dim, residual_conv_kernel_size) for _ in range(num_consecutive)]
+ )
+ decoder_layer = Sequential(
+ *[ResidualUnit(dim, residual_conv_kernel_size) for _ in range(num_consecutive)]
+ )
+
+ elif layer_type == "cond_residual":
+ assert exists(
+ dim_cond
+ ), "dim_cond must be passed into VideoTokenizer, if tokenizer is to be conditioned"
+
+ has_cond = True
+
+ encoder_layer = ResidualUnitMod(
+ dim, residual_conv_kernel_size, dim_cond=int(dim_cond * dim_cond_expansion_factor)
+ )
+ decoder_layer = ResidualUnitMod(
+ dim, residual_conv_kernel_size, dim_cond=int(dim_cond * dim_cond_expansion_factor)
+ )
+ dim_out = dim
+
+ elif layer_type == "compress_space":
+ dim_out = safe_get_index(layer_params, 0)
+ dim_out = default(dim_out, dim * 2)
+ dim_out = min(dim_out, max_dim)
+
+ encoder_layer = SpatialDownsample2x(dim, dim_out)
+ decoder_layer = SpatialUpsample2x(dim_out, dim)
+
+ assert layer_fmap_size > 1
+ layer_fmap_size //= 2
+
+ elif layer_type == "compress_time":
+ dim_out = safe_get_index(layer_params, 0)
+ dim_out = default(dim_out, dim * 2)
+ dim_out = min(dim_out, max_dim)
+
+ encoder_layer = TimeDownsample2x(dim, dim_out)
+ decoder_layer = TimeUpsample2x(dim_out, dim)
+
+ time_downsample_factor *= 2
+
+ elif layer_type == "attend_space":
+ attn_kwargs = dict(
+ dim=dim, dim_head=attn_dim_head, heads=attn_heads, dropout=attn_dropout, flash=flash_attn
+ )
+
+ encoder_layer = Sequential(Residual(SpaceAttention(**attn_kwargs)), Residual(FeedForward(dim)))
+
+ decoder_layer = Sequential(Residual(SpaceAttention(**attn_kwargs)), Residual(FeedForward(dim)))
+
+ elif layer_type == "linear_attend_space":
+ linear_attn_kwargs = dict(dim=dim, dim_head=linear_attn_dim_head, heads=linear_attn_heads)
+
+ encoder_layer = Sequential(
+ Residual(LinearSpaceAttention(**linear_attn_kwargs)), Residual(FeedForward(dim))
+ )
+
+ decoder_layer = Sequential(
+ Residual(LinearSpaceAttention(**linear_attn_kwargs)), Residual(FeedForward(dim))
+ )
+
+ elif layer_type == "gateloop_time":
+ gateloop_kwargs = dict(use_heinsen=False)
+
+ encoder_layer = ToTimeSequence(Residual(SimpleGateLoopLayer(dim=dim)))
+ decoder_layer = ToTimeSequence(Residual(SimpleGateLoopLayer(dim=dim)))
+
+ elif layer_type == "attend_time":
+ attn_kwargs = dict(
+ dim=dim,
+ dim_head=attn_dim_head,
+ heads=attn_heads,
+ dropout=attn_dropout,
+ causal=True,
+ flash=flash_attn,
+ )
+
+ encoder_layer = Sequential(
+ Residual(TokenShift(TimeAttention(**attn_kwargs))),
+ Residual(TokenShift(FeedForward(dim, dim_cond=dim_cond))),
+ )
+
+ decoder_layer = Sequential(
+ Residual(TokenShift(TimeAttention(**attn_kwargs))),
+ Residual(TokenShift(FeedForward(dim, dim_cond=dim_cond))),
+ )
+
+ elif layer_type == "cond_attend_space":
+ has_cond = True
+
+ attn_kwargs = dict(
+ dim=dim,
+ dim_cond=dim_cond,
+ dim_head=attn_dim_head,
+ heads=attn_heads,
+ dropout=attn_dropout,
+ flash=flash_attn,
+ )
+
+ encoder_layer = Sequential(Residual(SpaceAttention(**attn_kwargs)), Residual(FeedForward(dim)))
+
+ decoder_layer = Sequential(Residual(SpaceAttention(**attn_kwargs)), Residual(FeedForward(dim)))
+
+ elif layer_type == "cond_linear_attend_space":
+ has_cond = True
+
+ attn_kwargs = dict(
+ dim=dim,
+ dim_cond=dim_cond,
+ dim_head=attn_dim_head,
+ heads=attn_heads,
+ dropout=attn_dropout,
+ flash=flash_attn,
+ )
+
+ encoder_layer = Sequential(
+ Residual(LinearSpaceAttention(**attn_kwargs)), Residual(FeedForward(dim, dim_cond=dim_cond))
+ )
+
+ decoder_layer = Sequential(
+ Residual(LinearSpaceAttention(**attn_kwargs)), Residual(FeedForward(dim, dim_cond=dim_cond))
+ )
+
+ elif layer_type == "cond_attend_time":
+ has_cond = True
+
+ attn_kwargs = dict(
+ dim=dim,
+ dim_cond=dim_cond,
+ dim_head=attn_dim_head,
+ heads=attn_heads,
+ dropout=attn_dropout,
+ causal=True,
+ flash=flash_attn,
+ )
+
+ encoder_layer = Sequential(
+ Residual(TokenShift(TimeAttention(**attn_kwargs))),
+ Residual(TokenShift(FeedForward(dim, dim_cond=dim_cond))),
+ )
+
+ decoder_layer = Sequential(
+ Residual(TokenShift(TimeAttention(**attn_kwargs))),
+ Residual(TokenShift(FeedForward(dim, dim_cond=dim_cond))),
+ )
+
+ else:
+ raise ValueError(f"unknown layer type {layer_type}")
+
+ self.encoder_layers.append(encoder_layer)
+ self.decoder_layers.insert(0, decoder_layer)
+
+ dim = dim_out
+ has_cond_across_layers.append(has_cond)
+
+ # add a final norm just before quantization layer
+
+ self.encoder_layers.append(
+ Sequential(
+ Rearrange("b c ... -> b ... c"),
+ nn.LayerNorm(dim),
+ Rearrange("b ... c -> b c ..."),
+ )
+ )
+
+ self.time_downsample_factor = time_downsample_factor
+ self.time_padding = time_downsample_factor - 1
+
+ self.fmap_size = layer_fmap_size
+
+ # use a MLP stem for conditioning, if needed
+
+ self.has_cond_across_layers = has_cond_across_layers
+ self.has_cond = any(has_cond_across_layers)
+
+ self.encoder_cond_in = nn.Identity()
+ self.decoder_cond_in = nn.Identity()
+
+ if has_cond:
+ self.dim_cond = dim_cond
+
+ self.encoder_cond_in = Sequential(
+ nn.Linear(dim_cond, int(dim_cond * dim_cond_expansion_factor)), nn.SiLU()
+ )
+
+ self.decoder_cond_in = Sequential(
+ nn.Linear(dim_cond, int(dim_cond * dim_cond_expansion_factor)), nn.SiLU()
+ )
+
+ # quantizer related
+
+ self.use_fsq = use_fsq
+
+ if not use_fsq:
+ assert exists(codebook_size) and not exists(
+ fsq_levels
+ ), "if use_fsq is set to False, `codebook_size` must be set (and not `fsq_levels`)"
+
+ # lookup free quantizer(s) - multiple codebooks is possible
+ # each codebook will get its own entropy regularization
+
+ self.quantizers = LFQ(
+ dim=dim,
+ codebook_size=codebook_size,
+ num_codebooks=num_codebooks,
+ entropy_loss_weight=lfq_entropy_loss_weight,
+ commitment_loss_weight=lfq_commitment_loss_weight,
+ diversity_gamma=lfq_diversity_gamma,
+ )
+
+ else:
+ assert (
+ not exists(codebook_size) and exists(fsq_levels)
+ ), "if use_fsq is set to True, `fsq_levels` must be set (and not `codebook_size`). the effective codebook size is the cumulative product of all the FSQ levels"
+
+ self.quantizers = FSQ(fsq_levels, dim=dim, num_codebooks=num_codebooks)
+
+ self.quantizer_aux_loss_weight = quantizer_aux_loss_weight
+
+ # dummy loss
+
+ self.register_buffer("zero", torch.tensor(0.0), persistent=False)
+
+ # perceptual loss related
+
+ use_vgg = channels in {1, 3, 4} and perceptual_loss_weight > 0.0
+
+ self.vgg = None
+ self.perceptual_loss_weight = perceptual_loss_weight
+
+ if use_vgg:
+ if not exists(vgg):
+ vgg = torchvision.models.vgg16(weights=vgg_weights)
+
+ vgg.classifier = Sequential(*vgg.classifier[:-2])
+
+ self.vgg = vgg
+
+ self.use_vgg = use_vgg
+
+ # main flag for whether to use GAN at all
+
+ self.use_gan = use_gan
+
+ # discriminator
+
+ discr_kwargs = default(discr_kwargs, dict(dim=dim, image_size=image_size, channels=channels, max_dim=512))
+
+ self.discr = Discriminator(**discr_kwargs)
+
+ self.adversarial_loss_weight = adversarial_loss_weight
+ self.grad_penalty_loss_weight = grad_penalty_loss_weight
+
+ self.has_gan = use_gan and adversarial_loss_weight > 0.0
+
+ # multi-scale discriminators
+
+ self.has_multiscale_gan = use_gan and multiscale_adversarial_loss_weight > 0.0
+
+ self.multiscale_discrs = ModuleList([*multiscale_discrs])
+
+ self.multiscale_adversarial_loss_weight = multiscale_adversarial_loss_weight
+
+ self.has_multiscale_discrs = (
+ use_gan and multiscale_adversarial_loss_weight > 0.0 and len(multiscale_discrs) > 0
+ )
+
+ @property
+ def device(self):
+ return self.zero.device
+
+ @classmethod
+ def init_and_load_from(cls, path, strict=True):
+ path = Path(path)
+ assert path.exists()
+ pkg = torch.load(str(path), map_location="cpu")
+
+ assert "config" in pkg, "model configs were not found in this saved checkpoint"
+
+ config = pickle.loads(pkg["config"])
+ tokenizer = cls(**config)
+ tokenizer.load(path, strict=strict)
+ return tokenizer
+
+ def parameters(self):
+ return [
+ *self.conv_in.parameters(),
+ *self.conv_in_first_frame.parameters(),
+ *self.conv_out_first_frame.parameters(),
+ *self.conv_out.parameters(),
+ *self.encoder_layers.parameters(),
+ *self.decoder_layers.parameters(),
+ *self.encoder_cond_in.parameters(),
+ *self.decoder_cond_in.parameters(),
+ *self.quantizers.parameters(),
+ ]
+
+ def discr_parameters(self):
+ return self.discr.parameters()
+
+ def copy_for_eval(self):
+ device = self.device
+ vae_copy = copy.deepcopy(self.cpu())
+
+ maybe_del_attr_(vae_copy, "discr")
+ maybe_del_attr_(vae_copy, "vgg")
+ maybe_del_attr_(vae_copy, "multiscale_discrs")
+
+ vae_copy.eval()
+ return vae_copy.to(device)
+
+ @remove_vgg
+ def state_dict(self, *args, **kwargs):
+ return super().state_dict(*args, **kwargs)
+
+ @remove_vgg
+ def load_state_dict(self, *args, **kwargs):
+ return super().load_state_dict(*args, **kwargs)
+
+ def save(self, path, overwrite=True):
+ path = Path(path)
+ assert overwrite or not path.exists(), f"{str(path)} already exists"
+
+ pkg = dict(model_state_dict=self.state_dict(), version=__version__, config=self._configs)
+
+ torch.save(pkg, str(path))
+
+ def load(self, path, strict=True):
+ path = Path(path)
+ assert path.exists()
+
+ pkg = torch.load(str(path))
+ state_dict = pkg.get("model_state_dict")
+ version = pkg.get("version")
+
+ assert exists(state_dict)
+
+ if exists(version):
+ print(f"loading checkpointed tokenizer from version {version}")
+
+ self.load_state_dict(state_dict, strict=strict)
+
+ @beartype
+ def encode(self, video: Tensor, quantize=False, cond: Optional[Tensor] = None, video_contains_first_frame=True):
+ encode_first_frame_separately = self.separate_first_frame_encoding and video_contains_first_frame
+
+ # whether to pad video or not
+
+ if video_contains_first_frame:
+ video_len = video.shape[2]
+
+ video = pad_at_dim(video, (self.time_padding, 0), value=0.0, dim=2)
+ video_packed_shape = [torch.Size([self.time_padding]), torch.Size([]), torch.Size([video_len - 1])]
+
+ # conditioning, if needed
+
+ assert (not self.has_cond) or exists(
+ cond
+ ), "`cond` must be passed into tokenizer forward method since conditionable layers were specified"
+
+ if exists(cond):
+ assert cond.shape == (video.shape[0], self.dim_cond)
+
+ cond = self.encoder_cond_in(cond)
+ cond_kwargs = dict(cond=cond)
+
+ # initial conv
+ # taking into account whether to encode first frame separately
+
+ if encode_first_frame_separately:
+ pad, first_frame, video = unpack(video, video_packed_shape, "b c * h w")
+ first_frame = self.conv_in_first_frame(first_frame)
+
+ video = self.conv_in(video)
+
+ if encode_first_frame_separately:
+ video, _ = pack([first_frame, video], "b c * h w")
+ video = pad_at_dim(video, (self.time_padding, 0), dim=2)
+
+ # encoder layers
+
+ for fn, has_cond in zip(self.encoder_layers, self.has_cond_across_layers):
+ layer_kwargs = dict()
+
+ if has_cond:
+ layer_kwargs = cond_kwargs
+
+ video = fn(video, **layer_kwargs)
+
+ maybe_quantize = identity if not quantize else self.quantizers
+
+ return maybe_quantize(video)
+
+ @beartype
+ def decode_from_code_indices(self, codes: Tensor, cond: Optional[Tensor] = None, video_contains_first_frame=True):
+ assert codes.dtype in (torch.long, torch.int32)
+
+ if codes.ndim == 2:
+ video_code_len = codes.shape[-1]
+ assert divisible_by(
+ video_code_len, self.fmap_size**2
+ ), f"flattened video ids must have a length ({video_code_len}) that is divisible by the fmap size ({self.fmap_size}) squared ({self.fmap_size ** 2})"
+
+ codes = rearrange(codes, "b (f h w) -> b f h w", h=self.fmap_size, w=self.fmap_size)
+
+ quantized = self.quantizers.indices_to_codes(codes)
+
+ return self.decode(quantized, cond=cond, video_contains_first_frame=video_contains_first_frame)
+
+ @beartype
+ def decode(self, quantized: Tensor, cond: Optional[Tensor] = None, video_contains_first_frame=True):
+ decode_first_frame_separately = self.separate_first_frame_encoding and video_contains_first_frame
+
+ batch = quantized.shape[0]
+
+ # conditioning, if needed
+
+ assert (not self.has_cond) or exists(
+ cond
+ ), "`cond` must be passed into tokenizer forward method since conditionable layers were specified"
+
+ if exists(cond):
+ assert cond.shape == (batch, self.dim_cond)
+
+ cond = self.decoder_cond_in(cond)
+ cond_kwargs = dict(cond=cond)
+
+ # decoder layers
+
+ x = quantized
+
+ for fn, has_cond in zip(self.decoder_layers, reversed(self.has_cond_across_layers)):
+ layer_kwargs = dict()
+
+ if has_cond:
+ layer_kwargs = cond_kwargs
+
+ x = fn(x, **layer_kwargs)
+
+ # to pixels
+
+ if decode_first_frame_separately:
+ left_pad, xff, x = (
+ x[:, :, : self.time_padding],
+ x[:, :, self.time_padding],
+ x[:, :, (self.time_padding + 1) :],
+ )
+
+ out = self.conv_out(x)
+ outff = self.conv_out_first_frame(xff)
+
+ video, _ = pack([outff, out], "b c * h w")
+
+ else:
+ video = self.conv_out(x)
+
+ # if video were padded, remove padding
+
+ if video_contains_first_frame:
+ video = video[:, :, self.time_padding :]
+
+ return video
+
+ @torch.no_grad()
+ def tokenize(self, video):
+ self.eval()
+ return self.forward(video, return_codes=True)
+
+ @beartype
+ def forward(
+ self,
+ video_or_images: Tensor,
+ cond: Optional[Tensor] = None,
+ return_loss=False,
+ return_codes=False,
+ return_recon=False,
+ return_discr_loss=False,
+ return_recon_loss_only=False,
+ apply_gradient_penalty=True,
+ video_contains_first_frame=True,
+ adversarial_loss_weight=None,
+ multiscale_adversarial_loss_weight=None,
+ ):
+ adversarial_loss_weight = default(adversarial_loss_weight, self.adversarial_loss_weight)
+ multiscale_adversarial_loss_weight = default(
+ multiscale_adversarial_loss_weight, self.multiscale_adversarial_loss_weight
+ )
+
+ assert (return_loss + return_codes + return_discr_loss) <= 1
+ assert video_or_images.ndim in {4, 5}
+
+ assert video_or_images.shape[-2:] == (self.image_size, self.image_size)
+
+ # accept images for image pretraining (curriculum learning from images to video)
+
+ is_image = video_or_images.ndim == 4
+
+ if is_image:
+ video = rearrange(video_or_images, "b c ... -> b c 1 ...")
+ video_contains_first_frame = True
+ else:
+ video = video_or_images
+
+ batch, channels, frames = video.shape[:3]
+
+ assert divisible_by(
+ frames - int(video_contains_first_frame), self.time_downsample_factor
+ ), f"number of frames {frames} minus the first frame ({frames - int(video_contains_first_frame)}) must be divisible by the total downsample factor across time {self.time_downsample_factor}"
+
+ # encoder
+
+ x = self.encode(video, cond=cond, video_contains_first_frame=video_contains_first_frame)
+
+ # lookup free quantization
+
+ if self.use_fsq:
+ quantized, codes = self.quantizers(x)
+
+ aux_losses = self.zero
+ quantizer_loss_breakdown = None
+ else:
+ (quantized, codes, aux_losses), quantizer_loss_breakdown = self.quantizers(x, return_loss_breakdown=True)
+
+ if return_codes and not return_recon:
+ return codes
+
+ # decoder
+
+ recon_video = self.decode(quantized, cond=cond, video_contains_first_frame=video_contains_first_frame)
+
+ if return_codes:
+ return codes, recon_video
+
+ # reconstruction loss
+
+ if not (return_loss or return_discr_loss or return_recon_loss_only):
+ return recon_video
+
+ recon_loss = F.mse_loss(video, recon_video)
+
+ # for validation, only return recon loss
+
+ if return_recon_loss_only:
+ return recon_loss, recon_video
+
+ # gan discriminator loss
+
+ if return_discr_loss:
+ assert self.has_gan
+ assert exists(self.discr)
+
+ # pick a random frame for image discriminator
+
+ frame_indices = torch.randn((batch, frames)).topk(1, dim=-1).indices
+
+ real = pick_video_frame(video, frame_indices)
+
+ if apply_gradient_penalty:
+ real = real.requires_grad_()
+
+ fake = pick_video_frame(recon_video, frame_indices)
+
+ real_logits = self.discr(real)
+ fake_logits = self.discr(fake.detach())
+
+ discr_loss = hinge_discr_loss(fake_logits, real_logits)
+
+ # multiscale discriminators
+
+ multiscale_discr_losses = []
+
+ if self.has_multiscale_discrs:
+ for discr in self.multiscale_discrs:
+ multiscale_real_logits = discr(video)
+ multiscale_fake_logits = discr(recon_video.detach())
+
+ multiscale_discr_loss = hinge_discr_loss(multiscale_fake_logits, multiscale_real_logits)
+
+ multiscale_discr_losses.append(multiscale_discr_loss)
+ else:
+ multiscale_discr_losses.append(self.zero)
+
+ # gradient penalty
+
+ if apply_gradient_penalty:
+ gradient_penalty_loss = gradient_penalty(real, real_logits)
+ else:
+ gradient_penalty_loss = self.zero
+
+ # total loss
+
+ total_loss = (
+ discr_loss
+ + gradient_penalty_loss * self.grad_penalty_loss_weight
+ + sum(multiscale_discr_losses) * self.multiscale_adversarial_loss_weight
+ )
+
+ discr_loss_breakdown = DiscrLossBreakdown(discr_loss, multiscale_discr_losses, gradient_penalty_loss)
+
+ return total_loss, discr_loss_breakdown
+
+ # perceptual loss
+
+ if self.use_vgg:
+ frame_indices = torch.randn((batch, frames)).topk(1, dim=-1).indices
+
+ input_vgg_input = pick_video_frame(video, frame_indices)
+ recon_vgg_input = pick_video_frame(recon_video, frame_indices)
+
+ if channels == 1:
+ input_vgg_input = repeat(input_vgg_input, "b 1 h w -> b c h w", c=3)
+ recon_vgg_input = repeat(recon_vgg_input, "b 1 h w -> b c h w", c=3)
+
+ elif channels == 4:
+ input_vgg_input = input_vgg_input[:, :3]
+ recon_vgg_input = recon_vgg_input[:, :3]
+
+ input_vgg_feats = self.vgg(input_vgg_input)
+ recon_vgg_feats = self.vgg(recon_vgg_input)
+
+ perceptual_loss = F.mse_loss(input_vgg_feats, recon_vgg_feats)
+ else:
+ perceptual_loss = self.zero
+
+ # get gradient with respect to perceptual loss for last decoder layer
+ # needed for adaptive weighting
+
+ last_dec_layer = self.conv_out.conv.weight
+
+ norm_grad_wrt_perceptual_loss = None
+
+ if self.training and self.use_vgg and (self.has_gan or self.has_multiscale_discrs):
+ norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(perceptual_loss, last_dec_layer).norm(p=2)
+
+ # per-frame image discriminator
+
+ recon_video_frames = None
+
+ if self.has_gan:
+ frame_indices = torch.randn((batch, frames)).topk(1, dim=-1).indices
+ recon_video_frames = pick_video_frame(recon_video, frame_indices)
+
+ fake_logits = self.discr(recon_video_frames)
+ gen_loss = hinge_gen_loss(fake_logits)
+
+ adaptive_weight = 1.0
+
+ if exists(norm_grad_wrt_perceptual_loss):
+ norm_grad_wrt_gen_loss = grad_layer_wrt_loss(gen_loss, last_dec_layer).norm(p=2)
+ adaptive_weight = norm_grad_wrt_perceptual_loss / norm_grad_wrt_gen_loss.clamp(min=1e-3)
+ adaptive_weight.clamp_(max=1e3)
+
+ if torch.isnan(adaptive_weight).any():
+ adaptive_weight = 1.0
+ else:
+ gen_loss = self.zero
+ adaptive_weight = 0.0
+
+ # multiscale discriminator losses
+
+ multiscale_gen_losses = []
+ multiscale_gen_adaptive_weights = []
+
+ if self.has_multiscale_gan and self.has_multiscale_discrs:
+ if not exists(recon_video_frames):
+ recon_video_frames = pick_video_frame(recon_video, frame_indices)
+
+ for discr in self.multiscale_discrs:
+ fake_logits = recon_video_frames
+ multiscale_gen_loss = hinge_gen_loss(fake_logits)
+
+ multiscale_gen_losses.append(multiscale_gen_loss)
+
+ multiscale_adaptive_weight = 1.0
+
+ if exists(norm_grad_wrt_perceptual_loss):
+ norm_grad_wrt_gen_loss = grad_layer_wrt_loss(multiscale_gen_loss, last_dec_layer).norm(p=2)
+ multiscale_adaptive_weight = norm_grad_wrt_perceptual_loss / norm_grad_wrt_gen_loss.clamp(min=1e-5)
+ multiscale_adaptive_weight.clamp_(max=1e3)
+
+ multiscale_gen_adaptive_weights.append(multiscale_adaptive_weight)
+
+ # calculate total loss
+
+ total_loss = (
+ recon_loss
+ + aux_losses * self.quantizer_aux_loss_weight
+ + perceptual_loss * self.perceptual_loss_weight
+ + gen_loss * adaptive_weight * adversarial_loss_weight
+ )
+
+ if self.has_multiscale_discrs:
+ weighted_multiscale_gen_losses = sum(
+ loss * weight for loss, weight in zip(multiscale_gen_losses, multiscale_gen_adaptive_weights)
+ )
+
+ total_loss = total_loss + weighted_multiscale_gen_losses * multiscale_adversarial_loss_weight
+
+ # loss breakdown
+
+ loss_breakdown = LossBreakdown(
+ recon_loss,
+ aux_losses,
+ quantizer_loss_breakdown,
+ perceptual_loss,
+ gen_loss,
+ adaptive_weight,
+ multiscale_gen_losses,
+ multiscale_gen_adaptive_weights,
+ )
+
+ return total_loss, loss_breakdown
+
+
+# main class
+
+
+class MagViT2(Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x):
+ return x
diff --git a/sat/sgm/modules/autoencoding/regularizers/__init__.py b/sat/sgm/modules/autoencoding/regularizers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6065fb209b6cb6fb4e0cb601c895c2a35e0044e9
--- /dev/null
+++ b/sat/sgm/modules/autoencoding/regularizers/__init__.py
@@ -0,0 +1,30 @@
+from abc import abstractmethod
+from typing import Any, Tuple
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ....modules.distributions.distributions import DiagonalGaussianDistribution
+from .base import AbstractRegularizer
+
+
+class DiagonalGaussianRegularizer(AbstractRegularizer):
+ def __init__(self, sample: bool = True):
+ super().__init__()
+ self.sample = sample
+
+ def get_trainable_parameters(self) -> Any:
+ yield from ()
+
+ def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
+ log = dict()
+ posterior = DiagonalGaussianDistribution(z)
+ if self.sample:
+ z = posterior.sample()
+ else:
+ z = posterior.mode()
+ kl_loss = posterior.kl()
+ kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
+ log["kl_loss"] = kl_loss
+ return z, log
diff --git a/sat/sgm/modules/autoencoding/regularizers/base.py b/sat/sgm/modules/autoencoding/regularizers/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..7f455be98a6f1b5d8647b423de6c3aaeb24d3e23
--- /dev/null
+++ b/sat/sgm/modules/autoencoding/regularizers/base.py
@@ -0,0 +1,36 @@
+from abc import abstractmethod
+from typing import Any, Tuple
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+
+class AbstractRegularizer(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
+ raise NotImplementedError()
+
+ @abstractmethod
+ def get_trainable_parameters(self) -> Any:
+ raise NotImplementedError()
+
+
+class IdentityRegularizer(AbstractRegularizer):
+ def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
+ return z, dict()
+
+ def get_trainable_parameters(self) -> Any:
+ yield from ()
+
+
+def measure_perplexity(predicted_indices: torch.Tensor, num_centroids: int) -> Tuple[torch.Tensor, torch.Tensor]:
+ # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
+ # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
+ encodings = F.one_hot(predicted_indices, num_centroids).float().reshape(-1, num_centroids)
+ avg_probs = encodings.mean(0)
+ perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
+ cluster_use = torch.sum(avg_probs > 0)
+ return perplexity, cluster_use
diff --git a/sat/sgm/modules/autoencoding/regularizers/finite_scalar_quantization.py b/sat/sgm/modules/autoencoding/regularizers/finite_scalar_quantization.py
new file mode 100644
index 0000000000000000000000000000000000000000..5a20dd63ef18259ac4242438f2c1a393e5ef938d
--- /dev/null
+++ b/sat/sgm/modules/autoencoding/regularizers/finite_scalar_quantization.py
@@ -0,0 +1,180 @@
+"""
+Finite Scalar Quantization: VQ-VAE Made Simple - https://arxiv.org/abs/2309.15505
+Code adapted from Jax version in Appendix A.1
+"""
+
+from typing import List, Optional
+
+import torch
+import torch.nn as nn
+from torch.nn import Module
+from torch import Tensor, int32
+from torch.cuda.amp import autocast
+
+from einops import rearrange, pack, unpack
+
+# helper functions
+
+
+def exists(v):
+ return v is not None
+
+
+def default(*args):
+ for arg in args:
+ if exists(arg):
+ return arg
+ return None
+
+
+def pack_one(t, pattern):
+ return pack([t], pattern)
+
+
+def unpack_one(t, ps, pattern):
+ return unpack(t, ps, pattern)[0]
+
+
+# tensor helpers
+
+
+def round_ste(z: Tensor) -> Tensor:
+ """Round with straight through gradients."""
+ zhat = z.round()
+ return z + (zhat - z).detach()
+
+
+# main class
+
+
+class FSQ(Module):
+ def __init__(
+ self,
+ levels: List[int],
+ dim: Optional[int] = None,
+ num_codebooks=1,
+ keep_num_codebooks_dim: Optional[bool] = None,
+ scale: Optional[float] = None,
+ ):
+ super().__init__()
+ _levels = torch.tensor(levels, dtype=int32)
+ self.register_buffer("_levels", _levels, persistent=False)
+
+ _basis = torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=int32)
+ self.register_buffer("_basis", _basis, persistent=False)
+
+ self.scale = scale
+
+ codebook_dim = len(levels)
+ self.codebook_dim = codebook_dim
+
+ effective_codebook_dim = codebook_dim * num_codebooks
+ self.num_codebooks = num_codebooks
+ self.effective_codebook_dim = effective_codebook_dim
+
+ keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1)
+ assert not (num_codebooks > 1 and not keep_num_codebooks_dim)
+ self.keep_num_codebooks_dim = keep_num_codebooks_dim
+
+ self.dim = default(dim, len(_levels) * num_codebooks)
+
+ has_projections = self.dim != effective_codebook_dim
+ self.project_in = nn.Linear(self.dim, effective_codebook_dim) if has_projections else nn.Identity()
+ self.project_out = nn.Linear(effective_codebook_dim, self.dim) if has_projections else nn.Identity()
+ self.has_projections = has_projections
+
+ self.codebook_size = self._levels.prod().item()
+
+ implicit_codebook = self.indices_to_codes(torch.arange(self.codebook_size), project_out=False)
+ self.register_buffer("implicit_codebook", implicit_codebook, persistent=False)
+
+ def bound(self, z: Tensor, eps: float = 1e-3) -> Tensor:
+ """Bound `z`, an array of shape (..., d)."""
+ half_l = (self._levels - 1) * (1 + eps) / 2
+ offset = torch.where(self._levels % 2 == 0, 0.5, 0.0)
+ shift = (offset / half_l).atanh()
+ return (z + shift).tanh() * half_l - offset
+
+ def quantize(self, z: Tensor) -> Tensor:
+ """Quantizes z, returns quantized zhat, same shape as z."""
+ quantized = round_ste(self.bound(z))
+ half_width = self._levels // 2 # Renormalize to [-1, 1].
+ return quantized / half_width
+
+ def _scale_and_shift(self, zhat_normalized: Tensor) -> Tensor:
+ half_width = self._levels // 2
+ return (zhat_normalized * half_width) + half_width
+
+ def _scale_and_shift_inverse(self, zhat: Tensor) -> Tensor:
+ half_width = self._levels // 2
+ return (zhat - half_width) / half_width
+
+ def codes_to_indices(self, zhat: Tensor) -> Tensor:
+ """Converts a `code` to an index in the codebook."""
+ assert zhat.shape[-1] == self.codebook_dim
+ zhat = self._scale_and_shift(zhat)
+ return (zhat * self._basis).sum(dim=-1).to(int32)
+
+ def indices_to_codes(self, indices: Tensor, project_out=True) -> Tensor:
+ """Inverse of `codes_to_indices`."""
+
+ is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim))
+
+ indices = rearrange(indices, "... -> ... 1")
+ codes_non_centered = (indices // self._basis) % self._levels
+ codes = self._scale_and_shift_inverse(codes_non_centered)
+
+ if self.keep_num_codebooks_dim:
+ codes = rearrange(codes, "... c d -> ... (c d)")
+
+ if project_out:
+ codes = self.project_out(codes)
+
+ if is_img_or_video:
+ codes = rearrange(codes, "b ... d -> b d ...")
+
+ return codes
+
+ @autocast(enabled=False)
+ def forward(self, z: Tensor) -> Tensor:
+ """
+ einstein notation
+ b - batch
+ n - sequence (or flattened spatial dimensions)
+ d - feature dimension
+ c - number of codebook dim
+ """
+
+ is_img_or_video = z.ndim >= 4
+
+ # standardize image or video into (batch, seq, dimension)
+
+ if is_img_or_video:
+ z = rearrange(z, "b d ... -> b ... d")
+ z, ps = pack_one(z, "b * d")
+
+ assert z.shape[-1] == self.dim, f"expected dimension of {self.dim} but found dimension of {z.shape[-1]}"
+
+ z = self.project_in(z)
+
+ z = rearrange(z, "b n (c d) -> b n c d", c=self.num_codebooks)
+
+ codes = self.quantize(z)
+ indices = self.codes_to_indices(codes)
+
+ codes = rearrange(codes, "b n c d -> b n (c d)")
+
+ out = self.project_out(codes)
+
+ # reconstitute image or video dimensions
+
+ if is_img_or_video:
+ out = unpack_one(out, ps, "b * d")
+ out = rearrange(out, "b ... d -> b d ...")
+
+ indices = unpack_one(indices, ps, "b * c")
+
+ if not self.keep_num_codebooks_dim:
+ indices = rearrange(indices, "... 1 -> ...")
+
+ return out, indices
diff --git a/sat/sgm/modules/autoencoding/regularizers/lookup_free_quantization.py b/sat/sgm/modules/autoencoding/regularizers/lookup_free_quantization.py
new file mode 100644
index 0000000000000000000000000000000000000000..beca8884d284bd62bca9e6b4bfd137b07674362e
--- /dev/null
+++ b/sat/sgm/modules/autoencoding/regularizers/lookup_free_quantization.py
@@ -0,0 +1,309 @@
+"""
+Lookup Free Quantization
+Proposed in https://arxiv.org/abs/2310.05737
+
+In the simplest setup, each dimension is quantized into {-1, 1}.
+An entropy penalty is used to encourage utilization.
+"""
+
+from math import log2, ceil
+from collections import namedtuple
+
+import torch
+from torch import nn, einsum
+import torch.nn.functional as F
+from torch.nn import Module
+from torch.cuda.amp import autocast
+
+from einops import rearrange, reduce, pack, unpack
+
+# constants
+
+Return = namedtuple("Return", ["quantized", "indices", "entropy_aux_loss"])
+
+LossBreakdown = namedtuple("LossBreakdown", ["per_sample_entropy", "batch_entropy", "commitment"])
+
+# helper functions
+
+
+def exists(v):
+ return v is not None
+
+
+def default(*args):
+ for arg in args:
+ if exists(arg):
+ return arg() if callable(arg) else arg
+ return None
+
+
+def pack_one(t, pattern):
+ return pack([t], pattern)
+
+
+def unpack_one(t, ps, pattern):
+ return unpack(t, ps, pattern)[0]
+
+
+# entropy
+
+
+def log(t, eps=1e-5):
+ return t.clamp(min=eps).log()
+
+
+def entropy(prob):
+ return (-prob * log(prob)).sum(dim=-1)
+
+
+# class
+
+
+class LFQ(Module):
+ def __init__(
+ self,
+ *,
+ dim=None,
+ codebook_size=None,
+ entropy_loss_weight=0.1,
+ commitment_loss_weight=0.25,
+ diversity_gamma=1.0,
+ straight_through_activation=nn.Identity(),
+ num_codebooks=1,
+ keep_num_codebooks_dim=None,
+ codebook_scale=1.0, # for residual LFQ, codebook scaled down by 2x at each layer
+ frac_per_sample_entropy=1.0, # make less than 1. to only use a random fraction of the probs for per sample entropy
+ ):
+ super().__init__()
+
+ # some assert validations
+
+ assert exists(dim) or exists(codebook_size), "either dim or codebook_size must be specified for LFQ"
+ assert (
+ not exists(codebook_size) or log2(codebook_size).is_integer()
+ ), f"your codebook size must be a power of 2 for lookup free quantization (suggested {2 ** ceil(log2(codebook_size))})"
+
+ codebook_size = default(codebook_size, lambda: 2**dim)
+ codebook_dim = int(log2(codebook_size))
+
+ codebook_dims = codebook_dim * num_codebooks
+ dim = default(dim, codebook_dims)
+
+ has_projections = dim != codebook_dims
+ self.project_in = nn.Linear(dim, codebook_dims) if has_projections else nn.Identity()
+ self.project_out = nn.Linear(codebook_dims, dim) if has_projections else nn.Identity()
+ self.has_projections = has_projections
+
+ self.dim = dim
+ self.codebook_dim = codebook_dim
+ self.num_codebooks = num_codebooks
+
+ keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1)
+ assert not (num_codebooks > 1 and not keep_num_codebooks_dim)
+ self.keep_num_codebooks_dim = keep_num_codebooks_dim
+
+ # straight through activation
+
+ self.activation = straight_through_activation
+
+ # entropy aux loss related weights
+
+ assert 0 < frac_per_sample_entropy <= 1.0
+ self.frac_per_sample_entropy = frac_per_sample_entropy
+
+ self.diversity_gamma = diversity_gamma
+ self.entropy_loss_weight = entropy_loss_weight
+
+ # codebook scale
+
+ self.codebook_scale = codebook_scale
+
+ # commitment loss
+
+ self.commitment_loss_weight = commitment_loss_weight
+
+ # for no auxiliary loss, during inference
+
+ self.register_buffer("mask", 2 ** torch.arange(codebook_dim - 1, -1, -1))
+ self.register_buffer("zero", torch.tensor(0.0), persistent=False)
+
+ # codes
+
+ all_codes = torch.arange(codebook_size)
+ bits = ((all_codes[..., None].int() & self.mask) != 0).float()
+ codebook = self.bits_to_codes(bits)
+
+ self.register_buffer("codebook", codebook, persistent=False)
+
+ def bits_to_codes(self, bits):
+ return bits * self.codebook_scale * 2 - self.codebook_scale
+
+ @property
+ def dtype(self):
+ return self.codebook.dtype
+
+ def indices_to_codes(self, indices, project_out=True):
+ is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim))
+
+ if not self.keep_num_codebooks_dim:
+ indices = rearrange(indices, "... -> ... 1")
+
+ # indices to codes, which are bits of either -1 or 1
+
+ bits = ((indices[..., None].int() & self.mask) != 0).to(self.dtype)
+
+ codes = self.bits_to_codes(bits)
+
+ codes = rearrange(codes, "... c d -> ... (c d)")
+
+ # whether to project codes out to original dimensions
+ # if the input feature dimensions were not log2(codebook size)
+
+ if project_out:
+ codes = self.project_out(codes)
+
+ # rearrange codes back to original shape
+
+ if is_img_or_video:
+ codes = rearrange(codes, "b ... d -> b d ...")
+
+ return codes
+
+ @autocast(enabled=False)
+ def forward(
+ self,
+ x,
+ inv_temperature=100.0,
+ return_loss_breakdown=False,
+ mask=None,
+ ):
+ """
+ einstein notation
+ b - batch
+ n - sequence (or flattened spatial dimensions)
+ d - feature dimension, which is also log2(codebook size)
+ c - number of codebook dim
+ """
+
+ x = x.float()
+
+ is_img_or_video = x.ndim >= 4
+
+ # standardize image or video into (batch, seq, dimension)
+
+ if is_img_or_video:
+ x = rearrange(x, "b d ... -> b ... d")
+ x, ps = pack_one(x, "b * d")
+
+ assert x.shape[-1] == self.dim, f"expected dimension of {self.dim} but received {x.shape[-1]}"
+
+ x = self.project_in(x)
+
+ # split out number of codebooks
+
+ x = rearrange(x, "b n (c d) -> b n c d", c=self.num_codebooks)
+
+ # quantize by eq 3.
+
+ original_input = x
+
+ codebook_value = torch.ones_like(x) * self.codebook_scale
+ quantized = torch.where(x > 0, codebook_value, -codebook_value)
+
+ # use straight-through gradients (optionally with custom activation fn) if training
+
+ if self.training:
+ x = self.activation(x)
+ x = x + (quantized - x).detach()
+ else:
+ x = quantized
+
+ # calculate indices
+
+ indices = reduce((x > 0).int() * self.mask.int(), "b n c d -> b n c", "sum")
+
+ # entropy aux loss
+
+ if self.training:
+ # the same as euclidean distance up to a constant
+ distance = -2 * einsum("... i d, j d -> ... i j", original_input, self.codebook)
+
+ prob = (-distance * inv_temperature).softmax(dim=-1)
+
+ # account for mask
+
+ if exists(mask):
+ prob = prob[mask]
+ else:
+ prob = rearrange(prob, "b n ... -> (b n) ...")
+
+ # whether to only use a fraction of probs, for reducing memory
+
+ if self.frac_per_sample_entropy < 1.0:
+ num_tokens = prob.shape[0]
+ num_sampled_tokens = int(num_tokens * self.frac_per_sample_entropy)
+ rand_mask = torch.randn(num_tokens).argsort(dim=-1) < num_sampled_tokens
+ per_sample_probs = prob[rand_mask]
+ else:
+ per_sample_probs = prob
+
+ # calculate per sample entropy
+
+ per_sample_entropy = entropy(per_sample_probs).mean()
+
+ # distribution over all available tokens in the batch
+
+ avg_prob = reduce(per_sample_probs, "... c d -> c d", "mean")
+ codebook_entropy = entropy(avg_prob).mean()
+
+ # 1. entropy will be nudged to be low for each code, to encourage the network to output confident predictions
+ # 2. codebook entropy will be nudged to be high, to encourage all codes to be uniformly used within the batch
+
+ entropy_aux_loss = per_sample_entropy - self.diversity_gamma * codebook_entropy
+ else:
+ # if not training, just return dummy 0
+ entropy_aux_loss = per_sample_entropy = codebook_entropy = self.zero
+
+ # commit loss
+
+ if self.training:
+ commit_loss = F.mse_loss(original_input, quantized.detach(), reduction="none")
+
+ if exists(mask):
+ commit_loss = commit_loss[mask]
+
+ commit_loss = commit_loss.mean()
+ else:
+ commit_loss = self.zero
+
+ # merge back codebook dim
+
+ x = rearrange(x, "b n c d -> b n (c d)")
+
+ # project out to feature dimension if needed
+
+ x = self.project_out(x)
+
+ # reconstitute image or video dimensions
+
+ if is_img_or_video:
+ x = unpack_one(x, ps, "b * d")
+ x = rearrange(x, "b ... d -> b d ...")
+
+ indices = unpack_one(indices, ps, "b * c")
+
+ # whether to remove single codebook dim
+
+ if not self.keep_num_codebooks_dim:
+ indices = rearrange(indices, "... 1 -> ...")
+
+ # complete aux loss
+
+ aux_loss = entropy_aux_loss * self.entropy_loss_weight + commit_loss * self.commitment_loss_weight
+
+ ret = Return(x, indices, aux_loss)
+
+ if not return_loss_breakdown:
+ return ret
+
+ return ret, LossBreakdown(per_sample_entropy, codebook_entropy, commit_loss)
diff --git a/sat/sgm/modules/autoencoding/regularizers/quantize.py b/sat/sgm/modules/autoencoding/regularizers/quantize.py
new file mode 100644
index 0000000000000000000000000000000000000000..583f488c25e2283352176f7443a3233b3d4f926f
--- /dev/null
+++ b/sat/sgm/modules/autoencoding/regularizers/quantize.py
@@ -0,0 +1,453 @@
+import logging
+from abc import abstractmethod
+from typing import Dict, Iterator, Literal, Optional, Tuple, Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange
+from torch import einsum
+
+from .base import AbstractRegularizer, measure_perplexity
+
+logpy = logging.getLogger(__name__)
+
+
+class AbstractQuantizer(AbstractRegularizer):
+ def __init__(self):
+ super().__init__()
+ # Define these in your init
+ # shape (N,)
+ self.used: Optional[torch.Tensor]
+ self.re_embed: int
+ self.unknown_index: Union[Literal["random"], int]
+
+ def remap_to_used(self, inds: torch.Tensor) -> torch.Tensor:
+ assert self.used is not None, "You need to define used indices for remap"
+ ishape = inds.shape
+ assert len(ishape) > 1
+ inds = inds.reshape(ishape[0], -1)
+ used = self.used.to(inds)
+ match = (inds[:, :, None] == used[None, None, ...]).long()
+ new = match.argmax(-1)
+ unknown = match.sum(2) < 1
+ if self.unknown_index == "random":
+ new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device)
+ else:
+ new[unknown] = self.unknown_index
+ return new.reshape(ishape)
+
+ def unmap_to_all(self, inds: torch.Tensor) -> torch.Tensor:
+ assert self.used is not None, "You need to define used indices for remap"
+ ishape = inds.shape
+ assert len(ishape) > 1
+ inds = inds.reshape(ishape[0], -1)
+ used = self.used.to(inds)
+ if self.re_embed > self.used.shape[0]: # extra token
+ inds[inds >= self.used.shape[0]] = 0 # simply set to zero
+ back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
+ return back.reshape(ishape)
+
+ @abstractmethod
+ def get_codebook_entry(self, indices: torch.Tensor, shape: Optional[Tuple[int, ...]] = None) -> torch.Tensor:
+ raise NotImplementedError()
+
+ def get_trainable_parameters(self) -> Iterator[torch.nn.Parameter]:
+ yield from self.parameters()
+
+
+class GumbelQuantizer(AbstractQuantizer):
+ """
+ credit to @karpathy:
+ https://github.com/karpathy/deep-vector-quantization/blob/main/model.py (thanks!)
+ Gumbel Softmax trick quantizer
+ Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016
+ https://arxiv.org/abs/1611.01144
+ """
+
+ def __init__(
+ self,
+ num_hiddens: int,
+ embedding_dim: int,
+ n_embed: int,
+ straight_through: bool = True,
+ kl_weight: float = 5e-4,
+ temp_init: float = 1.0,
+ remap: Optional[str] = None,
+ unknown_index: str = "random",
+ loss_key: str = "loss/vq",
+ ) -> None:
+ super().__init__()
+
+ self.loss_key = loss_key
+ self.embedding_dim = embedding_dim
+ self.n_embed = n_embed
+
+ self.straight_through = straight_through
+ self.temperature = temp_init
+ self.kl_weight = kl_weight
+
+ self.proj = nn.Conv2d(num_hiddens, n_embed, 1)
+ self.embed = nn.Embedding(n_embed, embedding_dim)
+
+ self.remap = remap
+ if self.remap is not None:
+ self.register_buffer("used", torch.tensor(np.load(self.remap)))
+ self.re_embed = self.used.shape[0]
+ else:
+ self.used = None
+ self.re_embed = n_embed
+ if unknown_index == "extra":
+ self.unknown_index = self.re_embed
+ self.re_embed = self.re_embed + 1
+ else:
+ assert unknown_index == "random" or isinstance(
+ unknown_index, int
+ ), "unknown index needs to be 'random', 'extra' or any integer"
+ self.unknown_index = unknown_index # "random" or "extra" or integer
+ if self.remap is not None:
+ logpy.info(
+ f"Remapping {self.n_embed} indices to {self.re_embed} indices. "
+ f"Using {self.unknown_index} for unknown indices."
+ )
+
+ def forward(
+ self, z: torch.Tensor, temp: Optional[float] = None, return_logits: bool = False
+ ) -> Tuple[torch.Tensor, Dict]:
+ # force hard = True when we are in eval mode, as we must quantize.
+ # actually, always true seems to work
+ hard = self.straight_through if self.training else True
+ temp = self.temperature if temp is None else temp
+ out_dict = {}
+ logits = self.proj(z)
+ if self.remap is not None:
+ # continue only with used logits
+ full_zeros = torch.zeros_like(logits)
+ logits = logits[:, self.used, ...]
+
+ soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard)
+ if self.remap is not None:
+ # go back to all entries but unused set to zero
+ full_zeros[:, self.used, ...] = soft_one_hot
+ soft_one_hot = full_zeros
+ z_q = einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight)
+
+ # + kl divergence to the prior loss
+ qy = F.softmax(logits, dim=1)
+ diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.n_embed + 1e-10), dim=1).mean()
+ out_dict[self.loss_key] = diff
+
+ ind = soft_one_hot.argmax(dim=1)
+ out_dict["indices"] = ind
+ if self.remap is not None:
+ ind = self.remap_to_used(ind)
+
+ if return_logits:
+ out_dict["logits"] = logits
+
+ return z_q, out_dict
+
+ def get_codebook_entry(self, indices, shape):
+ # TODO: shape not yet optional
+ b, h, w, c = shape
+ assert b * h * w == indices.shape[0]
+ indices = rearrange(indices, "(b h w) -> b h w", b=b, h=h, w=w)
+ if self.remap is not None:
+ indices = self.unmap_to_all(indices)
+ one_hot = F.one_hot(indices, num_classes=self.n_embed).permute(0, 3, 1, 2).float()
+ z_q = einsum("b n h w, n d -> b d h w", one_hot, self.embed.weight)
+ return z_q
+
+
+class VectorQuantizer(AbstractQuantizer):
+ """
+ ____________________________________________
+ Discretization bottleneck part of the VQ-VAE.
+ Inputs:
+ - n_e : number of embeddings
+ - e_dim : dimension of embedding
+ - beta : commitment cost used in loss term,
+ beta * ||z_e(x)-sg[e]||^2
+ _____________________________________________
+ """
+
+ def __init__(
+ self,
+ n_e: int,
+ e_dim: int,
+ beta: float = 0.25,
+ remap: Optional[str] = None,
+ unknown_index: str = "random",
+ sane_index_shape: bool = False,
+ log_perplexity: bool = False,
+ embedding_weight_norm: bool = False,
+ loss_key: str = "loss/vq",
+ ):
+ super().__init__()
+ self.n_e = n_e
+ self.e_dim = e_dim
+ self.beta = beta
+ self.loss_key = loss_key
+
+ if not embedding_weight_norm:
+ self.embedding = nn.Embedding(self.n_e, self.e_dim)
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
+ else:
+ self.embedding = torch.nn.utils.weight_norm(nn.Embedding(self.n_e, self.e_dim), dim=1)
+
+ self.remap = remap
+ if self.remap is not None:
+ self.register_buffer("used", torch.tensor(np.load(self.remap)))
+ self.re_embed = self.used.shape[0]
+ else:
+ self.used = None
+ self.re_embed = n_e
+ if unknown_index == "extra":
+ self.unknown_index = self.re_embed
+ self.re_embed = self.re_embed + 1
+ else:
+ assert unknown_index == "random" or isinstance(
+ unknown_index, int
+ ), "unknown index needs to be 'random', 'extra' or any integer"
+ self.unknown_index = unknown_index # "random" or "extra" or integer
+ if self.remap is not None:
+ logpy.info(
+ f"Remapping {self.n_e} indices to {self.re_embed} indices. "
+ f"Using {self.unknown_index} for unknown indices."
+ )
+
+ self.sane_index_shape = sane_index_shape
+ self.log_perplexity = log_perplexity
+
+ def forward(
+ self,
+ z: torch.Tensor,
+ ) -> Tuple[torch.Tensor, Dict]:
+ do_reshape = z.ndim == 4
+ if do_reshape:
+ # # reshape z -> (batch, height, width, channel) and flatten
+ z = rearrange(z, "b c h w -> b h w c").contiguous()
+
+ else:
+ assert z.ndim < 4, "No reshaping strategy for inputs > 4 dimensions defined"
+ z = z.contiguous()
+
+ z_flattened = z.view(-1, self.e_dim)
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
+
+ d = (
+ torch.sum(z_flattened**2, dim=1, keepdim=True)
+ + torch.sum(self.embedding.weight**2, dim=1)
+ - 2 * torch.einsum("bd,dn->bn", z_flattened, rearrange(self.embedding.weight, "n d -> d n"))
+ )
+
+ min_encoding_indices = torch.argmin(d, dim=1)
+ z_q = self.embedding(min_encoding_indices).view(z.shape)
+ loss_dict = {}
+ if self.log_perplexity:
+ perplexity, cluster_usage = measure_perplexity(min_encoding_indices.detach(), self.n_e)
+ loss_dict.update({"perplexity": perplexity, "cluster_usage": cluster_usage})
+
+ # compute loss for embedding
+ loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2)
+ loss_dict[self.loss_key] = loss
+
+ # preserve gradients
+ z_q = z + (z_q - z).detach()
+
+ # reshape back to match original input shape
+ if do_reshape:
+ z_q = rearrange(z_q, "b h w c -> b c h w").contiguous()
+
+ if self.remap is not None:
+ min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis
+ min_encoding_indices = self.remap_to_used(min_encoding_indices)
+ min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
+
+ if self.sane_index_shape:
+ if do_reshape:
+ min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3])
+ else:
+ min_encoding_indices = rearrange(min_encoding_indices, "(b s) 1 -> b s", b=z_q.shape[0])
+
+ loss_dict["min_encoding_indices"] = min_encoding_indices
+
+ return z_q, loss_dict
+
+ def get_codebook_entry(self, indices: torch.Tensor, shape: Optional[Tuple[int, ...]] = None) -> torch.Tensor:
+ # shape specifying (batch, height, width, channel)
+ if self.remap is not None:
+ assert shape is not None, "Need to give shape for remap"
+ indices = indices.reshape(shape[0], -1) # add batch axis
+ indices = self.unmap_to_all(indices)
+ indices = indices.reshape(-1) # flatten again
+
+ # get quantized latent vectors
+ z_q = self.embedding(indices)
+
+ if shape is not None:
+ z_q = z_q.view(shape)
+ # reshape back to match original input shape
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
+
+ return z_q
+
+
+class EmbeddingEMA(nn.Module):
+ def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5):
+ super().__init__()
+ self.decay = decay
+ self.eps = eps
+ weight = torch.randn(num_tokens, codebook_dim)
+ self.weight = nn.Parameter(weight, requires_grad=False)
+ self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad=False)
+ self.embed_avg = nn.Parameter(weight.clone(), requires_grad=False)
+ self.update = True
+
+ def forward(self, embed_id):
+ return F.embedding(embed_id, self.weight)
+
+ def cluster_size_ema_update(self, new_cluster_size):
+ self.cluster_size.data.mul_(self.decay).add_(new_cluster_size, alpha=1 - self.decay)
+
+ def embed_avg_ema_update(self, new_embed_avg):
+ self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay)
+
+ def weight_update(self, num_tokens):
+ n = self.cluster_size.sum()
+ smoothed_cluster_size = (self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n
+ # normalize embedding average with smoothed cluster size
+ embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1)
+ self.weight.data.copy_(embed_normalized)
+
+
+class EMAVectorQuantizer(AbstractQuantizer):
+ def __init__(
+ self,
+ n_embed: int,
+ embedding_dim: int,
+ beta: float,
+ decay: float = 0.99,
+ eps: float = 1e-5,
+ remap: Optional[str] = None,
+ unknown_index: str = "random",
+ loss_key: str = "loss/vq",
+ ):
+ super().__init__()
+ self.codebook_dim = embedding_dim
+ self.num_tokens = n_embed
+ self.beta = beta
+ self.loss_key = loss_key
+
+ self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps)
+
+ self.remap = remap
+ if self.remap is not None:
+ self.register_buffer("used", torch.tensor(np.load(self.remap)))
+ self.re_embed = self.used.shape[0]
+ else:
+ self.used = None
+ self.re_embed = n_embed
+ if unknown_index == "extra":
+ self.unknown_index = self.re_embed
+ self.re_embed = self.re_embed + 1
+ else:
+ assert unknown_index == "random" or isinstance(
+ unknown_index, int
+ ), "unknown index needs to be 'random', 'extra' or any integer"
+ self.unknown_index = unknown_index # "random" or "extra" or integer
+ if self.remap is not None:
+ logpy.info(
+ f"Remapping {self.n_embed} indices to {self.re_embed} indices. "
+ f"Using {self.unknown_index} for unknown indices."
+ )
+
+ def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, Dict]:
+ # reshape z -> (batch, height, width, channel) and flatten
+ # z, 'b c h w -> b h w c'
+ z = rearrange(z, "b c h w -> b h w c")
+ z_flattened = z.reshape(-1, self.codebook_dim)
+
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
+ d = (
+ z_flattened.pow(2).sum(dim=1, keepdim=True)
+ + self.embedding.weight.pow(2).sum(dim=1)
+ - 2 * torch.einsum("bd,nd->bn", z_flattened, self.embedding.weight)
+ ) # 'n d -> d n'
+
+ encoding_indices = torch.argmin(d, dim=1)
+
+ z_q = self.embedding(encoding_indices).view(z.shape)
+ encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype)
+ avg_probs = torch.mean(encodings, dim=0)
+ perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
+
+ if self.training and self.embedding.update:
+ # EMA cluster size
+ encodings_sum = encodings.sum(0)
+ self.embedding.cluster_size_ema_update(encodings_sum)
+ # EMA embedding average
+ embed_sum = encodings.transpose(0, 1) @ z_flattened
+ self.embedding.embed_avg_ema_update(embed_sum)
+ # normalize embed_avg and update weight
+ self.embedding.weight_update(self.num_tokens)
+
+ # compute loss for embedding
+ loss = self.beta * F.mse_loss(z_q.detach(), z)
+
+ # preserve gradients
+ z_q = z + (z_q - z).detach()
+
+ # reshape back to match original input shape
+ # z_q, 'b h w c -> b c h w'
+ z_q = rearrange(z_q, "b h w c -> b c h w")
+
+ out_dict = {
+ self.loss_key: loss,
+ "encodings": encodings,
+ "encoding_indices": encoding_indices,
+ "perplexity": perplexity,
+ }
+
+ return z_q, out_dict
+
+
+class VectorQuantizerWithInputProjection(VectorQuantizer):
+ def __init__(
+ self,
+ input_dim: int,
+ n_codes: int,
+ codebook_dim: int,
+ beta: float = 1.0,
+ output_dim: Optional[int] = None,
+ **kwargs,
+ ):
+ super().__init__(n_codes, codebook_dim, beta, **kwargs)
+ self.proj_in = nn.Linear(input_dim, codebook_dim)
+ self.output_dim = output_dim
+ if output_dim is not None:
+ self.proj_out = nn.Linear(codebook_dim, output_dim)
+ else:
+ self.proj_out = nn.Identity()
+
+ def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, Dict]:
+ rearr = False
+ in_shape = z.shape
+
+ if z.ndim > 3:
+ rearr = self.output_dim is not None
+ z = rearrange(z, "b c ... -> b (...) c")
+ z = self.proj_in(z)
+ z_q, loss_dict = super().forward(z)
+
+ z_q = self.proj_out(z_q)
+ if rearr:
+ if len(in_shape) == 4:
+ z_q = rearrange(z_q, "b (h w) c -> b c h w ", w=in_shape[-1])
+ elif len(in_shape) == 5:
+ z_q = rearrange(z_q, "b (t h w) c -> b c t h w ", w=in_shape[-1], h=in_shape[-2])
+ else:
+ raise NotImplementedError(f"rearranging not available for {len(in_shape)}-dimensional input.")
+
+ return z_q, loss_dict
diff --git a/sat/sgm/modules/autoencoding/temporal_ae.py b/sat/sgm/modules/autoencoding/temporal_ae.py
new file mode 100644
index 0000000000000000000000000000000000000000..a45ef9d62efc30a0abd6d6e730254d6439ff419b
--- /dev/null
+++ b/sat/sgm/modules/autoencoding/temporal_ae.py
@@ -0,0 +1,331 @@
+from typing import Callable, Iterable, Union
+
+import torch
+from einops import rearrange, repeat
+
+from sgm.modules.diffusionmodules.model import (
+ XFORMERS_IS_AVAILABLE,
+ AttnBlock,
+ Decoder,
+ MemoryEfficientAttnBlock,
+ ResnetBlock,
+)
+from sgm.modules.diffusionmodules.openaimodel import ResBlock, timestep_embedding
+from sgm.modules.video_attention import VideoTransformerBlock
+from sgm.util import partialclass
+
+
+class VideoResBlock(ResnetBlock):
+ def __init__(
+ self,
+ out_channels,
+ *args,
+ dropout=0.0,
+ video_kernel_size=3,
+ alpha=0.0,
+ merge_strategy="learned",
+ **kwargs,
+ ):
+ super().__init__(out_channels=out_channels, dropout=dropout, *args, **kwargs)
+ if video_kernel_size is None:
+ video_kernel_size = [3, 1, 1]
+ self.time_stack = ResBlock(
+ channels=out_channels,
+ emb_channels=0,
+ dropout=dropout,
+ dims=3,
+ use_scale_shift_norm=False,
+ use_conv=False,
+ up=False,
+ down=False,
+ kernel_size=video_kernel_size,
+ use_checkpoint=False,
+ skip_t_emb=True,
+ )
+
+ self.merge_strategy = merge_strategy
+ if self.merge_strategy == "fixed":
+ self.register_buffer("mix_factor", torch.Tensor([alpha]))
+ elif self.merge_strategy == "learned":
+ self.register_parameter("mix_factor", torch.nn.Parameter(torch.Tensor([alpha])))
+ else:
+ raise ValueError(f"unknown merge strategy {self.merge_strategy}")
+
+ def get_alpha(self, bs):
+ if self.merge_strategy == "fixed":
+ return self.mix_factor
+ elif self.merge_strategy == "learned":
+ return torch.sigmoid(self.mix_factor)
+ else:
+ raise NotImplementedError()
+
+ def forward(self, x, temb, skip_video=False, timesteps=None):
+ if timesteps is None:
+ timesteps = self.timesteps
+
+ b, c, h, w = x.shape
+
+ x = super().forward(x, temb)
+
+ if not skip_video:
+ x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
+
+ x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
+
+ x = self.time_stack(x, temb)
+
+ alpha = self.get_alpha(bs=b // timesteps)
+ x = alpha * x + (1.0 - alpha) * x_mix
+
+ x = rearrange(x, "b c t h w -> (b t) c h w")
+ return x
+
+
+class AE3DConv(torch.nn.Conv2d):
+ def __init__(self, in_channels, out_channels, video_kernel_size=3, *args, **kwargs):
+ super().__init__(in_channels, out_channels, *args, **kwargs)
+ if isinstance(video_kernel_size, Iterable):
+ padding = [int(k // 2) for k in video_kernel_size]
+ else:
+ padding = int(video_kernel_size // 2)
+
+ self.time_mix_conv = torch.nn.Conv3d(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ kernel_size=video_kernel_size,
+ padding=padding,
+ )
+
+ def forward(self, input, timesteps, skip_video=False):
+ x = super().forward(input)
+ if skip_video:
+ return x
+ x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
+ x = self.time_mix_conv(x)
+ return rearrange(x, "b c t h w -> (b t) c h w")
+
+
+class VideoBlock(AttnBlock):
+ def __init__(self, in_channels: int, alpha: float = 0, merge_strategy: str = "learned"):
+ super().__init__(in_channels)
+ # no context, single headed, as in base class
+ self.time_mix_block = VideoTransformerBlock(
+ dim=in_channels,
+ n_heads=1,
+ d_head=in_channels,
+ checkpoint=False,
+ ff_in=True,
+ attn_mode="softmax",
+ )
+
+ time_embed_dim = self.in_channels * 4
+ self.video_time_embed = torch.nn.Sequential(
+ torch.nn.Linear(self.in_channels, time_embed_dim),
+ torch.nn.SiLU(),
+ torch.nn.Linear(time_embed_dim, self.in_channels),
+ )
+
+ self.merge_strategy = merge_strategy
+ if self.merge_strategy == "fixed":
+ self.register_buffer("mix_factor", torch.Tensor([alpha]))
+ elif self.merge_strategy == "learned":
+ self.register_parameter("mix_factor", torch.nn.Parameter(torch.Tensor([alpha])))
+ else:
+ raise ValueError(f"unknown merge strategy {self.merge_strategy}")
+
+ def forward(self, x, timesteps, skip_video=False):
+ if skip_video:
+ return super().forward(x)
+
+ x_in = x
+ x = self.attention(x)
+ h, w = x.shape[2:]
+ x = rearrange(x, "b c h w -> b (h w) c")
+
+ x_mix = x
+ num_frames = torch.arange(timesteps, device=x.device)
+ num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps)
+ num_frames = rearrange(num_frames, "b t -> (b t)")
+ t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False)
+ emb = self.video_time_embed(t_emb) # b, n_channels
+ emb = emb[:, None, :]
+ x_mix = x_mix + emb
+
+ alpha = self.get_alpha()
+ x_mix = self.time_mix_block(x_mix, timesteps=timesteps)
+ x = alpha * x + (1.0 - alpha) * x_mix # alpha merge
+
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
+ x = self.proj_out(x)
+
+ return x_in + x
+
+ def get_alpha(
+ self,
+ ):
+ if self.merge_strategy == "fixed":
+ return self.mix_factor
+ elif self.merge_strategy == "learned":
+ return torch.sigmoid(self.mix_factor)
+ else:
+ raise NotImplementedError(f"unknown merge strategy {self.merge_strategy}")
+
+
+class MemoryEfficientVideoBlock(MemoryEfficientAttnBlock):
+ def __init__(self, in_channels: int, alpha: float = 0, merge_strategy: str = "learned"):
+ super().__init__(in_channels)
+ # no context, single headed, as in base class
+ self.time_mix_block = VideoTransformerBlock(
+ dim=in_channels,
+ n_heads=1,
+ d_head=in_channels,
+ checkpoint=False,
+ ff_in=True,
+ attn_mode="softmax-xformers",
+ )
+
+ time_embed_dim = self.in_channels * 4
+ self.video_time_embed = torch.nn.Sequential(
+ torch.nn.Linear(self.in_channels, time_embed_dim),
+ torch.nn.SiLU(),
+ torch.nn.Linear(time_embed_dim, self.in_channels),
+ )
+
+ self.merge_strategy = merge_strategy
+ if self.merge_strategy == "fixed":
+ self.register_buffer("mix_factor", torch.Tensor([alpha]))
+ elif self.merge_strategy == "learned":
+ self.register_parameter("mix_factor", torch.nn.Parameter(torch.Tensor([alpha])))
+ else:
+ raise ValueError(f"unknown merge strategy {self.merge_strategy}")
+
+ def forward(self, x, timesteps, skip_time_block=False):
+ if skip_time_block:
+ return super().forward(x)
+
+ x_in = x
+ x = self.attention(x)
+ h, w = x.shape[2:]
+ x = rearrange(x, "b c h w -> b (h w) c")
+
+ x_mix = x
+ num_frames = torch.arange(timesteps, device=x.device)
+ num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps)
+ num_frames = rearrange(num_frames, "b t -> (b t)")
+ t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False)
+ emb = self.video_time_embed(t_emb) # b, n_channels
+ emb = emb[:, None, :]
+ x_mix = x_mix + emb
+
+ alpha = self.get_alpha()
+ x_mix = self.time_mix_block(x_mix, timesteps=timesteps)
+ x = alpha * x + (1.0 - alpha) * x_mix # alpha merge
+
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
+ x = self.proj_out(x)
+
+ return x_in + x
+
+ def get_alpha(
+ self,
+ ):
+ if self.merge_strategy == "fixed":
+ return self.mix_factor
+ elif self.merge_strategy == "learned":
+ return torch.sigmoid(self.mix_factor)
+ else:
+ raise NotImplementedError(f"unknown merge strategy {self.merge_strategy}")
+
+
+def make_time_attn(
+ in_channels,
+ attn_type="vanilla",
+ attn_kwargs=None,
+ alpha: float = 0,
+ merge_strategy: str = "learned",
+):
+ assert attn_type in [
+ "vanilla",
+ "vanilla-xformers",
+ ], f"attn_type {attn_type} not supported for spatio-temporal attention"
+ print(f"making spatial and temporal attention of type '{attn_type}' with {in_channels} in_channels")
+ if not XFORMERS_IS_AVAILABLE and attn_type == "vanilla-xformers":
+ print(
+ f"Attention mode '{attn_type}' is not available. Falling back to vanilla attention. "
+ f"This is not a problem in Pytorch >= 2.0. FYI, you are running with PyTorch version {torch.__version__}"
+ )
+ attn_type = "vanilla"
+
+ if attn_type == "vanilla":
+ assert attn_kwargs is None
+ return partialclass(VideoBlock, in_channels, alpha=alpha, merge_strategy=merge_strategy)
+ elif attn_type == "vanilla-xformers":
+ print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...")
+ return partialclass(
+ MemoryEfficientVideoBlock,
+ in_channels,
+ alpha=alpha,
+ merge_strategy=merge_strategy,
+ )
+ else:
+ return NotImplementedError()
+
+
+class Conv2DWrapper(torch.nn.Conv2d):
+ def forward(self, input: torch.Tensor, **kwargs) -> torch.Tensor:
+ return super().forward(input)
+
+
+class VideoDecoder(Decoder):
+ available_time_modes = ["all", "conv-only", "attn-only"]
+
+ def __init__(
+ self,
+ *args,
+ video_kernel_size: Union[int, list] = 3,
+ alpha: float = 0.0,
+ merge_strategy: str = "learned",
+ time_mode: str = "conv-only",
+ **kwargs,
+ ):
+ self.video_kernel_size = video_kernel_size
+ self.alpha = alpha
+ self.merge_strategy = merge_strategy
+ self.time_mode = time_mode
+ assert (
+ self.time_mode in self.available_time_modes
+ ), f"time_mode parameter has to be in {self.available_time_modes}"
+ super().__init__(*args, **kwargs)
+
+ def get_last_layer(self, skip_time_mix=False, **kwargs):
+ if self.time_mode == "attn-only":
+ raise NotImplementedError("TODO")
+ else:
+ return self.conv_out.time_mix_conv.weight if not skip_time_mix else self.conv_out.weight
+
+ def _make_attn(self) -> Callable:
+ if self.time_mode not in ["conv-only", "only-last-conv"]:
+ return partialclass(
+ make_time_attn,
+ alpha=self.alpha,
+ merge_strategy=self.merge_strategy,
+ )
+ else:
+ return super()._make_attn()
+
+ def _make_conv(self) -> Callable:
+ if self.time_mode != "attn-only":
+ return partialclass(AE3DConv, video_kernel_size=self.video_kernel_size)
+ else:
+ return Conv2DWrapper
+
+ def _make_resblock(self) -> Callable:
+ if self.time_mode not in ["attn-only", "only-last-conv"]:
+ return partialclass(
+ VideoResBlock,
+ video_kernel_size=self.video_kernel_size,
+ alpha=self.alpha,
+ merge_strategy=self.merge_strategy,
+ )
+ else:
+ return super()._make_resblock()
diff --git a/sat/sgm/modules/autoencoding/vqvae/movq_dec_3d.py b/sat/sgm/modules/autoencoding/vqvae/movq_dec_3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..0f9a46988703904f1e3a0b5f8f28f33cce4537bd
--- /dev/null
+++ b/sat/sgm/modules/autoencoding/vqvae/movq_dec_3d.py
@@ -0,0 +1,495 @@
+# pytorch_diffusion + derived encoder decoder
+import math
+import torch
+import torch.nn as nn
+import numpy as np
+from einops import rearrange
+from .movq_enc_3d import CausalConv3d, Upsample3D, DownSample3D
+
+
+def cast_tuple(t, length=1):
+ return t if isinstance(t, tuple) else ((t,) * length)
+
+
+def divisible_by(num, den):
+ return (num % den) == 0
+
+
+def is_odd(n):
+ return not divisible_by(n, 2)
+
+
+def get_timestep_embedding(timesteps, embedding_dim):
+ """
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
+ From Fairseq.
+ Build sinusoidal embeddings.
+ This matches the implementation in tensor2tensor, but differs slightly
+ from the description in Section 3.5 of "Attention Is All You Need".
+ """
+ assert len(timesteps.shape) == 1
+
+ half_dim = embedding_dim // 2
+ emb = math.log(10000) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
+ emb = emb.to(device=timesteps.device)
+ emb = timesteps.float()[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
+ return emb
+
+
+def nonlinearity(x):
+ # swish
+ return x * torch.sigmoid(x)
+
+
+class SpatialNorm3D(nn.Module):
+ def __init__(
+ self,
+ f_channels,
+ zq_channels,
+ norm_layer=nn.GroupNorm,
+ freeze_norm_layer=False,
+ add_conv=False,
+ pad_mode="constant",
+ **norm_layer_params,
+ ):
+ super().__init__()
+ self.norm_layer = norm_layer(num_channels=f_channels, **norm_layer_params)
+ if freeze_norm_layer:
+ for p in self.norm_layer.parameters:
+ p.requires_grad = False
+ self.add_conv = add_conv
+ if self.add_conv:
+ self.conv = CausalConv3d(zq_channels, zq_channels, kernel_size=3, pad_mode=pad_mode)
+ self.conv_y = CausalConv3d(zq_channels, f_channels, kernel_size=1, pad_mode=pad_mode)
+ self.conv_b = CausalConv3d(zq_channels, f_channels, kernel_size=1, pad_mode=pad_mode)
+
+ def forward(self, f, zq):
+ if zq.shape[2] > 1:
+ f_first, f_rest = f[:, :, :1], f[:, :, 1:]
+ f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:]
+ zq_first, zq_rest = zq[:, :, :1], zq[:, :, 1:]
+ zq_first = torch.nn.functional.interpolate(zq_first, size=f_first_size, mode="nearest")
+ zq_rest = torch.nn.functional.interpolate(zq_rest, size=f_rest_size, mode="nearest")
+ zq = torch.cat([zq_first, zq_rest], dim=2)
+ else:
+ zq = torch.nn.functional.interpolate(zq, size=f.shape[-3:], mode="nearest")
+ if self.add_conv:
+ zq = self.conv(zq)
+ norm_f = self.norm_layer(f)
+ new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
+ return new_f
+
+
+def Normalize3D(in_channels, zq_ch, add_conv):
+ return SpatialNorm3D(
+ in_channels,
+ zq_ch,
+ norm_layer=nn.GroupNorm,
+ freeze_norm_layer=False,
+ add_conv=add_conv,
+ num_groups=32,
+ eps=1e-6,
+ affine=True,
+ )
+
+
+class ResnetBlock3D(nn.Module):
+ def __init__(
+ self,
+ *,
+ in_channels,
+ out_channels=None,
+ conv_shortcut=False,
+ dropout,
+ temb_channels=512,
+ zq_ch=None,
+ add_conv=False,
+ pad_mode="constant",
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+ self.use_conv_shortcut = conv_shortcut
+
+ self.norm1 = Normalize3D(in_channels, zq_ch, add_conv=add_conv)
+ self.conv1 = CausalConv3d(in_channels, out_channels, kernel_size=3, pad_mode=pad_mode)
+ if temb_channels > 0:
+ self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
+ self.norm2 = Normalize3D(out_channels, zq_ch, add_conv=add_conv)
+ self.dropout = torch.nn.Dropout(dropout)
+ self.conv2 = CausalConv3d(out_channels, out_channels, kernel_size=3, pad_mode=pad_mode)
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ self.conv_shortcut = CausalConv3d(in_channels, out_channels, kernel_size=3, pad_mode=pad_mode)
+ else:
+ self.nin_shortcut = torch.nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
+
+ def forward(self, x, temb, zq):
+ h = x
+ h = self.norm1(h, zq)
+ h = nonlinearity(h)
+ h = self.conv1(h)
+
+ if temb is not None:
+ h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None, None]
+
+ h = self.norm2(h, zq)
+ h = nonlinearity(h)
+ h = self.dropout(h)
+ h = self.conv2(h)
+
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ x = self.conv_shortcut(x)
+ else:
+ x = self.nin_shortcut(x)
+
+ return x + h
+
+
+class AttnBlock2D(nn.Module):
+ def __init__(self, in_channels, zq_ch=None, add_conv=False):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize3D(in_channels, zq_ch, add_conv=add_conv)
+ self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+ self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+ self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+ self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+
+ def forward(self, x, zq):
+ h_ = x
+ h_ = self.norm(h_, zq)
+
+ t = h_.shape[2]
+ h_ = rearrange(h_, "b c t h w -> (b t) c h w")
+
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b, c, h, w = q.shape
+ q = q.reshape(b, c, h * w)
+ q = q.permute(0, 2, 1) # b,hw,c
+ k = k.reshape(b, c, h * w) # b,c,hw
+ w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
+ w_ = w_ * (int(c) ** (-0.5))
+ w_ = torch.nn.functional.softmax(w_, dim=2)
+
+ # attend to values
+ v = v.reshape(b, c, h * w)
+ w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
+ h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
+ h_ = h_.reshape(b, c, h, w)
+
+ h_ = self.proj_out(h_)
+
+ h_ = rearrange(h_, "(b t) c h w -> b c t h w", t=t)
+
+ return x + h_
+
+
+class MOVQDecoder3D(nn.Module):
+ def __init__(
+ self,
+ *,
+ ch,
+ out_ch,
+ ch_mult=(1, 2, 4, 8),
+ num_res_blocks,
+ attn_resolutions,
+ dropout=0.0,
+ resamp_with_conv=True,
+ in_channels,
+ resolution,
+ z_channels,
+ give_pre_end=False,
+ zq_ch=None,
+ add_conv=False,
+ pad_mode="first",
+ temporal_compress_times=4,
+ **ignorekwargs,
+ ):
+ super().__init__()
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+ self.give_pre_end = give_pre_end
+
+ # log2 of temporal_compress_times
+ self.temporal_compress_level = int(np.log2(temporal_compress_times))
+
+ if zq_ch is None:
+ zq_ch = z_channels
+
+ block_in = ch * ch_mult[self.num_resolutions - 1]
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
+ self.z_shape = (1, z_channels, curr_res, curr_res)
+
+ self.conv_in = CausalConv3d(z_channels, block_in, kernel_size=3, pad_mode=pad_mode)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock3D(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ zq_ch=zq_ch,
+ add_conv=add_conv,
+ pad_mode=pad_mode,
+ )
+
+ self.mid.block_2 = ResnetBlock3D(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ zq_ch=zq_ch,
+ add_conv=add_conv,
+ pad_mode=pad_mode,
+ )
+
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks + 1):
+ block.append(
+ ResnetBlock3D(
+ in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ zq_ch=zq_ch,
+ add_conv=add_conv,
+ pad_mode=pad_mode,
+ )
+ )
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(AttnBlock2D(block_in, zq_ch, add_conv=add_conv))
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ if i_level < self.num_resolutions - self.temporal_compress_level:
+ up.upsample = Upsample3D(block_in, resamp_with_conv, compress_time=False)
+ else:
+ up.upsample = Upsample3D(block_in, resamp_with_conv, compress_time=True)
+ curr_res = curr_res * 2
+ self.up.insert(0, up) # prepend to get consistent order
+
+ self.norm_out = Normalize3D(block_in, zq_ch, add_conv=add_conv)
+ self.conv_out = CausalConv3d(block_in, out_ch, kernel_size=3, pad_mode=pad_mode)
+
+ def forward(self, z, use_cp=False):
+ self.last_z_shape = z.shape
+
+ # timestep embedding
+ temb = None
+
+ t = z.shape[2]
+ # z to block_in
+
+ zq = z
+ h = self.conv_in(z)
+
+ # middle
+ h = self.mid.block_1(h, temb, zq)
+ # h = self.mid.attn_1(h, zq)
+ h = self.mid.block_2(h, temb, zq)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks + 1):
+ h = self.up[i_level].block[i_block](h, temb, zq)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h, zq)
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+
+ # end
+ if self.give_pre_end:
+ return h
+
+ h = self.norm_out(h, zq)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+ def get_last_layer(self):
+ return self.conv_out.conv.weight
+
+
+class NewDecoder3D(nn.Module):
+ def __init__(
+ self,
+ *,
+ ch,
+ out_ch,
+ ch_mult=(1, 2, 4, 8),
+ num_res_blocks,
+ attn_resolutions,
+ dropout=0.0,
+ resamp_with_conv=True,
+ in_channels,
+ resolution,
+ z_channels,
+ give_pre_end=False,
+ zq_ch=None,
+ add_conv=False,
+ pad_mode="first",
+ temporal_compress_times=4,
+ post_quant_conv=False,
+ **ignorekwargs,
+ ):
+ super().__init__()
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+ self.give_pre_end = give_pre_end
+
+ # log2 of temporal_compress_times
+ self.temporal_compress_level = int(np.log2(temporal_compress_times))
+
+ if zq_ch is None:
+ zq_ch = z_channels
+ if post_quant_conv:
+ self.post_quant_conv = CausalConv3d(zq_ch, z_channels, kernel_size=3, pad_mode=pad_mode)
+ else:
+ self.post_quant_conv = None
+
+ # compute in_ch_mult, block_in and curr_res at lowest res
+ in_ch_mult = (1,) + tuple(ch_mult)
+ block_in = ch * ch_mult[self.num_resolutions - 1]
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
+ self.z_shape = (1, z_channels, curr_res, curr_res)
+ print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
+
+ # z to block_in
+ # self.conv_in = torch.nn.Conv3d(z_channels,
+ # block_in,
+ # kernel_size=3,
+ # stride=1,
+ # padding=1)
+ self.conv_in = CausalConv3d(z_channels, block_in, kernel_size=3, pad_mode=pad_mode)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock3D(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ zq_ch=zq_ch,
+ add_conv=add_conv,
+ pad_mode=pad_mode,
+ )
+ # remove attention block
+ # self.mid.attn_1 = AttnBlock2D(block_in, zq_ch, add_conv=add_conv)
+ self.mid.block_2 = ResnetBlock3D(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ zq_ch=zq_ch,
+ add_conv=add_conv,
+ pad_mode=pad_mode,
+ )
+
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks + 1):
+ block.append(
+ ResnetBlock3D(
+ in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ zq_ch=zq_ch,
+ add_conv=add_conv,
+ pad_mode=pad_mode,
+ )
+ )
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(AttnBlock2D(block_in, zq_ch, add_conv=add_conv))
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ if i_level < self.num_resolutions - self.temporal_compress_level:
+ up.upsample = Upsample3D(block_in, resamp_with_conv, compress_time=False)
+ else:
+ up.upsample = Upsample3D(block_in, resamp_with_conv, compress_time=True)
+ curr_res = curr_res * 2
+ self.up.insert(0, up) # prepend to get consistent order
+
+ self.norm_out = Normalize3D(block_in, zq_ch, add_conv=add_conv)
+ # self.conv_out = torch.nn.Conv3d(block_in,
+ # out_ch,
+ # kernel_size=3,
+ # stride=1,
+ # padding=1)
+ self.conv_out = CausalConv3d(block_in, out_ch, kernel_size=3, pad_mode=pad_mode)
+
+ def forward(self, z):
+ # assert z.shape[1:] == self.z_shape[1:]
+ self.last_z_shape = z.shape
+
+ # timestep embedding
+ temb = None
+
+ t = z.shape[2]
+ # z to block_in
+
+ zq = z
+ if self.post_quant_conv is not None:
+ z = self.post_quant_conv(z)
+ h = self.conv_in(z)
+
+ # middle
+ h = self.mid.block_1(h, temb, zq)
+ # h = self.mid.attn_1(h, zq)
+ h = self.mid.block_2(h, temb, zq)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks + 1):
+ h = self.up[i_level].block[i_block](h, temb, zq)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h, zq)
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+
+ # end
+ if self.give_pre_end:
+ return h
+
+ h = self.norm_out(h, zq)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+ def get_last_layer(self):
+ return self.conv_out.conv.weight
diff --git a/sat/sgm/modules/autoencoding/vqvae/movq_dec_3d_dev.py b/sat/sgm/modules/autoencoding/vqvae/movq_dec_3d_dev.py
new file mode 100644
index 0000000000000000000000000000000000000000..1b9a663b9c0bfdc364d839285c2cb314661a6c4c
--- /dev/null
+++ b/sat/sgm/modules/autoencoding/vqvae/movq_dec_3d_dev.py
@@ -0,0 +1,535 @@
+# pytorch_diffusion + derived encoder decoder
+import math
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+
+from beartype import beartype
+from beartype.typing import Union, Tuple, Optional, List
+from einops import rearrange
+
+from .movq_enc_3d import CausalConv3d, Upsample3D, DownSample3D
+
+
+def cast_tuple(t, length=1):
+ return t if isinstance(t, tuple) else ((t,) * length)
+
+
+def divisible_by(num, den):
+ return (num % den) == 0
+
+
+def is_odd(n):
+ return not divisible_by(n, 2)
+
+
+def get_timestep_embedding(timesteps, embedding_dim):
+ """
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
+ From Fairseq.
+ Build sinusoidal embeddings.
+ This matches the implementation in tensor2tensor, but differs slightly
+ from the description in Section 3.5 of "Attention Is All You Need".
+ """
+ assert len(timesteps.shape) == 1
+
+ half_dim = embedding_dim // 2
+ emb = math.log(10000) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
+ emb = emb.to(device=timesteps.device)
+ emb = timesteps.float()[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
+ return emb
+
+
+def nonlinearity(x):
+ # swish
+ return x * torch.sigmoid(x)
+
+
+class SpatialNorm3D(nn.Module):
+ def __init__(
+ self,
+ f_channels,
+ zq_channels,
+ norm_layer=nn.GroupNorm,
+ freeze_norm_layer=False,
+ add_conv=False,
+ pad_mode="constant",
+ **norm_layer_params,
+ ):
+ super().__init__()
+ self.norm_layer = norm_layer(num_channels=f_channels, **norm_layer_params)
+ if freeze_norm_layer:
+ for p in self.norm_layer.parameters:
+ p.requires_grad = False
+ self.add_conv = add_conv
+ if self.add_conv:
+ # self.conv = nn.Conv3d(zq_channels, zq_channels, kernel_size=3, stride=1, padding=1)
+ self.conv = CausalConv3d(zq_channels, zq_channels, kernel_size=3, pad_mode=pad_mode)
+ # self.conv_y = nn.Conv3d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
+ # self.conv_b = nn.Conv3d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
+ self.conv_y = CausalConv3d(zq_channels, f_channels, kernel_size=1, pad_mode=pad_mode)
+ self.conv_b = CausalConv3d(zq_channels, f_channels, kernel_size=1, pad_mode=pad_mode)
+
+ def forward(self, f, zq):
+ if zq.shape[2] > 1:
+ f_first, f_rest = f[:, :, :1], f[:, :, 1:]
+ f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:]
+ zq_first, zq_rest = zq[:, :, :1], zq[:, :, 1:]
+ zq_first = torch.nn.functional.interpolate(zq_first, size=f_first_size, mode="nearest")
+ zq_rest = torch.nn.functional.interpolate(zq_rest, size=f_rest_size, mode="nearest")
+ zq = torch.cat([zq_first, zq_rest], dim=2)
+ else:
+ zq = torch.nn.functional.interpolate(zq, size=f.shape[-3:], mode="nearest")
+ if self.add_conv:
+ zq = self.conv(zq)
+ norm_f = self.norm_layer(f)
+ new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
+ return new_f
+
+
+def Normalize3D(in_channels, zq_ch, add_conv):
+ return SpatialNorm3D(
+ in_channels,
+ zq_ch,
+ norm_layer=nn.GroupNorm,
+ freeze_norm_layer=False,
+ add_conv=add_conv,
+ num_groups=32,
+ eps=1e-6,
+ affine=True,
+ )
+
+
+class ResnetBlock3D(nn.Module):
+ def __init__(
+ self,
+ *,
+ in_channels,
+ out_channels=None,
+ conv_shortcut=False,
+ dropout,
+ temb_channels=512,
+ zq_ch=None,
+ add_conv=False,
+ pad_mode="constant",
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+ self.use_conv_shortcut = conv_shortcut
+
+ self.norm1 = Normalize3D(in_channels, zq_ch, add_conv=add_conv)
+ # self.conv1 = torch.nn.Conv3d(in_channels,
+ # out_channels,
+ # kernel_size=3,
+ # stride=1,
+ # padding=1)
+ self.conv1 = CausalConv3d(in_channels, out_channels, kernel_size=3, pad_mode=pad_mode)
+ if temb_channels > 0:
+ self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
+ self.norm2 = Normalize3D(out_channels, zq_ch, add_conv=add_conv)
+ self.dropout = torch.nn.Dropout(dropout)
+ # self.conv2 = torch.nn.Conv3d(out_channels,
+ # out_channels,
+ # kernel_size=3,
+ # stride=1,
+ # padding=1)
+ self.conv2 = CausalConv3d(out_channels, out_channels, kernel_size=3, pad_mode=pad_mode)
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ # self.conv_shortcut = torch.nn.Conv3d(in_channels,
+ # out_channels,
+ # kernel_size=3,
+ # stride=1,
+ # padding=1)
+ self.conv_shortcut = CausalConv3d(in_channels, out_channels, kernel_size=3, pad_mode=pad_mode)
+ else:
+ self.nin_shortcut = torch.nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
+ # self.nin_shortcut = CausalConv3d(in_channels, out_channels, kernel_size=1, pad_mode=pad_mode)
+
+ def forward(self, x, temb, zq):
+ h = x
+ h = self.norm1(h, zq)
+ h = nonlinearity(h)
+ h = self.conv1(h)
+
+ if temb is not None:
+ h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None, None]
+
+ h = self.norm2(h, zq)
+ h = nonlinearity(h)
+ h = self.dropout(h)
+ h = self.conv2(h)
+
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ x = self.conv_shortcut(x)
+ else:
+ x = self.nin_shortcut(x)
+
+ return x + h
+
+
+class AttnBlock2D(nn.Module):
+ def __init__(self, in_channels, zq_ch=None, add_conv=False):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize3D(in_channels, zq_ch, add_conv=add_conv)
+ self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+ self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+ self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+ self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+
+ def forward(self, x, zq):
+ h_ = x
+ h_ = self.norm(h_, zq)
+
+ t = h_.shape[2]
+ h_ = rearrange(h_, "b c t h w -> (b t) c h w")
+
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b, c, h, w = q.shape
+ q = q.reshape(b, c, h * w)
+ q = q.permute(0, 2, 1) # b,hw,c
+ k = k.reshape(b, c, h * w) # b,c,hw
+ w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
+ w_ = w_ * (int(c) ** (-0.5))
+ w_ = torch.nn.functional.softmax(w_, dim=2)
+
+ # attend to values
+ v = v.reshape(b, c, h * w)
+ w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
+ h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
+ h_ = h_.reshape(b, c, h, w)
+
+ h_ = self.proj_out(h_)
+
+ h_ = rearrange(h_, "(b t) c h w -> b c t h w", t=t)
+
+ return x + h_
+
+
+class MOVQDecoder3D(nn.Module):
+ def __init__(
+ self,
+ *,
+ ch,
+ out_ch,
+ ch_mult=(1, 2, 4, 8),
+ num_res_blocks,
+ attn_resolutions,
+ dropout=0.0,
+ resamp_with_conv=True,
+ in_channels,
+ resolution,
+ z_channels,
+ give_pre_end=False,
+ zq_ch=None,
+ add_conv=False,
+ pad_mode="first",
+ temporal_compress_times=4,
+ **ignorekwargs,
+ ):
+ super().__init__()
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+ self.give_pre_end = give_pre_end
+
+ # log2 of temporal_compress_times
+ self.temporal_compress_level = int(np.log2(temporal_compress_times))
+
+ if zq_ch is None:
+ zq_ch = z_channels
+
+ # compute in_ch_mult, block_in and curr_res at lowest res
+ in_ch_mult = (1,) + tuple(ch_mult)
+ block_in = ch * ch_mult[self.num_resolutions - 1]
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
+ self.z_shape = (1, z_channels, curr_res, curr_res)
+ print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
+
+ # z to block_in
+ # self.conv_in = torch.nn.Conv3d(z_channels,
+ # block_in,
+ # kernel_size=3,
+ # stride=1,
+ # padding=1)
+ self.conv_in = CausalConv3d(z_channels, block_in, kernel_size=3, pad_mode=pad_mode)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock3D(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ zq_ch=zq_ch,
+ add_conv=add_conv,
+ pad_mode=pad_mode,
+ )
+ # remove attention block
+ # self.mid.attn_1 = AttnBlock2D(block_in, zq_ch, add_conv=add_conv)
+ self.mid.block_2 = ResnetBlock3D(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ zq_ch=zq_ch,
+ add_conv=add_conv,
+ pad_mode=pad_mode,
+ )
+
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks + 1):
+ block.append(
+ ResnetBlock3D(
+ in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ zq_ch=zq_ch,
+ add_conv=add_conv,
+ pad_mode=pad_mode,
+ )
+ )
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(AttnBlock2D(block_in, zq_ch, add_conv=add_conv))
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ if i_level < self.num_resolutions - self.temporal_compress_level:
+ up.upsample = Upsample3D(block_in, resamp_with_conv, compress_time=False)
+ else:
+ up.upsample = Upsample3D(block_in, resamp_with_conv, compress_time=True)
+ curr_res = curr_res * 2
+ self.up.insert(0, up) # prepend to get consistent order
+
+ self.norm_out = Normalize3D(block_in, zq_ch, add_conv=add_conv)
+ # self.conv_out = torch.nn.Conv3d(block_in,
+ # out_ch,
+ # kernel_size=3,
+ # stride=1,
+ # padding=1)
+ self.conv_out = CausalConv3d(block_in, out_ch, kernel_size=3, pad_mode=pad_mode)
+
+ def forward(self, z, use_cp=False):
+ # assert z.shape[1:] == self.z_shape[1:]
+ self.last_z_shape = z.shape
+
+ # timestep embedding
+ temb = None
+
+ t = z.shape[2]
+ # z to block_in
+
+ zq = z
+ h = self.conv_in(z)
+
+ # middle
+ h = self.mid.block_1(h, temb, zq)
+ # h = self.mid.attn_1(h, zq)
+ h = self.mid.block_2(h, temb, zq)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks + 1):
+ h = self.up[i_level].block[i_block](h, temb, zq)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h, zq)
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+
+ # end
+ if self.give_pre_end:
+ return h
+
+ h = self.norm_out(h, zq)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+ def get_last_layer(self):
+ return self.conv_out.conv.weight
+
+
+class NewDecoder3D(nn.Module):
+ def __init__(
+ self,
+ *,
+ ch,
+ out_ch,
+ ch_mult=(1, 2, 4, 8),
+ num_res_blocks,
+ attn_resolutions,
+ dropout=0.0,
+ resamp_with_conv=True,
+ in_channels,
+ resolution,
+ z_channels,
+ give_pre_end=False,
+ zq_ch=None,
+ add_conv=False,
+ pad_mode="first",
+ temporal_compress_times=4,
+ post_quant_conv=False,
+ **ignorekwargs,
+ ):
+ super().__init__()
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+ self.give_pre_end = give_pre_end
+
+ # log2 of temporal_compress_times
+ self.temporal_compress_level = int(np.log2(temporal_compress_times))
+
+ if zq_ch is None:
+ zq_ch = z_channels
+ if post_quant_conv:
+ self.post_quant_conv = CausalConv3d(zq_ch, z_channels, kernel_size=3, pad_mode=pad_mode)
+ else:
+ self.post_quant_conv = None
+
+ # compute in_ch_mult, block_in and curr_res at lowest res
+ in_ch_mult = (1,) + tuple(ch_mult)
+ block_in = ch * ch_mult[self.num_resolutions - 1]
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
+ self.z_shape = (1, z_channels, curr_res, curr_res)
+ print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
+
+ # z to block_in
+ # self.conv_in = torch.nn.Conv3d(z_channels,
+ # block_in,
+ # kernel_size=3,
+ # stride=1,
+ # padding=1)
+ self.conv_in = CausalConv3d(z_channels, block_in, kernel_size=3, pad_mode=pad_mode)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock3D(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ zq_ch=zq_ch,
+ add_conv=add_conv,
+ pad_mode=pad_mode,
+ )
+ # remove attention block
+ # self.mid.attn_1 = AttnBlock2D(block_in, zq_ch, add_conv=add_conv)
+ self.mid.block_2 = ResnetBlock3D(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ zq_ch=zq_ch,
+ add_conv=add_conv,
+ pad_mode=pad_mode,
+ )
+
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks + 1):
+ block.append(
+ ResnetBlock3D(
+ in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ zq_ch=zq_ch,
+ add_conv=add_conv,
+ pad_mode=pad_mode,
+ )
+ )
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(AttnBlock2D(block_in, zq_ch, add_conv=add_conv))
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ if i_level < self.num_resolutions - self.temporal_compress_level:
+ up.upsample = Upsample3D(block_in, resamp_with_conv, compress_time=False)
+ else:
+ up.upsample = Upsample3D(block_in, resamp_with_conv, compress_time=True)
+ curr_res = curr_res * 2
+ self.up.insert(0, up) # prepend to get consistent order
+
+ self.norm_out = Normalize3D(block_in, zq_ch, add_conv=add_conv)
+ # self.conv_out = torch.nn.Conv3d(block_in,
+ # out_ch,
+ # kernel_size=3,
+ # stride=1,
+ # padding=1)
+ self.conv_out = CausalConv3d(block_in, out_ch, kernel_size=3, pad_mode=pad_mode)
+
+ def forward(self, z):
+ # assert z.shape[1:] == self.z_shape[1:]
+ self.last_z_shape = z.shape
+
+ # timestep embedding
+ temb = None
+
+ t = z.shape[2]
+ # z to block_in
+
+ zq = z
+ if self.post_quant_conv is not None:
+ z = self.post_quant_conv(z)
+ h = self.conv_in(z)
+
+ # middle
+ h = self.mid.block_1(h, temb, zq)
+ # h = self.mid.attn_1(h, zq)
+ h = self.mid.block_2(h, temb, zq)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks + 1):
+ h = self.up[i_level].block[i_block](h, temb, zq)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h, zq)
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+
+ # end
+ if self.give_pre_end:
+ return h
+
+ h = self.norm_out(h, zq)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+ def get_last_layer(self):
+ return self.conv_out.conv.weight
diff --git a/sat/sgm/modules/autoencoding/vqvae/movq_enc_3d.py b/sat/sgm/modules/autoencoding/vqvae/movq_enc_3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..99b358df877c6f59f9ad62c1bb168339042f1900
--- /dev/null
+++ b/sat/sgm/modules/autoencoding/vqvae/movq_enc_3d.py
@@ -0,0 +1,413 @@
+# pytorch_diffusion + derived encoder decoder
+import math
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+
+from beartype import beartype
+from beartype.typing import Union, Tuple, Optional, List
+from einops import rearrange
+
+
+def cast_tuple(t, length=1):
+ return t if isinstance(t, tuple) else ((t,) * length)
+
+
+def divisible_by(num, den):
+ return (num % den) == 0
+
+
+def is_odd(n):
+ return not divisible_by(n, 2)
+
+
+def get_timestep_embedding(timesteps, embedding_dim):
+ """
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
+ From Fairseq.
+ Build sinusoidal embeddings.
+ This matches the implementation in tensor2tensor, but differs slightly
+ from the description in Section 3.5 of "Attention Is All You Need".
+ """
+ assert len(timesteps.shape) == 1
+
+ half_dim = embedding_dim // 2
+ emb = math.log(10000) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
+ emb = emb.to(device=timesteps.device)
+ emb = timesteps.float()[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
+ return emb
+
+
+def nonlinearity(x):
+ # swish
+ return x * torch.sigmoid(x)
+
+
+class CausalConv3d(nn.Module):
+ @beartype
+ def __init__(
+ self, chan_in, chan_out, kernel_size: Union[int, Tuple[int, int, int]], pad_mode="constant", **kwargs
+ ):
+ super().__init__()
+ kernel_size = cast_tuple(kernel_size, 3)
+
+ time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
+
+ assert is_odd(height_kernel_size) and is_odd(width_kernel_size)
+
+ dilation = kwargs.pop("dilation", 1)
+ stride = kwargs.pop("stride", 1)
+
+ self.pad_mode = pad_mode
+ time_pad = dilation * (time_kernel_size - 1) + (1 - stride)
+ height_pad = height_kernel_size // 2
+ width_pad = width_kernel_size // 2
+
+ self.height_pad = height_pad
+ self.width_pad = width_pad
+ self.time_pad = time_pad
+ self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0)
+
+ stride = (stride, 1, 1)
+ dilation = (dilation, 1, 1)
+ self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
+
+ def forward(self, x):
+ if self.pad_mode == "constant":
+ causal_padding_3d = (self.time_pad, 0, self.width_pad, self.width_pad, self.height_pad, self.height_pad)
+ x = F.pad(x, causal_padding_3d, mode="constant", value=0)
+ elif self.pad_mode == "first":
+ pad_x = torch.cat([x[:, :, :1]] * self.time_pad, dim=2)
+ x = torch.cat([pad_x, x], dim=2)
+ causal_padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
+ x = F.pad(x, causal_padding_2d, mode="constant", value=0)
+ elif self.pad_mode == "reflect":
+ # reflect padding
+ reflect_x = x[:, :, 1 : self.time_pad + 1, :, :].flip(dims=[2])
+ if reflect_x.shape[2] < self.time_pad:
+ reflect_x = torch.cat(
+ [torch.zeros_like(x[:, :, :1, :, :])] * (self.time_pad - reflect_x.shape[2]) + [reflect_x], dim=2
+ )
+ x = torch.cat([reflect_x, x], dim=2)
+ causal_padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
+ x = F.pad(x, causal_padding_2d, mode="constant", value=0)
+ else:
+ raise ValueError("Invalid pad mode")
+ return self.conv(x)
+
+
+def Normalize3D(in_channels): # same for 3D and 2D
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
+
+
+class Upsample3D(nn.Module):
+ def __init__(self, in_channels, with_conv, compress_time=False):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
+ self.compress_time = compress_time
+
+ def forward(self, x):
+ if self.compress_time:
+ if x.shape[2] > 1:
+ # split first frame
+ x_first, x_rest = x[:, :, 0], x[:, :, 1:]
+
+ x_first = torch.nn.functional.interpolate(x_first, scale_factor=2.0, mode="nearest")
+ x_rest = torch.nn.functional.interpolate(x_rest, scale_factor=2.0, mode="nearest")
+ x = torch.cat([x_first[:, :, None, :, :], x_rest], dim=2)
+ else:
+ x = x.squeeze(2)
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
+ x = x[:, :, None, :, :]
+ else:
+ # only interpolate 2D
+ t = x.shape[2]
+ x = rearrange(x, "b c t h w -> (b t) c h w")
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
+ x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
+
+ if self.with_conv:
+ t = x.shape[2]
+ x = rearrange(x, "b c t h w -> (b t) c h w")
+ x = self.conv(x)
+ x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
+ return x
+
+
+class DownSample3D(nn.Module):
+ def __init__(self, in_channels, with_conv, compress_time=False, out_channels=None):
+ super().__init__()
+ self.with_conv = with_conv
+ if out_channels is None:
+ out_channels = in_channels
+ if self.with_conv:
+ # no asymmetric padding in torch conv, must do it ourselves
+ self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0)
+ self.compress_time = compress_time
+
+ def forward(self, x):
+ if self.compress_time:
+ h, w = x.shape[-2:]
+ x = rearrange(x, "b c t h w -> (b h w) c t")
+
+ # split first frame
+ x_first, x_rest = x[..., 0], x[..., 1:]
+
+ if x_rest.shape[-1] > 0:
+ x_rest = torch.nn.functional.avg_pool1d(x_rest, kernel_size=2, stride=2)
+ x = torch.cat([x_first[..., None], x_rest], dim=-1)
+ x = rearrange(x, "(b h w) c t -> b c t h w", h=h, w=w)
+
+ if self.with_conv:
+ pad = (0, 1, 0, 1)
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
+ t = x.shape[2]
+ x = rearrange(x, "b c t h w -> (b t) c h w")
+ x = self.conv(x)
+ x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
+ else:
+ t = x.shape[2]
+ x = rearrange(x, "b c t h w -> (b t) c h w")
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
+ x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
+ return x
+
+
+class ResnetBlock3D(nn.Module):
+ def __init__(
+ self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512, pad_mode="constant"
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+ self.use_conv_shortcut = conv_shortcut
+
+ self.norm1 = Normalize3D(in_channels)
+ # self.conv1 = torch.nn.Conv3d(in_channels,
+ # out_channels,
+ # kernel_size=3,
+ # stride=1,
+ # padding=1)
+ self.conv1 = CausalConv3d(in_channels, out_channels, kernel_size=3, pad_mode=pad_mode)
+ if temb_channels > 0:
+ self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
+ self.norm2 = Normalize3D(out_channels)
+ self.dropout = torch.nn.Dropout(dropout)
+ # self.conv2 = torch.nn.Conv3d(out_channels,
+ # out_channels,
+ # kernel_size=3,
+ # stride=1,
+ # padding=1)
+ self.conv2 = CausalConv3d(out_channels, out_channels, kernel_size=3, pad_mode=pad_mode)
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ # self.conv_shortcut = torch.nn.Conv3d(in_channels,
+ # out_channels,
+ # kernel_size=3,
+ # stride=1,
+ # padding=1)
+ self.conv_shortcut = CausalConv3d(in_channels, out_channels, kernel_size=3, pad_mode=pad_mode)
+ else:
+ self.nin_shortcut = torch.nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
+ # self.nin_shortcut = CausalConv3d(in_channels, out_channels, kernel_size=1, pad_mode=pad_mode)
+
+ def forward(self, x, temb):
+ h = x
+ h = self.norm1(h)
+ h = nonlinearity(h)
+ h = self.conv1(h)
+
+ if temb is not None:
+ h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None, None]
+
+ h = self.norm2(h)
+ h = nonlinearity(h)
+ h = self.dropout(h)
+ h = self.conv2(h)
+
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ x = self.conv_shortcut(x)
+ else:
+ x = self.nin_shortcut(x)
+
+ return x + h
+
+
+class AttnBlock2D(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize3D(in_channels)
+ self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+ self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+ self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+ self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+
+ t = h_.shape[2]
+ h_ = rearrange(h_, "b c t h w -> (b t) c h w")
+
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b, c, h, w = q.shape
+ q = q.reshape(b, c, h * w)
+ q = q.permute(0, 2, 1) # b,hw,c
+ k = k.reshape(b, c, h * w) # b,c,hw
+
+ # # original version, nan in fp16
+ # w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
+ # w_ = w_ * (int(c)**(-0.5))
+ # # implement c**-0.5 on q
+ q = q * (int(c) ** (-0.5))
+ w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
+
+ w_ = torch.nn.functional.softmax(w_, dim=2)
+
+ # attend to values
+ v = v.reshape(b, c, h * w)
+ w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
+ h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
+ h_ = h_.reshape(b, c, h, w)
+
+ h_ = self.proj_out(h_)
+
+ h_ = rearrange(h_, "(b t) c h w -> b c t h w", t=t)
+
+ return x + h_
+
+
+class Encoder3D(nn.Module):
+ def __init__(
+ self,
+ *,
+ ch,
+ out_ch,
+ ch_mult=(1, 2, 4, 8),
+ num_res_blocks,
+ attn_resolutions,
+ dropout=0.0,
+ resamp_with_conv=True,
+ in_channels,
+ resolution,
+ z_channels,
+ double_z=True,
+ pad_mode="first",
+ temporal_compress_times=4,
+ **ignore_kwargs,
+ ):
+ super().__init__()
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+
+ # log2 of temporal_compress_times
+ self.temporal_compress_level = int(np.log2(temporal_compress_times))
+
+ # downsampling
+ # self.conv_in = torch.nn.Conv3d(in_channels,
+ # self.ch,
+ # kernel_size=3,
+ # stride=1,
+ # padding=1)
+ self.conv_in = CausalConv3d(in_channels, self.ch, kernel_size=3, pad_mode=pad_mode)
+
+ curr_res = resolution
+ in_ch_mult = (1,) + tuple(ch_mult)
+ self.down = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_in = ch * in_ch_mult[i_level]
+ block_out = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks):
+ block.append(
+ ResnetBlock3D(
+ in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ pad_mode=pad_mode,
+ )
+ )
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(AttnBlock2D(block_in))
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ if i_level != self.num_resolutions - 1:
+ if i_level < self.temporal_compress_level:
+ down.downsample = DownSample3D(block_in, resamp_with_conv, compress_time=True)
+ else:
+ down.downsample = DownSample3D(block_in, resamp_with_conv, compress_time=False)
+ curr_res = curr_res // 2
+ self.down.append(down)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock3D(
+ in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout, pad_mode=pad_mode
+ )
+ # remove attention block
+ # self.mid.attn_1 = AttnBlock2D(block_in)
+ self.mid.block_2 = ResnetBlock3D(
+ in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout, pad_mode=pad_mode
+ )
+
+ # end
+ self.norm_out = Normalize3D(block_in)
+ # self.conv_out = torch.nn.Conv3d(block_in,
+ # 2*z_channels if double_z else z_channels,
+ # kernel_size=3,
+ # stride=1,
+ # padding=1)
+ self.conv_out = CausalConv3d(
+ block_in, 2 * z_channels if double_z else z_channels, kernel_size=3, pad_mode=pad_mode
+ )
+
+ def forward(self, x, use_cp=False):
+ # assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution)
+ # timestep embedding
+ temb = None
+
+ # downsampling
+ hs = [self.conv_in(x)]
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](hs[-1], temb)
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h)
+ hs.append(h)
+ if i_level != self.num_resolutions - 1:
+ hs.append(self.down[i_level].downsample(hs[-1]))
+
+ # middle
+ h = hs[-1]
+ h = self.mid.block_1(h, temb)
+ # h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # end
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
diff --git a/sat/sgm/modules/autoencoding/vqvae/movq_modules.py b/sat/sgm/modules/autoencoding/vqvae/movq_modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..2773b0f2a67a6b4a68579c38962a6852e789e209
--- /dev/null
+++ b/sat/sgm/modules/autoencoding/vqvae/movq_modules.py
@@ -0,0 +1,368 @@
+# pytorch_diffusion + derived encoder decoder
+import math
+import torch
+import torch.nn as nn
+import numpy as np
+
+
+def get_timestep_embedding(timesteps, embedding_dim):
+ """
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
+ From Fairseq.
+ Build sinusoidal embeddings.
+ This matches the implementation in tensor2tensor, but differs slightly
+ from the description in Section 3.5 of "Attention Is All You Need".
+ """
+ assert len(timesteps.shape) == 1
+
+ half_dim = embedding_dim // 2
+ emb = math.log(10000) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
+ emb = emb.to(device=timesteps.device)
+ emb = timesteps.float()[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
+ return emb
+
+
+def nonlinearity(x):
+ # swish
+ return x * torch.sigmoid(x)
+
+
+class SpatialNorm(nn.Module):
+ def __init__(
+ self,
+ f_channels,
+ zq_channels,
+ norm_layer=nn.GroupNorm,
+ freeze_norm_layer=False,
+ add_conv=False,
+ **norm_layer_params,
+ ):
+ super().__init__()
+ self.norm_layer = norm_layer(num_channels=f_channels, **norm_layer_params)
+ if freeze_norm_layer:
+ for p in self.norm_layer.parameters:
+ p.requires_grad = False
+ self.add_conv = add_conv
+ if self.add_conv:
+ self.conv = nn.Conv2d(zq_channels, zq_channels, kernel_size=3, stride=1, padding=1)
+ self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
+ self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
+
+ def forward(self, f, zq):
+ f_size = f.shape[-2:]
+ zq = torch.nn.functional.interpolate(zq, size=f_size, mode="nearest")
+ if self.add_conv:
+ zq = self.conv(zq)
+ norm_f = self.norm_layer(f)
+ new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
+ return new_f
+
+
+def Normalize(in_channels, zq_ch, add_conv):
+ return SpatialNorm(
+ in_channels,
+ zq_ch,
+ norm_layer=nn.GroupNorm,
+ freeze_norm_layer=False,
+ add_conv=add_conv,
+ num_groups=32,
+ eps=1e-6,
+ affine=True,
+ )
+
+
+class Upsample(nn.Module):
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
+
+ def forward(self, x):
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
+ if self.with_conv:
+ x = self.conv(x)
+ return x
+
+
+class Downsample(nn.Module):
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ # no asymmetric padding in torch conv, must do it ourselves
+ self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
+
+ def forward(self, x):
+ if self.with_conv:
+ pad = (0, 1, 0, 1)
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
+ x = self.conv(x)
+ else:
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
+ return x
+
+
+class ResnetBlock(nn.Module):
+ def __init__(
+ self,
+ *,
+ in_channels,
+ out_channels=None,
+ conv_shortcut=False,
+ dropout,
+ temb_channels=512,
+ zq_ch=None,
+ add_conv=False,
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+ self.use_conv_shortcut = conv_shortcut
+
+ self.norm1 = Normalize(in_channels, zq_ch, add_conv=add_conv)
+ self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
+ if temb_channels > 0:
+ self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
+ self.norm2 = Normalize(out_channels, zq_ch, add_conv=add_conv)
+ self.dropout = torch.nn.Dropout(dropout)
+ self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
+ else:
+ self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
+
+ def forward(self, x, temb, zq):
+ h = x
+ h = self.norm1(h, zq)
+ h = nonlinearity(h)
+ h = self.conv1(h)
+
+ if temb is not None:
+ h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
+
+ h = self.norm2(h, zq)
+ h = nonlinearity(h)
+ h = self.dropout(h)
+ h = self.conv2(h)
+
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ x = self.conv_shortcut(x)
+ else:
+ x = self.nin_shortcut(x)
+
+ return x + h
+
+
+class AttnBlock(nn.Module):
+ def __init__(self, in_channels, zq_ch=None, add_conv=False):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels, zq_ch, add_conv=add_conv)
+ self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+ self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+ self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+ self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+
+ def forward(self, x, zq):
+ h_ = x
+ h_ = self.norm(h_, zq)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b, c, h, w = q.shape
+ q = q.reshape(b, c, h * w)
+ q = q.permute(0, 2, 1) # b,hw,c
+ k = k.reshape(b, c, h * w) # b,c,hw
+ w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
+ w_ = w_ * (int(c) ** (-0.5))
+ w_ = torch.nn.functional.softmax(w_, dim=2)
+
+ # attend to values
+ v = v.reshape(b, c, h * w)
+ w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
+ h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
+ h_ = h_.reshape(b, c, h, w)
+
+ h_ = self.proj_out(h_)
+
+ return x + h_
+
+
+class MOVQDecoder(nn.Module):
+ def __init__(
+ self,
+ *,
+ ch,
+ out_ch,
+ ch_mult=(1, 2, 4, 8),
+ num_res_blocks,
+ attn_resolutions,
+ dropout=0.0,
+ resamp_with_conv=True,
+ in_channels,
+ resolution,
+ z_channels,
+ give_pre_end=False,
+ zq_ch=None,
+ add_conv=False,
+ **ignorekwargs,
+ ):
+ super().__init__()
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+ self.give_pre_end = give_pre_end
+
+ # compute in_ch_mult, block_in and curr_res at lowest res
+ in_ch_mult = (1,) + tuple(ch_mult)
+ block_in = ch * ch_mult[self.num_resolutions - 1]
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
+ self.z_shape = (1, z_channels, curr_res, curr_res)
+ print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
+
+ # z to block_in
+ self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ zq_ch=zq_ch,
+ add_conv=add_conv,
+ )
+ self.mid.attn_1 = AttnBlock(block_in, zq_ch, add_conv=add_conv)
+ self.mid.block_2 = ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ zq_ch=zq_ch,
+ add_conv=add_conv,
+ )
+
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks + 1):
+ block.append(
+ ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ zq_ch=zq_ch,
+ add_conv=add_conv,
+ )
+ )
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(AttnBlock(block_in, zq_ch, add_conv=add_conv))
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ up.upsample = Upsample(block_in, resamp_with_conv)
+ curr_res = curr_res * 2
+ self.up.insert(0, up) # prepend to get consistent order
+
+ # end
+ self.norm_out = Normalize(block_in, zq_ch, add_conv=add_conv)
+ self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
+
+ def forward(self, z, zq):
+ # assert z.shape[1:] == self.z_shape[1:]
+ self.last_z_shape = z.shape
+
+ # timestep embedding
+ temb = None
+
+ # z to block_in
+ h = self.conv_in(z)
+
+ # middle
+ h = self.mid.block_1(h, temb, zq)
+ h = self.mid.attn_1(h, zq)
+ h = self.mid.block_2(h, temb, zq)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks + 1):
+ h = self.up[i_level].block[i_block](h, temb, zq)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h, zq)
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+
+ # end
+ if self.give_pre_end:
+ return h
+
+ h = self.norm_out(h, zq)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+ def forward_with_features_output(self, z, zq):
+ # assert z.shape[1:] == self.z_shape[1:]
+ self.last_z_shape = z.shape
+
+ # timestep embedding
+ temb = None
+ output_features = {}
+
+ # z to block_in
+ h = self.conv_in(z)
+ output_features["conv_in"] = h
+
+ # middle
+ h = self.mid.block_1(h, temb, zq)
+ output_features["mid_block_1"] = h
+ h = self.mid.attn_1(h, zq)
+ output_features["mid_attn_1"] = h
+ h = self.mid.block_2(h, temb, zq)
+ output_features["mid_block_2"] = h
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks + 1):
+ h = self.up[i_level].block[i_block](h, temb, zq)
+ output_features[f"up_{i_level}_block_{i_block}"] = h
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h, zq)
+ output_features[f"up_{i_level}_attn_{i_block}"] = h
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+ output_features[f"up_{i_level}_upsample"] = h
+
+ # end
+ if self.give_pre_end:
+ return h
+
+ h = self.norm_out(h, zq)
+ output_features["norm_out"] = h
+ h = nonlinearity(h)
+ output_features["nonlinearity"] = h
+ h = self.conv_out(h)
+ output_features["conv_out"] = h
+
+ return h, output_features
diff --git a/sat/sgm/modules/autoencoding/vqvae/quantize.py b/sat/sgm/modules/autoencoding/vqvae/quantize.py
new file mode 100644
index 0000000000000000000000000000000000000000..54ea128fc3279c04048b2c89b7d33b551ea4db58
--- /dev/null
+++ b/sat/sgm/modules/autoencoding/vqvae/quantize.py
@@ -0,0 +1,241 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+from torch import einsum
+from einops import rearrange
+
+
+class VectorQuantizer2(nn.Module):
+ """
+ Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
+ avoids costly matrix multiplications and allows for post-hoc remapping of indices.
+ """
+
+ # NOTE: due to a bug the beta term was applied to the wrong term. for
+ # backwards compatibility we use the buggy version by default, but you can
+ # specify legacy=False to fix it.
+ def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", sane_index_shape=False, legacy=True):
+ super().__init__()
+ self.n_e = n_e
+ self.e_dim = e_dim
+ self.beta = beta
+ self.legacy = legacy
+
+ self.embedding = nn.Embedding(self.n_e, self.e_dim)
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
+
+ self.remap = remap
+ if self.remap is not None:
+ self.register_buffer("used", torch.tensor(np.load(self.remap)))
+ self.re_embed = self.used.shape[0]
+ self.unknown_index = unknown_index # "random" or "extra" or integer
+ if self.unknown_index == "extra":
+ self.unknown_index = self.re_embed
+ self.re_embed = self.re_embed + 1
+ print(
+ f"Remapping {self.n_e} indices to {self.re_embed} indices. "
+ f"Using {self.unknown_index} for unknown indices."
+ )
+ else:
+ self.re_embed = n_e
+
+ self.sane_index_shape = sane_index_shape
+
+ def remap_to_used(self, inds):
+ ishape = inds.shape
+ assert len(ishape) > 1
+ inds = inds.reshape(ishape[0], -1)
+ used = self.used.to(inds)
+ match = (inds[:, :, None] == used[None, None, ...]).long()
+ new = match.argmax(-1)
+ unknown = match.sum(2) < 1
+ if self.unknown_index == "random":
+ new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device)
+ else:
+ new[unknown] = self.unknown_index
+ return new.reshape(ishape)
+
+ def unmap_to_all(self, inds):
+ ishape = inds.shape
+ assert len(ishape) > 1
+ inds = inds.reshape(ishape[0], -1)
+ used = self.used.to(inds)
+ if self.re_embed > self.used.shape[0]: # extra token
+ inds[inds >= self.used.shape[0]] = 0 # simply set to zero
+ back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
+ return back.reshape(ishape)
+
+ def forward(self, z, temp=None, rescale_logits=False, return_logits=False):
+ assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel"
+ assert rescale_logits == False, "Only for interface compatible with Gumbel"
+ assert return_logits == False, "Only for interface compatible with Gumbel"
+ # reshape z -> (batch, height, width, channel) and flatten
+ z = rearrange(z, "b c h w -> b h w c").contiguous()
+ z_flattened = z.view(-1, self.e_dim)
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
+
+ d = (
+ torch.sum(z_flattened**2, dim=1, keepdim=True)
+ + torch.sum(self.embedding.weight**2, dim=1)
+ - 2 * torch.einsum("bd,dn->bn", z_flattened, rearrange(self.embedding.weight, "n d -> d n"))
+ )
+
+ min_encoding_indices = torch.argmin(d, dim=1)
+ z_q = self.embedding(min_encoding_indices).view(z.shape)
+ perplexity = None
+ min_encodings = None
+
+ # compute loss for embedding
+ if not self.legacy:
+ loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2)
+ else:
+ loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
+
+ # preserve gradients
+ z_q = z + (z_q - z).detach()
+
+ # reshape back to match original input shape
+ z_q = rearrange(z_q, "b h w c -> b c h w").contiguous()
+
+ if self.remap is not None:
+ min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis
+ min_encoding_indices = self.remap_to_used(min_encoding_indices)
+ min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
+
+ if self.sane_index_shape:
+ min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3])
+
+ return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
+
+ def get_codebook_entry(self, indices, shape):
+ # shape specifying (batch, height, width, channel)
+ if self.remap is not None:
+ indices = indices.reshape(shape[0], -1) # add batch axis
+ indices = self.unmap_to_all(indices)
+ indices = indices.reshape(-1) # flatten again
+
+ # get quantized latent vectors
+ z_q = self.embedding(indices)
+
+ if shape is not None:
+ z_q = z_q.view(shape)
+ # reshape back to match original input shape
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
+
+ return z_q
+
+
+class GumbelQuantize(nn.Module):
+ """
+ credit to @karpathy: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py (thanks!)
+ Gumbel Softmax trick quantizer
+ Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016
+ https://arxiv.org/abs/1611.01144
+ """
+
+ def __init__(
+ self,
+ num_hiddens,
+ embedding_dim,
+ n_embed,
+ straight_through=True,
+ kl_weight=5e-4,
+ temp_init=1.0,
+ use_vqinterface=True,
+ remap=None,
+ unknown_index="random",
+ ):
+ super().__init__()
+
+ self.embedding_dim = embedding_dim
+ self.n_embed = n_embed
+
+ self.straight_through = straight_through
+ self.temperature = temp_init
+ self.kl_weight = kl_weight
+
+ self.proj = nn.Conv2d(num_hiddens, n_embed, 1)
+ self.embed = nn.Embedding(n_embed, embedding_dim)
+
+ self.use_vqinterface = use_vqinterface
+
+ self.remap = remap
+ if self.remap is not None:
+ self.register_buffer("used", torch.tensor(np.load(self.remap)))
+ self.re_embed = self.used.shape[0]
+ self.unknown_index = unknown_index # "random" or "extra" or integer
+ if self.unknown_index == "extra":
+ self.unknown_index = self.re_embed
+ self.re_embed = self.re_embed + 1
+ print(
+ f"Remapping {self.n_embed} indices to {self.re_embed} indices. "
+ f"Using {self.unknown_index} for unknown indices."
+ )
+ else:
+ self.re_embed = n_embed
+
+ def remap_to_used(self, inds):
+ ishape = inds.shape
+ assert len(ishape) > 1
+ inds = inds.reshape(ishape[0], -1)
+ used = self.used.to(inds)
+ match = (inds[:, :, None] == used[None, None, ...]).long()
+ new = match.argmax(-1)
+ unknown = match.sum(2) < 1
+ if self.unknown_index == "random":
+ new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device)
+ else:
+ new[unknown] = self.unknown_index
+ return new.reshape(ishape)
+
+ def unmap_to_all(self, inds):
+ ishape = inds.shape
+ assert len(ishape) > 1
+ inds = inds.reshape(ishape[0], -1)
+ used = self.used.to(inds)
+ if self.re_embed > self.used.shape[0]: # extra token
+ inds[inds >= self.used.shape[0]] = 0 # simply set to zero
+ back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
+ return back.reshape(ishape)
+
+ def forward(self, z, temp=None, return_logits=False):
+ # force hard = True when we are in eval mode, as we must quantize. actually, always true seems to work
+ hard = self.straight_through if self.training else True
+ temp = self.temperature if temp is None else temp
+
+ logits = self.proj(z)
+ if self.remap is not None:
+ # continue only with used logits
+ full_zeros = torch.zeros_like(logits)
+ logits = logits[:, self.used, ...]
+
+ soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard)
+ if self.remap is not None:
+ # go back to all entries but unused set to zero
+ full_zeros[:, self.used, ...] = soft_one_hot
+ soft_one_hot = full_zeros
+ z_q = einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight)
+
+ # + kl divergence to the prior loss
+ qy = F.softmax(logits, dim=1)
+ diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.n_embed + 1e-10), dim=1).mean()
+
+ ind = soft_one_hot.argmax(dim=1)
+ if self.remap is not None:
+ ind = self.remap_to_used(ind)
+ if self.use_vqinterface:
+ if return_logits:
+ return z_q, diff, (None, None, ind), logits
+ return z_q, diff, (None, None, ind)
+ return z_q, diff, ind
+
+ def get_codebook_entry(self, indices, shape):
+ b, h, w, c = shape
+ assert b * h * w == indices.shape[0]
+ indices = rearrange(indices, "(b h w) -> b h w", b=b, h=h, w=w)
+ if self.remap is not None:
+ indices = self.unmap_to_all(indices)
+ one_hot = F.one_hot(indices, num_classes=self.n_embed).permute(0, 3, 1, 2).float()
+ z_q = einsum("b n h w, n d -> b d h w", one_hot, self.embed.weight)
+ return z_q
diff --git a/sat/sgm/modules/autoencoding/vqvae/vqvae_blocks.py b/sat/sgm/modules/autoencoding/vqvae/vqvae_blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..e42154f1c9bcacdd8b2cd80452ba9104b0352fcc
--- /dev/null
+++ b/sat/sgm/modules/autoencoding/vqvae/vqvae_blocks.py
@@ -0,0 +1,402 @@
+# pytorch_diffusion + derived encoder decoder
+import math
+import torch
+import torch.nn as nn
+import numpy as np
+
+
+def get_timestep_embedding(timesteps, embedding_dim):
+ """
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
+ From Fairseq.
+ Build sinusoidal embeddings.
+ This matches the implementation in tensor2tensor, but differs slightly
+ from the description in Section 3.5 of "Attention Is All You Need".
+ """
+ assert len(timesteps.shape) == 1
+
+ half_dim = embedding_dim // 2
+ emb = math.log(10000) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
+ emb = emb.to(device=timesteps.device)
+ emb = timesteps.float()[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
+ return emb
+
+
+def nonlinearity(x):
+ # swish
+ return x * torch.sigmoid(x)
+
+
+def Normalize(in_channels):
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
+
+
+class Upsample(nn.Module):
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
+
+ def forward(self, x):
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
+ if self.with_conv:
+ x = self.conv(x)
+ return x
+
+
+class Downsample(nn.Module):
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ # no asymmetric padding in torch conv, must do it ourselves
+ self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
+
+ def forward(self, x):
+ if self.with_conv:
+ pad = (0, 1, 0, 1)
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
+ x = self.conv(x)
+ else:
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
+ return x
+
+
+class ResnetBlock(nn.Module):
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512):
+ super().__init__()
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+ self.use_conv_shortcut = conv_shortcut
+
+ self.norm1 = Normalize(in_channels)
+ self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
+ if temb_channels > 0:
+ self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
+ self.norm2 = Normalize(out_channels)
+ self.dropout = torch.nn.Dropout(dropout)
+ self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
+ else:
+ self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
+
+ def forward(self, x, temb):
+ h = x
+ h = self.norm1(h)
+ h = nonlinearity(h)
+ h = self.conv1(h)
+
+ if temb is not None:
+ h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
+
+ h = self.norm2(h)
+ h = nonlinearity(h)
+ h = self.dropout(h)
+ h = self.conv2(h)
+
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ x = self.conv_shortcut(x)
+ else:
+ x = self.nin_shortcut(x)
+
+ return x + h
+
+
+class AttnBlock(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+ self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+ self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+ self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b, c, h, w = q.shape
+ q = q.reshape(b, c, h * w)
+ q = q.permute(0, 2, 1) # b,hw,c
+ k = k.reshape(b, c, h * w) # b,c,hw
+
+ # # original version, nan in fp16
+ # w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
+ # w_ = w_ * (int(c)**(-0.5))
+ # # implement c**-0.5 on q
+ q = q * (int(c) ** (-0.5))
+ w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
+
+ w_ = torch.nn.functional.softmax(w_, dim=2)
+
+ # attend to values
+ v = v.reshape(b, c, h * w)
+ w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
+ h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
+ h_ = h_.reshape(b, c, h, w)
+
+ h_ = self.proj_out(h_)
+
+ return x + h_
+
+
+class Encoder(nn.Module):
+ def __init__(
+ self,
+ *,
+ ch,
+ out_ch,
+ ch_mult=(1, 2, 4, 8),
+ num_res_blocks,
+ attn_resolutions,
+ dropout=0.0,
+ resamp_with_conv=True,
+ in_channels,
+ resolution,
+ z_channels,
+ double_z=True,
+ **ignore_kwargs,
+ ):
+ super().__init__()
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+
+ # downsampling
+ self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
+
+ curr_res = resolution
+ in_ch_mult = (1,) + tuple(ch_mult)
+ self.down = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_in = ch * in_ch_mult[i_level]
+ block_out = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks):
+ block.append(
+ ResnetBlock(
+ in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout
+ )
+ )
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(AttnBlock(block_in))
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ if i_level != self.num_resolutions - 1:
+ down.downsample = Downsample(block_in, resamp_with_conv)
+ curr_res = curr_res // 2
+ self.down.append(down)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(
+ in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
+ )
+ self.mid.attn_1 = AttnBlock(block_in)
+ self.mid.block_2 = ResnetBlock(
+ in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
+ )
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(
+ block_in, 2 * z_channels if double_z else z_channels, kernel_size=3, stride=1, padding=1
+ )
+
+ def forward(self, x):
+ # assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution)
+
+ # timestep embedding
+ temb = None
+
+ # downsampling
+ hs = [self.conv_in(x)]
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](hs[-1], temb)
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h)
+ hs.append(h)
+ if i_level != self.num_resolutions - 1:
+ hs.append(self.down[i_level].downsample(hs[-1]))
+
+ # middle
+ h = hs[-1]
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # end
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+ def forward_with_features_output(self, x):
+ # assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution)
+
+ # timestep embedding
+ temb = None
+ output_features = {}
+
+ # downsampling
+ hs = [self.conv_in(x)]
+ output_features["conv_in"] = hs[-1]
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](hs[-1], temb)
+ output_features["down{}_block{}".format(i_level, i_block)] = h
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h)
+ output_features["down{}_attn{}".format(i_level, i_block)] = h
+ hs.append(h)
+ if i_level != self.num_resolutions - 1:
+ hs.append(self.down[i_level].downsample(hs[-1]))
+ output_features["down{}_downsample".format(i_level)] = hs[-1]
+
+ # middle
+ h = hs[-1]
+ h = self.mid.block_1(h, temb)
+ output_features["mid_block_1"] = h
+ h = self.mid.attn_1(h)
+ output_features["mid_attn_1"] = h
+ h = self.mid.block_2(h, temb)
+ output_features["mid_block_2"] = h
+
+ # end
+ h = self.norm_out(h)
+ output_features["norm_out"] = h
+ h = nonlinearity(h)
+ output_features["nonlinearity"] = h
+ h = self.conv_out(h)
+ output_features["conv_out"] = h
+
+ return h, output_features
+
+
+class Decoder(nn.Module):
+ def __init__(
+ self,
+ *,
+ ch,
+ out_ch,
+ ch_mult=(1, 2, 4, 8),
+ num_res_blocks,
+ attn_resolutions,
+ dropout=0.0,
+ resamp_with_conv=True,
+ in_channels,
+ resolution,
+ z_channels,
+ give_pre_end=False,
+ **ignorekwargs,
+ ):
+ super().__init__()
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+ self.give_pre_end = give_pre_end
+
+ # compute in_ch_mult, block_in and curr_res at lowest res
+ in_ch_mult = (1,) + tuple(ch_mult)
+ block_in = ch * ch_mult[self.num_resolutions - 1]
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
+ self.z_shape = (1, z_channels, curr_res, curr_res)
+ print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
+
+ # z to block_in
+ self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(
+ in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
+ )
+ self.mid.attn_1 = AttnBlock(block_in)
+ self.mid.block_2 = ResnetBlock(
+ in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
+ )
+
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks + 1):
+ block.append(
+ ResnetBlock(
+ in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout
+ )
+ )
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(AttnBlock(block_in))
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ up.upsample = Upsample(block_in, resamp_with_conv)
+ curr_res = curr_res * 2
+ self.up.insert(0, up) # prepend to get consistent order
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
+
+ def forward(self, z):
+ # assert z.shape[1:] == self.z_shape[1:]
+ self.last_z_shape = z.shape
+
+ # timestep embedding
+ temb = None
+
+ # z to block_in
+ h = self.conv_in(z)
+
+ # middle
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks + 1):
+ h = self.up[i_level].block[i_block](h, temb)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h)
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+
+ # end
+ if self.give_pre_end:
+ return h
+
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
diff --git a/sat/sgm/modules/cp_enc_dec.py b/sat/sgm/modules/cp_enc_dec.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a65d61dafe279ff753557130ff67225757ba222
--- /dev/null
+++ b/sat/sgm/modules/cp_enc_dec.py
@@ -0,0 +1,897 @@
+import math
+import torch
+import torch.distributed
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+
+from beartype import beartype
+from beartype.typing import Union, Tuple, Optional, List
+from einops import rearrange
+
+from ..util import (
+ get_context_parallel_group,
+ get_context_parallel_rank,
+ get_context_parallel_world_size,
+ get_context_parallel_group_rank,
+)
+
+# try:
+from ..util import SafeConv3d as Conv3d
+# except:
+# # Degrade to normal Conv3d if SafeConv3d is not available
+# from torch.nn import Conv3d
+
+_USE_CP = True
+
+
+def cast_tuple(t, length=1):
+ return t if isinstance(t, tuple) else ((t,) * length)
+
+
+def divisible_by(num, den):
+ return (num % den) == 0
+
+
+def is_odd(n):
+ return not divisible_by(n, 2)
+
+
+def exists(v):
+ return v is not None
+
+
+def pair(t):
+ return t if isinstance(t, tuple) else (t, t)
+
+
+def get_timestep_embedding(timesteps, embedding_dim):
+ """
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
+ From Fairseq.
+ Build sinusoidal embeddings.
+ This matches the implementation in tensor2tensor, but differs slightly
+ from the description in Section 3.5 of "Attention Is All You Need".
+ """
+ assert len(timesteps.shape) == 1
+
+ half_dim = embedding_dim // 2
+ emb = math.log(10000) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
+ emb = emb.to(device=timesteps.device)
+ emb = timesteps.float()[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
+ return emb
+
+
+def nonlinearity(x):
+ # swish
+ return x * torch.sigmoid(x)
+
+
+def leaky_relu(p=0.1):
+ return nn.LeakyReLU(p)
+
+
+def _split(input_, dim):
+ cp_world_size = get_context_parallel_world_size()
+
+ if cp_world_size == 1:
+ return input_
+
+ cp_rank = get_context_parallel_rank()
+
+ # print('in _split, cp_rank:', cp_rank, 'input_size:', input_.shape)
+
+ inpu_first_frame_ = input_.transpose(0, dim)[:1].transpose(0, dim).contiguous()
+ input_ = input_.transpose(0, dim)[1:].transpose(0, dim).contiguous()
+ dim_size = input_.size()[dim] // cp_world_size
+
+ input_list = torch.split(input_, dim_size, dim=dim)
+ output = input_list[cp_rank]
+
+ if cp_rank == 0:
+ output = torch.cat([inpu_first_frame_, output], dim=dim)
+ output = output.contiguous()
+
+ # print('out _split, cp_rank:', cp_rank, 'output_size:', output.shape)
+
+ return output
+
+
+def _gather(input_, dim):
+ cp_world_size = get_context_parallel_world_size()
+
+ # Bypass the function if context parallel is 1
+ if cp_world_size == 1:
+ return input_
+
+ group = get_context_parallel_group()
+ cp_rank = get_context_parallel_rank()
+
+ # print('in _gather, cp_rank:', cp_rank, 'input_size:', input_.shape)
+
+ input_first_frame_ = input_.transpose(0, dim)[:1].transpose(0, dim).contiguous()
+ if cp_rank == 0:
+ input_ = input_.transpose(0, dim)[1:].transpose(0, dim).contiguous()
+
+ tensor_list = [torch.empty_like(torch.cat([input_first_frame_, input_], dim=dim))] + [
+ torch.empty_like(input_) for _ in range(cp_world_size - 1)
+ ]
+
+ if cp_rank == 0:
+ input_ = torch.cat([input_first_frame_, input_], dim=dim)
+
+ tensor_list[cp_rank] = input_
+ torch.distributed.all_gather(tensor_list, input_, group=group)
+
+ output = torch.cat(tensor_list, dim=dim).contiguous()
+
+ # print('out _gather, cp_rank:', cp_rank, 'output_size:', output.shape)
+
+ return output
+
+
+def _conv_split(input_, dim, kernel_size):
+ cp_world_size = get_context_parallel_world_size()
+
+ # Bypass the function if context parallel is 1
+ if cp_world_size == 1:
+ return input_
+
+ # print('in _conv_split, cp_rank:', cp_rank, 'input_size:', input_.shape)
+
+ cp_rank = get_context_parallel_rank()
+
+ dim_size = (input_.size()[dim] - kernel_size) // cp_world_size
+
+ if cp_rank == 0:
+ output = input_.transpose(dim, 0)[: dim_size + kernel_size].transpose(dim, 0)
+ else:
+ output = input_.transpose(dim, 0)[cp_rank * dim_size + 1 : (cp_rank + 1) * dim_size + kernel_size].transpose(
+ dim, 0
+ )
+ output = output.contiguous()
+
+ # print('out _conv_split, cp_rank:', cp_rank, 'input_size:', output.shape)
+
+ return output
+
+
+def _conv_gather(input_, dim, kernel_size):
+ cp_world_size = get_context_parallel_world_size()
+
+ # Bypass the function if context parallel is 1
+ if cp_world_size == 1:
+ return input_
+
+ group = get_context_parallel_group()
+ cp_rank = get_context_parallel_rank()
+
+ # print('in _conv_gather, cp_rank:', cp_rank, 'input_size:', input_.shape)
+
+ input_first_kernel_ = input_.transpose(0, dim)[:kernel_size].transpose(0, dim).contiguous()
+ if cp_rank == 0:
+ input_ = input_.transpose(0, dim)[kernel_size:].transpose(0, dim).contiguous()
+ else:
+ input_ = input_.transpose(0, dim)[kernel_size - 1 :].transpose(0, dim).contiguous()
+
+ tensor_list = [torch.empty_like(torch.cat([input_first_kernel_, input_], dim=dim))] + [
+ torch.empty_like(input_) for _ in range(cp_world_size - 1)
+ ]
+ if cp_rank == 0:
+ input_ = torch.cat([input_first_kernel_, input_], dim=dim)
+
+ tensor_list[cp_rank] = input_
+ torch.distributed.all_gather(tensor_list, input_, group=group)
+
+ # Note: torch.cat already creates a contiguous tensor.
+ output = torch.cat(tensor_list, dim=dim).contiguous()
+
+ # print('out _conv_gather, cp_rank:', cp_rank, 'input_size:', output.shape)
+
+ return output
+
+
+def _pass_from_previous_rank(input_, dim, kernel_size):
+ # Bypass the function if kernel size is 1
+ if kernel_size == 1:
+ return input_
+
+ group = get_context_parallel_group()
+ cp_rank = get_context_parallel_rank()
+ cp_group_rank = get_context_parallel_group_rank()
+ cp_world_size = get_context_parallel_world_size()
+
+ # print('in _pass_from_previous_rank, cp_rank:', cp_rank, 'input_size:', input_.shape)
+
+ global_rank = torch.distributed.get_rank()
+ global_world_size = torch.distributed.get_world_size()
+
+ input_ = input_.transpose(0, dim)
+
+ # pass from last rank
+ send_rank = global_rank + 1
+ recv_rank = global_rank - 1
+ if send_rank % cp_world_size == 0:
+ send_rank -= cp_world_size
+ if recv_rank % cp_world_size == cp_world_size - 1:
+ recv_rank += cp_world_size
+
+ if cp_rank < cp_world_size - 1:
+ req_send = torch.distributed.isend(input_[-kernel_size + 1 :].contiguous(), send_rank, group=group)
+ if cp_rank > 0:
+ recv_buffer = torch.empty_like(input_[-kernel_size + 1 :]).contiguous()
+ req_recv = torch.distributed.irecv(recv_buffer, recv_rank, group=group)
+
+ if cp_rank == 0:
+ input_ = torch.cat([input_[:1]] * (kernel_size - 1) + [input_], dim=0)
+ else:
+ req_recv.wait()
+ input_ = torch.cat([recv_buffer, input_], dim=0)
+
+ input_ = input_.transpose(0, dim).contiguous()
+
+ # print('out _pass_from_previous_rank, cp_rank:', cp_rank, 'input_size:', input_.shape)
+
+ return input_
+
+
+def _drop_from_previous_rank(input_, dim, kernel_size):
+ input_ = input_.transpose(0, dim)[kernel_size - 1 :].transpose(0, dim)
+ return input_
+
+
+class _ConvolutionScatterToContextParallelRegion(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, input_, dim, kernel_size):
+ ctx.dim = dim
+ ctx.kernel_size = kernel_size
+ return _conv_split(input_, dim, kernel_size)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ return _conv_gather(grad_output, ctx.dim, ctx.kernel_size), None, None
+
+
+class _ConvolutionGatherFromContextParallelRegion(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, input_, dim, kernel_size):
+ ctx.dim = dim
+ ctx.kernel_size = kernel_size
+ return _conv_gather(input_, dim, kernel_size)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ return _conv_split(grad_output, ctx.dim, ctx.kernel_size), None, None
+
+
+class _ConvolutionPassFromPreviousRank(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, input_, dim, kernel_size):
+ ctx.dim = dim
+ ctx.kernel_size = kernel_size
+ return _pass_from_previous_rank(input_, dim, kernel_size)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ return _drop_from_previous_rank(grad_output, ctx.dim, ctx.kernel_size), None, None
+
+
+def conv_scatter_to_context_parallel_region(input_, dim, kernel_size):
+ return _ConvolutionScatterToContextParallelRegion.apply(input_, dim, kernel_size)
+
+
+def conv_gather_from_context_parallel_region(input_, dim, kernel_size):
+ return _ConvolutionGatherFromContextParallelRegion.apply(input_, dim, kernel_size)
+
+
+def conv_pass_from_last_rank(input_, dim, kernel_size):
+ return _ConvolutionPassFromPreviousRank.apply(input_, dim, kernel_size)
+
+
+class ContextParallelCausalConv3d(nn.Module):
+ def __init__(self, chan_in, chan_out, kernel_size: Union[int, Tuple[int, int, int]], stride=1, **kwargs):
+ super().__init__()
+ kernel_size = cast_tuple(kernel_size, 3)
+
+ time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
+
+ assert is_odd(height_kernel_size) and is_odd(width_kernel_size)
+
+ time_pad = time_kernel_size - 1
+ height_pad = height_kernel_size // 2
+ width_pad = width_kernel_size // 2
+
+ self.height_pad = height_pad
+ self.width_pad = width_pad
+ self.time_pad = time_pad
+ self.time_kernel_size = time_kernel_size
+ self.temporal_dim = 2
+
+ stride = (stride, stride, stride)
+ dilation = (1, 1, 1)
+ self.conv = Conv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
+
+ def forward(self, input_):
+ # temporal padding inside
+ if _USE_CP:
+ input_parallel = conv_pass_from_last_rank(input_, self.temporal_dim, self.time_kernel_size)
+ else:
+ input_ = input_.transpose(0, self.temporal_dim)
+ input_parallel = torch.cat([input_[:1]] * (self.time_kernel_size - 1) + [input_], dim=0)
+ input_parallel = input_parallel.transpose(0, self.temporal_dim)
+ padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
+ input_parallel = F.pad(input_parallel, padding_2d, mode="constant", value=0)
+ output_parallel = self.conv(input_parallel)
+ output = output_parallel
+ return output
+
+
+class ContextParallelGroupNorm(torch.nn.GroupNorm):
+ def forward(self, input_):
+ if _USE_CP:
+ input_ = conv_gather_from_context_parallel_region(input_, dim=2, kernel_size=1)
+ output = super().forward(input_)
+ if _USE_CP:
+ output = conv_scatter_to_context_parallel_region(output, dim=2, kernel_size=1)
+ return output
+
+
+def Normalize(in_channels, gather=False, **kwargs): # same for 3D and 2D
+ if gather:
+ return ContextParallelGroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
+ else:
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
+
+
+class SpatialNorm3D(nn.Module):
+ def __init__(
+ self,
+ f_channels,
+ zq_channels,
+ freeze_norm_layer=False,
+ add_conv=False,
+ pad_mode="constant",
+ gather=False,
+ **norm_layer_params,
+ ):
+ super().__init__()
+ if gather:
+ self.norm_layer = ContextParallelGroupNorm(num_channels=f_channels, **norm_layer_params)
+ else:
+ self.norm_layer = torch.nn.GroupNorm(num_channels=f_channels, **norm_layer_params)
+ # self.norm_layer = norm_layer(num_channels=f_channels, **norm_layer_params)
+ if freeze_norm_layer:
+ for p in self.norm_layer.parameters:
+ p.requires_grad = False
+
+ self.add_conv = add_conv
+ if add_conv:
+ self.conv = ContextParallelCausalConv3d(
+ chan_in=zq_channels,
+ chan_out=zq_channels,
+ kernel_size=3,
+ )
+
+ self.conv_y = ContextParallelCausalConv3d(
+ chan_in=zq_channels,
+ chan_out=f_channels,
+ kernel_size=1,
+ )
+ self.conv_b = ContextParallelCausalConv3d(
+ chan_in=zq_channels,
+ chan_out=f_channels,
+ kernel_size=1,
+ )
+
+ def forward(self, f, zq):
+ if f.shape[2] == 1 and not _USE_CP:
+ zq = torch.nn.functional.interpolate(zq, size=f.shape[-3:], mode="nearest")
+ elif get_context_parallel_rank() == 0:
+ f_first, f_rest = f[:, :, :1], f[:, :, 1:]
+ f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:]
+ zq_first, zq_rest = zq[:, :, :1], zq[:, :, 1:]
+ zq_first = torch.nn.functional.interpolate(zq_first, size=f_first_size, mode="nearest")
+ zq_rest = torch.nn.functional.interpolate(zq_rest, size=f_rest_size, mode="nearest")
+ zq = torch.cat([zq_first, zq_rest], dim=2)
+ else:
+ zq = torch.nn.functional.interpolate(zq, size=f.shape[-3:], mode="nearest")
+
+ if self.add_conv:
+ zq = self.conv(zq)
+
+ # f = conv_gather_from_context_parallel_region(f, dim=2, kernel_size=1)
+ norm_f = self.norm_layer(f)
+ # norm_f = conv_scatter_to_context_parallel_region(norm_f, dim=2, kernel_size=1)
+
+ new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
+ return new_f
+
+
+def Normalize3D(
+ in_channels,
+ zq_ch,
+ add_conv,
+ gather=False,
+):
+ return SpatialNorm3D(
+ in_channels,
+ zq_ch,
+ gather=gather,
+ # norm_layer=nn.GroupNorm,
+ freeze_norm_layer=False,
+ add_conv=add_conv,
+ num_groups=32,
+ eps=1e-6,
+ affine=True,
+ )
+
+
+class Upsample3D(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ with_conv,
+ compress_time=False,
+ ):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
+ self.compress_time = compress_time
+
+ def forward(self, x):
+ if self.compress_time:
+ if x.shape[2] == 1 and not _USE_CP:
+ x = torch.nn.functional.interpolate(x[:, :, 0], scale_factor=2.0, mode="nearest")[:, :, None, :, :]
+ elif get_context_parallel_rank() == 0:
+ # split first frame
+ x_first, x_rest = x[:, :, 0], x[:, :, 1:]
+
+ x_first = torch.nn.functional.interpolate(x_first, scale_factor=2.0, mode="nearest")
+ x_rest = torch.nn.functional.interpolate(x_rest, scale_factor=2.0, mode="nearest")
+ x = torch.cat([x_first[:, :, None, :, :], x_rest], dim=2)
+ else:
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
+
+ else:
+ # only interpolate 2D
+ t = x.shape[2]
+ x = rearrange(x, "b c t h w -> (b t) c h w")
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
+ x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
+
+ if self.with_conv:
+ t = x.shape[2]
+ x = rearrange(x, "b c t h w -> (b t) c h w")
+ x = self.conv(x)
+ x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
+ return x
+
+
+class DownSample3D(nn.Module):
+ def __init__(self, in_channels, with_conv, compress_time=False, out_channels=None):
+ super().__init__()
+ self.with_conv = with_conv
+ if out_channels is None:
+ out_channels = in_channels
+ if self.with_conv:
+ # no asymmetric padding in torch conv, must do it ourselves
+ self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0)
+ self.compress_time = compress_time
+
+ def forward(self, x):
+ if self.compress_time and x.shape[2] > 1:
+ h, w = x.shape[-2:]
+ x = rearrange(x, "b c t h w -> (b h w) c t")
+
+ if x.shape[-1] % 2 == 1:
+ # split first frame
+ x_first, x_rest = x[..., 0], x[..., 1:]
+
+ if x_rest.shape[-1] > 0:
+ x_rest = torch.nn.functional.avg_pool1d(x_rest, kernel_size=2, stride=2)
+ x = torch.cat([x_first[..., None], x_rest], dim=-1)
+ x = rearrange(x, "(b h w) c t -> b c t h w", h=h, w=w)
+ else:
+ x = torch.nn.functional.avg_pool1d(x, kernel_size=2, stride=2)
+ x = rearrange(x, "(b h w) c t -> b c t h w", h=h, w=w)
+
+ if self.with_conv:
+ pad = (0, 1, 0, 1)
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
+ t = x.shape[2]
+ x = rearrange(x, "b c t h w -> (b t) c h w")
+ x = self.conv(x)
+ x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
+ else:
+ t = x.shape[2]
+ x = rearrange(x, "b c t h w -> (b t) c h w")
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
+ x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
+ return x
+
+
+class ContextParallelResnetBlock3D(nn.Module):
+ def __init__(
+ self,
+ *,
+ in_channels,
+ out_channels=None,
+ conv_shortcut=False,
+ dropout,
+ temb_channels=512,
+ zq_ch=None,
+ add_conv=False,
+ gather_norm=False,
+ normalization=Normalize,
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+ self.use_conv_shortcut = conv_shortcut
+
+ self.norm1 = normalization(
+ in_channels,
+ zq_ch=zq_ch,
+ add_conv=add_conv,
+ gather=gather_norm,
+ )
+
+ self.conv1 = ContextParallelCausalConv3d(
+ chan_in=in_channels,
+ chan_out=out_channels,
+ kernel_size=3,
+ )
+ if temb_channels > 0:
+ self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
+ self.norm2 = normalization(
+ out_channels,
+ zq_ch=zq_ch,
+ add_conv=add_conv,
+ gather=gather_norm,
+ )
+ self.dropout = torch.nn.Dropout(dropout)
+ self.conv2 = ContextParallelCausalConv3d(
+ chan_in=out_channels,
+ chan_out=out_channels,
+ kernel_size=3,
+ )
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ self.conv_shortcut = ContextParallelCausalConv3d(
+ chan_in=in_channels,
+ chan_out=out_channels,
+ kernel_size=3,
+ )
+ else:
+ self.nin_shortcut = Conv3d(
+ in_channels,
+ out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ )
+
+ def forward(self, x, temb, zq=None):
+ h = x
+
+ # if isinstance(self.norm1, torch.nn.GroupNorm):
+ # h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1)
+ if zq is not None:
+ h = self.norm1(h, zq)
+ else:
+ h = self.norm1(h)
+ # if isinstance(self.norm1, torch.nn.GroupNorm):
+ # h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1)
+
+ h = nonlinearity(h)
+ h = self.conv1(h)
+
+ if temb is not None:
+ h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None, None]
+
+ # if isinstance(self.norm2, torch.nn.GroupNorm):
+ # h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1)
+ if zq is not None:
+ h = self.norm2(h, zq)
+ else:
+ h = self.norm2(h)
+ # if isinstance(self.norm2, torch.nn.GroupNorm):
+ # h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1)
+
+ h = nonlinearity(h)
+ h = self.dropout(h)
+ h = self.conv2(h)
+
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ x = self.conv_shortcut(x)
+ else:
+ x = self.nin_shortcut(x)
+
+ return x + h
+
+
+class ContextParallelEncoder3D(nn.Module):
+ def __init__(
+ self,
+ *,
+ ch,
+ out_ch,
+ ch_mult=(1, 2, 4, 8),
+ num_res_blocks,
+ attn_resolutions,
+ dropout=0.0,
+ resamp_with_conv=True,
+ in_channels,
+ resolution,
+ z_channels,
+ double_z=True,
+ pad_mode="first",
+ temporal_compress_times=4,
+ gather_norm=False,
+ **ignore_kwargs,
+ ):
+ super().__init__()
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+
+ # log2 of temporal_compress_times
+ self.temporal_compress_level = int(np.log2(temporal_compress_times))
+
+ self.conv_in = ContextParallelCausalConv3d(
+ chan_in=in_channels,
+ chan_out=self.ch,
+ kernel_size=3,
+ )
+
+ curr_res = resolution
+ in_ch_mult = (1,) + tuple(ch_mult)
+ self.down = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_in = ch * in_ch_mult[i_level]
+ block_out = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks):
+ block.append(
+ ContextParallelResnetBlock3D(
+ in_channels=block_in,
+ out_channels=block_out,
+ dropout=dropout,
+ temb_channels=self.temb_ch,
+ gather_norm=gather_norm,
+ )
+ )
+ block_in = block_out
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ if i_level != self.num_resolutions - 1:
+ if i_level < self.temporal_compress_level:
+ down.downsample = DownSample3D(block_in, resamp_with_conv, compress_time=True)
+ else:
+ down.downsample = DownSample3D(block_in, resamp_with_conv, compress_time=False)
+ curr_res = curr_res // 2
+ self.down.append(down)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ContextParallelResnetBlock3D(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ gather_norm=gather_norm,
+ )
+
+ self.mid.block_2 = ContextParallelResnetBlock3D(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ gather_norm=gather_norm,
+ )
+
+ # end
+ self.norm_out = Normalize(block_in, gather=gather_norm)
+
+ self.conv_out = ContextParallelCausalConv3d(
+ chan_in=block_in,
+ chan_out=2 * z_channels if double_z else z_channels,
+ kernel_size=3,
+ )
+
+ def forward(self, x, use_cp=True):
+ global _USE_CP
+ _USE_CP = use_cp
+
+ # timestep embedding
+ temb = None
+
+ # downsampling
+ hs = [self.conv_in(x)]
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](hs[-1], temb)
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h)
+ hs.append(h)
+ if i_level != self.num_resolutions - 1:
+ hs.append(self.down[i_level].downsample(hs[-1]))
+ # middle
+ h = hs[-1]
+ h = self.mid.block_1(h, temb)
+ h = self.mid.block_2(h, temb)
+
+ # end
+ # h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1)
+ h = self.norm_out(h)
+ # h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1)
+
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+
+ return h
+
+
+class ContextParallelDecoder3D(nn.Module):
+ def __init__(
+ self,
+ *,
+ ch,
+ out_ch,
+ ch_mult=(1, 2, 4, 8),
+ num_res_blocks,
+ attn_resolutions,
+ dropout=0.0,
+ resamp_with_conv=True,
+ in_channels,
+ resolution,
+ z_channels,
+ give_pre_end=False,
+ zq_ch=None,
+ add_conv=False,
+ pad_mode="first",
+ temporal_compress_times=4,
+ gather_norm=False,
+ **ignorekwargs,
+ ):
+ super().__init__()
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+ self.give_pre_end = give_pre_end
+
+ # log2 of temporal_compress_times
+ self.temporal_compress_level = int(np.log2(temporal_compress_times))
+
+ if zq_ch is None:
+ zq_ch = z_channels
+
+ # compute in_ch_mult, block_in and curr_res at lowest res
+ in_ch_mult = (1,) + tuple(ch_mult)
+ block_in = ch * ch_mult[self.num_resolutions - 1]
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
+ self.z_shape = (1, z_channels, curr_res, curr_res)
+ print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
+
+ self.conv_in = ContextParallelCausalConv3d(
+ chan_in=z_channels,
+ chan_out=block_in,
+ kernel_size=3,
+ )
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ContextParallelResnetBlock3D(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ zq_ch=zq_ch,
+ add_conv=add_conv,
+ normalization=Normalize3D,
+ gather_norm=gather_norm,
+ )
+
+ self.mid.block_2 = ContextParallelResnetBlock3D(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ zq_ch=zq_ch,
+ add_conv=add_conv,
+ normalization=Normalize3D,
+ gather_norm=gather_norm,
+ )
+
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks + 1):
+ block.append(
+ ContextParallelResnetBlock3D(
+ in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ zq_ch=zq_ch,
+ add_conv=add_conv,
+ normalization=Normalize3D,
+ gather_norm=gather_norm,
+ )
+ )
+ block_in = block_out
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ if i_level < self.num_resolutions - self.temporal_compress_level:
+ up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=False)
+ else:
+ up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=True)
+ self.up.insert(0, up)
+
+ self.norm_out = Normalize3D(block_in, zq_ch, add_conv=add_conv, gather=gather_norm)
+
+ self.conv_out = ContextParallelCausalConv3d(
+ chan_in=block_in,
+ chan_out=out_ch,
+ kernel_size=3,
+ )
+
+ def forward(self, z, use_cp=True):
+ global _USE_CP
+ _USE_CP = use_cp
+ self.last_z_shape = z.shape
+
+ # timestep embedding
+ temb = None
+
+ t = z.shape[2]
+ # z to block_in
+
+ zq = z
+ h = self.conv_in(z)
+
+ # middle
+ h = self.mid.block_1(h, temb, zq)
+ h = self.mid.block_2(h, temb, zq)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks + 1):
+ h = self.up[i_level].block[i_block](h, temb, zq)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h, zq)
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+
+ # end
+ if self.give_pre_end:
+ return h
+
+ h = self.norm_out(h, zq)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ _USE_CP = True
+ return h
+
+ def get_last_layer(self):
+ return self.conv_out.conv.weight
diff --git a/sat/sgm/modules/diffusionmodules/__init__.py b/sat/sgm/modules/diffusionmodules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..fccebf954f5760fa559b17755e743c41daa1a824
--- /dev/null
+++ b/sat/sgm/modules/diffusionmodules/__init__.py
@@ -0,0 +1,6 @@
+from .denoiser import Denoiser
+from .discretizer import Discretization
+from .model import Decoder, Encoder, Model
+from .openaimodel import UNetModel
+from .sampling import BaseDiffusionSampler
+from .wrappers import OpenAIWrapper
diff --git a/sat/sgm/modules/diffusionmodules/denoiser.py b/sat/sgm/modules/diffusionmodules/denoiser.py
new file mode 100644
index 0000000000000000000000000000000000000000..3cc01e36f86117183ba8f6c5ee74f4c4cd579aed
--- /dev/null
+++ b/sat/sgm/modules/diffusionmodules/denoiser.py
@@ -0,0 +1,72 @@
+from typing import Dict, Union
+
+import torch
+import torch.nn as nn
+
+from ...util import append_dims, instantiate_from_config
+
+
+class Denoiser(nn.Module):
+ def __init__(self, weighting_config, scaling_config):
+ super().__init__()
+
+ self.weighting = instantiate_from_config(weighting_config)
+ self.scaling = instantiate_from_config(scaling_config)
+
+ def possibly_quantize_sigma(self, sigma):
+ return sigma
+
+ def possibly_quantize_c_noise(self, c_noise):
+ return c_noise
+
+ def w(self, sigma):
+ return self.weighting(sigma)
+
+ def forward(
+ self,
+ network: nn.Module,
+ input: torch.Tensor,
+ sigma: torch.Tensor,
+ cond: Dict,
+ **additional_model_inputs,
+ ) -> torch.Tensor:
+ sigma = self.possibly_quantize_sigma(sigma)
+ sigma_shape = sigma.shape
+ sigma = append_dims(sigma, input.ndim)
+ c_skip, c_out, c_in, c_noise = self.scaling(sigma, **additional_model_inputs)
+ c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape))
+ return network(input * c_in, c_noise, cond, **additional_model_inputs) * c_out + input * c_skip
+
+
+class DiscreteDenoiser(Denoiser):
+ def __init__(
+ self,
+ weighting_config,
+ scaling_config,
+ num_idx,
+ discretization_config,
+ do_append_zero=False,
+ quantize_c_noise=True,
+ flip=True,
+ ):
+ super().__init__(weighting_config, scaling_config)
+ sigmas = instantiate_from_config(discretization_config)(num_idx, do_append_zero=do_append_zero, flip=flip)
+ self.sigmas = sigmas
+ # self.register_buffer("sigmas", sigmas)
+ self.quantize_c_noise = quantize_c_noise
+
+ def sigma_to_idx(self, sigma):
+ dists = sigma - self.sigmas.to(sigma.device)[:, None]
+ return dists.abs().argmin(dim=0).view(sigma.shape)
+
+ def idx_to_sigma(self, idx):
+ return self.sigmas.to(idx.device)[idx]
+
+ def possibly_quantize_sigma(self, sigma):
+ return self.idx_to_sigma(self.sigma_to_idx(sigma))
+
+ def possibly_quantize_c_noise(self, c_noise):
+ if self.quantize_c_noise:
+ return self.sigma_to_idx(c_noise)
+ else:
+ return c_noise
diff --git a/sat/sgm/modules/diffusionmodules/denoiser_scaling.py b/sat/sgm/modules/diffusionmodules/denoiser_scaling.py
new file mode 100644
index 0000000000000000000000000000000000000000..2cb9643014435d908dbbd30c30c02c632373846b
--- /dev/null
+++ b/sat/sgm/modules/diffusionmodules/denoiser_scaling.py
@@ -0,0 +1,60 @@
+from abc import ABC, abstractmethod
+from typing import Any, Tuple
+
+import torch
+
+
+class DenoiserScaling(ABC):
+ @abstractmethod
+ def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ pass
+
+
+class EDMScaling:
+ def __init__(self, sigma_data: float = 0.5):
+ self.sigma_data = sigma_data
+
+ def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
+ c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5
+ c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5
+ c_noise = 0.25 * sigma.log()
+ return c_skip, c_out, c_in, c_noise
+
+
+class EpsScaling:
+ def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ c_skip = torch.ones_like(sigma, device=sigma.device)
+ c_out = -sigma
+ c_in = 1 / (sigma**2 + 1.0) ** 0.5
+ c_noise = sigma.clone()
+ return c_skip, c_out, c_in, c_noise
+
+
+class VScaling:
+ def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ c_skip = 1.0 / (sigma**2 + 1.0)
+ c_out = -sigma / (sigma**2 + 1.0) ** 0.5
+ c_in = 1.0 / (sigma**2 + 1.0) ** 0.5
+ c_noise = sigma.clone()
+ return c_skip, c_out, c_in, c_noise
+
+
+class VScalingWithEDMcNoise(DenoiserScaling):
+ def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ c_skip = 1.0 / (sigma**2 + 1.0)
+ c_out = -sigma / (sigma**2 + 1.0) ** 0.5
+ c_in = 1.0 / (sigma**2 + 1.0) ** 0.5
+ c_noise = 0.25 * sigma.log()
+ return c_skip, c_out, c_in, c_noise
+
+
+class VideoScaling: # similar to VScaling
+ def __call__(
+ self, alphas_cumprod_sqrt: torch.Tensor, **additional_model_inputs
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ c_skip = alphas_cumprod_sqrt
+ c_out = -((1 - alphas_cumprod_sqrt**2) ** 0.5)
+ c_in = torch.ones_like(alphas_cumprod_sqrt, device=alphas_cumprod_sqrt.device)
+ c_noise = additional_model_inputs["idx"].clone()
+ return c_skip, c_out, c_in, c_noise
diff --git a/sat/sgm/modules/diffusionmodules/denoiser_weighting.py b/sat/sgm/modules/diffusionmodules/denoiser_weighting.py
new file mode 100644
index 0000000000000000000000000000000000000000..b8b03ca58f17ea3d7374f4bbb7bf1d2994755e00
--- /dev/null
+++ b/sat/sgm/modules/diffusionmodules/denoiser_weighting.py
@@ -0,0 +1,24 @@
+import torch
+
+
+class UnitWeighting:
+ def __call__(self, sigma):
+ return torch.ones_like(sigma, device=sigma.device)
+
+
+class EDMWeighting:
+ def __init__(self, sigma_data=0.5):
+ self.sigma_data = sigma_data
+
+ def __call__(self, sigma):
+ return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2
+
+
+class VWeighting(EDMWeighting):
+ def __init__(self):
+ super().__init__(sigma_data=1.0)
+
+
+class EpsWeighting:
+ def __call__(self, sigma):
+ return sigma**-2.0
diff --git a/sat/sgm/modules/diffusionmodules/discretizer.py b/sat/sgm/modules/diffusionmodules/discretizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..a86b7d8dfb06aafef388ba28b369c611148ca300
--- /dev/null
+++ b/sat/sgm/modules/diffusionmodules/discretizer.py
@@ -0,0 +1,126 @@
+from abc import abstractmethod
+from functools import partial
+
+import numpy as np
+import torch
+
+from ...modules.diffusionmodules.util import make_beta_schedule
+from ...util import append_zero
+
+
+def generate_roughly_equally_spaced_steps(num_substeps: int, max_step: int) -> np.ndarray:
+ return np.linspace(max_step - 1, 0, num_substeps, endpoint=False).astype(int)[::-1]
+
+
+class Discretization:
+ def __call__(self, n, do_append_zero=True, device="cpu", flip=False, return_idx=False):
+ if return_idx:
+ sigmas, idx = self.get_sigmas(n, device=device, return_idx=return_idx)
+ else:
+ sigmas = self.get_sigmas(n, device=device, return_idx=return_idx)
+ sigmas = append_zero(sigmas) if do_append_zero else sigmas
+ if return_idx:
+ return sigmas if not flip else torch.flip(sigmas, (0,)), idx
+ else:
+ return sigmas if not flip else torch.flip(sigmas, (0,))
+
+ @abstractmethod
+ def get_sigmas(self, n, device):
+ pass
+
+
+class EDMDiscretization(Discretization):
+ def __init__(self, sigma_min=0.002, sigma_max=80.0, rho=7.0):
+ self.sigma_min = sigma_min
+ self.sigma_max = sigma_max
+ self.rho = rho
+
+ def get_sigmas(self, n, device="cpu"):
+ ramp = torch.linspace(0, 1, n, device=device)
+ min_inv_rho = self.sigma_min ** (1 / self.rho)
+ max_inv_rho = self.sigma_max ** (1 / self.rho)
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** self.rho
+ return sigmas
+
+
+class LegacyDDPMDiscretization(Discretization):
+ def __init__(
+ self,
+ linear_start=0.00085,
+ linear_end=0.0120,
+ num_timesteps=1000,
+ ):
+ super().__init__()
+ self.num_timesteps = num_timesteps
+ betas = make_beta_schedule("linear", num_timesteps, linear_start=linear_start, linear_end=linear_end)
+ alphas = 1.0 - betas
+ self.alphas_cumprod = np.cumprod(alphas, axis=0)
+ self.to_torch = partial(torch.tensor, dtype=torch.float32)
+
+ def get_sigmas(self, n, device="cpu"):
+ if n < self.num_timesteps:
+ timesteps = generate_roughly_equally_spaced_steps(n, self.num_timesteps)
+ alphas_cumprod = self.alphas_cumprod[timesteps]
+ elif n == self.num_timesteps:
+ alphas_cumprod = self.alphas_cumprod
+ else:
+ raise ValueError
+
+ to_torch = partial(torch.tensor, dtype=torch.float32, device=device)
+ sigmas = to_torch((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
+ return torch.flip(sigmas, (0,)) # sigma_t: 14.4 -> 0.029
+
+
+class ZeroSNRDDPMDiscretization(Discretization):
+ def __init__(
+ self,
+ linear_start=0.00085,
+ linear_end=0.0120,
+ num_timesteps=1000,
+ shift_scale=1.0, # noise schedule t_n -> t_m: logSNR(t_m) = logSNR(t_n) - log(shift_scale)
+ keep_start=False,
+ post_shift=False,
+ ):
+ super().__init__()
+ if keep_start and not post_shift:
+ linear_start = linear_start / (shift_scale + (1 - shift_scale) * linear_start)
+ self.num_timesteps = num_timesteps
+ betas = make_beta_schedule("linear", num_timesteps, linear_start=linear_start, linear_end=linear_end)
+ alphas = 1.0 - betas
+ self.alphas_cumprod = np.cumprod(alphas, axis=0)
+ self.to_torch = partial(torch.tensor, dtype=torch.float32)
+
+ # SNR shift
+ if not post_shift:
+ self.alphas_cumprod = self.alphas_cumprod / (shift_scale + (1 - shift_scale) * self.alphas_cumprod)
+
+ self.post_shift = post_shift
+ self.shift_scale = shift_scale
+
+ def get_sigmas(self, n, device="cpu", return_idx=False):
+ if n < self.num_timesteps:
+ timesteps = generate_roughly_equally_spaced_steps(n, self.num_timesteps)
+ alphas_cumprod = self.alphas_cumprod[timesteps]
+ elif n == self.num_timesteps:
+ alphas_cumprod = self.alphas_cumprod
+ else:
+ raise ValueError
+
+ to_torch = partial(torch.tensor, dtype=torch.float32, device=device)
+ alphas_cumprod = to_torch(alphas_cumprod)
+ alphas_cumprod_sqrt = alphas_cumprod.sqrt()
+ alphas_cumprod_sqrt_0 = alphas_cumprod_sqrt[0].clone()
+ alphas_cumprod_sqrt_T = alphas_cumprod_sqrt[-1].clone()
+
+ alphas_cumprod_sqrt -= alphas_cumprod_sqrt_T
+ alphas_cumprod_sqrt *= alphas_cumprod_sqrt_0 / (alphas_cumprod_sqrt_0 - alphas_cumprod_sqrt_T)
+
+ if self.post_shift:
+ alphas_cumprod_sqrt = (
+ alphas_cumprod_sqrt**2 / (self.shift_scale + (1 - self.shift_scale) * alphas_cumprod_sqrt**2)
+ ) ** 0.5
+
+ if return_idx:
+ return torch.flip(alphas_cumprod_sqrt, (0,)), timesteps
+ else:
+ return torch.flip(alphas_cumprod_sqrt, (0,)) # sqrt(alpha_t): 0 -> 0.99
diff --git a/sat/sgm/modules/diffusionmodules/guiders.py b/sat/sgm/modules/diffusionmodules/guiders.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ce657c39258681fdadbec874ae2a7b6e26b8294
--- /dev/null
+++ b/sat/sgm/modules/diffusionmodules/guiders.py
@@ -0,0 +1,87 @@
+import logging
+from abc import ABC, abstractmethod
+from typing import Dict, List, Optional, Tuple, Union
+from functools import partial
+import math
+
+import torch
+from einops import rearrange, repeat
+
+from ...util import append_dims, default, instantiate_from_config
+
+
+class Guider(ABC):
+ @abstractmethod
+ def __call__(self, x: torch.Tensor, sigma: float) -> torch.Tensor:
+ pass
+
+ def prepare_inputs(self, x: torch.Tensor, s: float, c: Dict, uc: Dict) -> Tuple[torch.Tensor, float, Dict]:
+ pass
+
+
+class VanillaCFG:
+ """
+ implements parallelized CFG
+ """
+
+ def __init__(self, scale, dyn_thresh_config=None):
+ self.scale = scale
+ scale_schedule = lambda scale, sigma: scale # independent of step
+ self.scale_schedule = partial(scale_schedule, scale)
+ self.dyn_thresh = instantiate_from_config(
+ default(
+ dyn_thresh_config,
+ {"target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding"},
+ )
+ )
+
+ def __call__(self, x, sigma, scale=None):
+ x_u, x_c = x.chunk(2)
+ scale_value = default(scale, self.scale_schedule(sigma))
+ x_pred = self.dyn_thresh(x_u, x_c, scale_value)
+ return x_pred
+
+ def prepare_inputs(self, x, s, c, uc):
+ c_out = dict()
+
+ for k in c:
+ if k in ["vector", "crossattn", "concat"]:
+ c_out[k] = torch.cat((uc[k], c[k]), 0)
+ else:
+ assert c[k] == uc[k]
+ c_out[k] = c[k]
+ return torch.cat([x] * 2), torch.cat([s] * 2), c_out
+
+
+class DynamicCFG(VanillaCFG):
+ def __init__(self, scale, exp, num_steps, dyn_thresh_config=None):
+ super().__init__(scale, dyn_thresh_config)
+ scale_schedule = (
+ lambda scale, sigma, step_index: 1 + scale * (1 - math.cos(math.pi * (step_index / num_steps) ** exp)) / 2
+ )
+ self.scale_schedule = partial(scale_schedule, scale)
+ self.dyn_thresh = instantiate_from_config(
+ default(
+ dyn_thresh_config,
+ {"target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding"},
+ )
+ )
+
+ def __call__(self, x, sigma, step_index, scale=None):
+ x_u, x_c = x.chunk(2)
+ scale_value = self.scale_schedule(sigma, step_index.item())
+ x_pred = self.dyn_thresh(x_u, x_c, scale_value)
+ return x_pred
+
+
+class IdentityGuider:
+ def __call__(self, x, sigma):
+ return x
+
+ def prepare_inputs(self, x, s, c, uc):
+ c_out = dict()
+
+ for k in c:
+ c_out[k] = c[k]
+
+ return x, s, c_out
diff --git a/sat/sgm/modules/diffusionmodules/lora.py b/sat/sgm/modules/diffusionmodules/lora.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ccd72a19f615162832d7c5d1c215dd7921833c7
--- /dev/null
+++ b/sat/sgm/modules/diffusionmodules/lora.py
@@ -0,0 +1,362 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Callable, Dict, List, Optional, Set, Tuple, Type, Union
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+
+class LoRALinearLayer(nn.Module):
+ def __init__(self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None):
+ super().__init__()
+
+ self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
+ self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype)
+ # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
+ # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
+ self.network_alpha = network_alpha
+ self.rank = rank
+ self.out_features = out_features
+ self.in_features = in_features
+
+ nn.init.normal_(self.down.weight, std=1 / rank)
+ nn.init.zeros_(self.up.weight)
+
+ def forward(self, hidden_states):
+ orig_dtype = hidden_states.dtype
+ dtype = self.down.weight.dtype
+
+ down_hidden_states = self.down(hidden_states.to(dtype))
+ up_hidden_states = self.up(down_hidden_states)
+
+ if self.network_alpha is not None:
+ up_hidden_states *= self.network_alpha / self.rank
+
+ return up_hidden_states.to(orig_dtype)
+
+
+class LoRAConv2dLayer(nn.Module):
+ def __init__(
+ self, in_features, out_features, rank=4, kernel_size=(1, 1), stride=(1, 1), padding=0, network_alpha=None
+ ):
+ super().__init__()
+
+ self.down = nn.Conv2d(in_features, rank, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)
+ # according to the official kohya_ss trainer kernel_size are always fixed for the up layer
+ # # see: https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L129
+ self.up = nn.Conv2d(rank, out_features, kernel_size=(1, 1), stride=(1, 1), bias=False)
+
+ # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
+ # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
+ self.network_alpha = network_alpha
+ self.rank = rank
+
+ nn.init.normal_(self.down.weight, std=1 / rank)
+ nn.init.zeros_(self.up.weight)
+
+ def forward(self, hidden_states):
+ orig_dtype = hidden_states.dtype
+ dtype = self.down.weight.dtype
+
+ down_hidden_states = self.down(hidden_states.to(dtype))
+ up_hidden_states = self.up(down_hidden_states)
+
+ if self.network_alpha is not None:
+ up_hidden_states *= self.network_alpha / self.rank
+
+ return up_hidden_states.to(orig_dtype)
+
+
+class LoRACompatibleConv(nn.Conv2d):
+ """
+ A convolutional layer that can be used with LoRA.
+ """
+
+ def __init__(self, *args, lora_layer: Optional[LoRAConv2dLayer] = None, scale: float = 1.0, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.lora_layer = lora_layer
+ self.scale = scale
+
+ def set_lora_layer(self, lora_layer: Optional[LoRAConv2dLayer]):
+ self.lora_layer = lora_layer
+
+ def _fuse_lora(self, lora_scale=1.0):
+ if self.lora_layer is None:
+ return
+
+ dtype, device = self.weight.data.dtype, self.weight.data.device
+
+ w_orig = self.weight.data.float()
+ w_up = self.lora_layer.up.weight.data.float()
+ w_down = self.lora_layer.down.weight.data.float()
+
+ if self.lora_layer.network_alpha is not None:
+ w_up = w_up * self.lora_layer.network_alpha / self.lora_layer.rank
+
+ fusion = torch.mm(w_up.flatten(start_dim=1), w_down.flatten(start_dim=1))
+ fusion = fusion.reshape((w_orig.shape))
+ fused_weight = w_orig + (lora_scale * fusion)
+ self.weight.data = fused_weight.to(device=device, dtype=dtype)
+
+ # we can drop the lora layer now
+ self.lora_layer = None
+
+ # offload the up and down matrices to CPU to not blow the memory
+ self.w_up = w_up.cpu()
+ self.w_down = w_down.cpu()
+ self._lora_scale = lora_scale
+
+ def _unfuse_lora(self):
+ if not (hasattr(self, "w_up") and hasattr(self, "w_down")):
+ return
+
+ fused_weight = self.weight.data
+ dtype, device = fused_weight.data.dtype, fused_weight.data.device
+
+ self.w_up = self.w_up.to(device=device).float()
+ self.w_down = self.w_down.to(device).float()
+
+ fusion = torch.mm(self.w_up.flatten(start_dim=1), self.w_down.flatten(start_dim=1))
+ fusion = fusion.reshape((fused_weight.shape))
+ unfused_weight = fused_weight.float() - (self._lora_scale * fusion)
+ self.weight.data = unfused_weight.to(device=device, dtype=dtype)
+
+ self.w_up = None
+ self.w_down = None
+
+ def forward(self, hidden_states, scale: float = None):
+ if scale is None:
+ scale = self.scale
+ if self.lora_layer is None:
+ # make sure to the functional Conv2D function as otherwise torch.compile's graph will break
+ # see: https://github.com/huggingface/diffusers/pull/4315
+ return F.conv2d(
+ hidden_states, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups
+ )
+ else:
+ return super().forward(hidden_states) + (scale * self.lora_layer(hidden_states))
+
+
+class LoRACompatibleLinear(nn.Linear):
+ """
+ A Linear layer that can be used with LoRA.
+ """
+
+ def __init__(self, *args, lora_layer: Optional[LoRALinearLayer] = None, scale: float = 1.0, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.lora_layer = lora_layer
+ self.scale = scale
+
+ def set_lora_layer(self, lora_layer: Optional[LoRALinearLayer]):
+ self.lora_layer = lora_layer
+
+ def _fuse_lora(self, lora_scale=1.0):
+ if self.lora_layer is None:
+ return
+
+ dtype, device = self.weight.data.dtype, self.weight.data.device
+
+ w_orig = self.weight.data.float()
+ w_up = self.lora_layer.up.weight.data.float()
+ w_down = self.lora_layer.down.weight.data.float()
+
+ if self.lora_layer.network_alpha is not None:
+ w_up = w_up * self.lora_layer.network_alpha / self.lora_layer.rank
+
+ fused_weight = w_orig + (lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0])
+ self.weight.data = fused_weight.to(device=device, dtype=dtype)
+
+ # we can drop the lora layer now
+ self.lora_layer = None
+
+ # offload the up and down matrices to CPU to not blow the memory
+ self.w_up = w_up.cpu()
+ self.w_down = w_down.cpu()
+ self._lora_scale = lora_scale
+
+ def _unfuse_lora(self):
+ if not (hasattr(self, "w_up") and hasattr(self, "w_down")):
+ return
+
+ fused_weight = self.weight.data
+ dtype, device = fused_weight.dtype, fused_weight.device
+
+ w_up = self.w_up.to(device=device).float()
+ w_down = self.w_down.to(device).float()
+
+ unfused_weight = fused_weight.float() - (self._lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0])
+ self.weight.data = unfused_weight.to(device=device, dtype=dtype)
+
+ self.w_up = None
+ self.w_down = None
+
+ def forward(self, hidden_states, scale: float = None):
+ if scale is None:
+ scale = self.scale
+ if self.lora_layer is None:
+ out = super().forward(hidden_states)
+ return out
+ else:
+ out = super().forward(hidden_states) + (scale * self.lora_layer(hidden_states))
+ return out
+
+
+def _find_children(
+ model,
+ search_class: List[Type[nn.Module]] = [nn.Linear],
+):
+ """
+ Find all modules of a certain class (or union of classes).
+
+ Returns all matching modules, along with the parent of those moduless and the
+ names they are referenced by.
+ """
+ # For each target find every linear_class module that isn't a child of a LoraInjectedLinear
+ for parent in model.modules():
+ for name, module in parent.named_children():
+ if any([isinstance(module, _class) for _class in search_class]):
+ yield parent, name, module
+
+
+def _find_modules_v2(
+ model,
+ ancestor_class: Optional[Set[str]] = None,
+ search_class: List[Type[nn.Module]] = [nn.Linear],
+ exclude_children_of: Optional[List[Type[nn.Module]]] = [
+ LoRACompatibleLinear,
+ LoRACompatibleConv,
+ LoRALinearLayer,
+ LoRAConv2dLayer,
+ ],
+):
+ """
+ Find all modules of a certain class (or union of classes) that are direct or
+ indirect descendants of other modules of a certain class (or union of classes).
+
+ Returns all matching modules, along with the parent of those moduless and the
+ names they are referenced by.
+ """
+
+ # Get the targets we should replace all linears under
+ if ancestor_class is not None:
+ ancestors = (module for module in model.modules() if module.__class__.__name__ in ancestor_class)
+ else:
+ # this, incase you want to naively iterate over all modules.
+ ancestors = [module for module in model.modules()]
+
+ # For each target find every linear_class module that isn't a child of a LoraInjectedLinear
+ for ancestor in ancestors:
+ for fullname, module in ancestor.named_modules():
+ if any([isinstance(module, _class) for _class in search_class]):
+ # Find the direct parent if this is a descendant, not a child, of target
+ *path, name = fullname.split(".")
+ parent = ancestor
+ flag = False
+ while path:
+ try:
+ parent = parent.get_submodule(path.pop(0))
+ except:
+ flag = True
+ break
+ if flag:
+ continue
+ # Skip this linear if it's a child of a LoraInjectedLinear
+ if exclude_children_of and any([isinstance(parent, _class) for _class in exclude_children_of]):
+ continue
+ # Otherwise, yield it
+ yield parent, name, module
+
+
+_find_modules = _find_modules_v2
+
+
+def inject_trainable_lora_extended(
+ model: nn.Module,
+ target_replace_module: Set[str] = None,
+ rank: int = 4,
+ scale: float = 1.0,
+):
+ for _module, name, _child_module in _find_modules(
+ model, target_replace_module, search_class=[nn.Linear, nn.Conv2d]
+ ):
+ if _child_module.__class__ == nn.Linear:
+ weight = _child_module.weight
+ bias = _child_module.bias
+ lora_layer = LoRALinearLayer(
+ in_features=_child_module.in_features,
+ out_features=_child_module.out_features,
+ rank=rank,
+ )
+ _tmp = (
+ LoRACompatibleLinear(
+ _child_module.in_features,
+ _child_module.out_features,
+ lora_layer=lora_layer,
+ scale=scale,
+ )
+ .to(weight.dtype)
+ .to(weight.device)
+ )
+ _tmp.weight = weight
+ if bias is not None:
+ _tmp.bias = bias
+ elif _child_module.__class__ == nn.Conv2d:
+ weight = _child_module.weight
+ bias = _child_module.bias
+ lora_layer = LoRAConv2dLayer(
+ in_features=_child_module.in_channels,
+ out_features=_child_module.out_channels,
+ rank=rank,
+ kernel_size=_child_module.kernel_size,
+ stride=_child_module.stride,
+ padding=_child_module.padding,
+ )
+ _tmp = (
+ LoRACompatibleConv(
+ _child_module.in_channels,
+ _child_module.out_channels,
+ kernel_size=_child_module.kernel_size,
+ stride=_child_module.stride,
+ padding=_child_module.padding,
+ lora_layer=lora_layer,
+ scale=scale,
+ )
+ .to(weight.dtype)
+ .to(weight.device)
+ )
+ _tmp.weight = weight
+ if bias is not None:
+ _tmp.bias = bias
+ else:
+ continue
+
+ _module._modules[name] = _tmp
+ # print('injecting lora layer to', _module, name)
+
+ return
+
+
+def update_lora_scale(
+ model: nn.Module,
+ target_module: Set[str] = None,
+ scale: float = 1.0,
+):
+ for _module, name, _child_module in _find_modules(
+ model, target_module, search_class=[LoRACompatibleLinear, LoRACompatibleConv]
+ ):
+ _child_module.scale = scale
+
+ return
diff --git a/sat/sgm/modules/diffusionmodules/loss.py b/sat/sgm/modules/diffusionmodules/loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc0a304f881be58e3573fa19b2399dfbe767fa04
--- /dev/null
+++ b/sat/sgm/modules/diffusionmodules/loss.py
@@ -0,0 +1,132 @@
+from typing import List, Optional, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from omegaconf import ListConfig
+import math
+
+from ...modules.diffusionmodules.sampling import VideoDDIMSampler, VPSDEDPMPP2MSampler
+from ...util import append_dims, instantiate_from_config
+from ...modules.autoencoding.lpips.loss.lpips import LPIPS
+
+# import rearrange
+from einops import rearrange
+import random
+from sat import mpu
+
+
+class StandardDiffusionLoss(nn.Module):
+ def __init__(
+ self,
+ sigma_sampler_config,
+ type="l2",
+ offset_noise_level=0.0,
+ batch2model_keys: Optional[Union[str, List[str], ListConfig]] = None,
+ ):
+ super().__init__()
+
+ assert type in ["l2", "l1", "lpips"]
+
+ self.sigma_sampler = instantiate_from_config(sigma_sampler_config)
+
+ self.type = type
+ self.offset_noise_level = offset_noise_level
+
+ if type == "lpips":
+ self.lpips = LPIPS().eval()
+
+ if not batch2model_keys:
+ batch2model_keys = []
+
+ if isinstance(batch2model_keys, str):
+ batch2model_keys = [batch2model_keys]
+
+ self.batch2model_keys = set(batch2model_keys)
+
+ def __call__(self, network, denoiser, conditioner, input, batch):
+ cond = conditioner(batch)
+ additional_model_inputs = {key: batch[key] for key in self.batch2model_keys.intersection(batch)}
+
+ sigmas = self.sigma_sampler(input.shape[0]).to(input.device)
+ noise = torch.randn_like(input)
+ if self.offset_noise_level > 0.0:
+ noise = (
+ noise + append_dims(torch.randn(input.shape[0]).to(input.device), input.ndim) * self.offset_noise_level
+ )
+ noise = noise.to(input.dtype)
+ noised_input = input.float() + noise * append_dims(sigmas, input.ndim)
+ model_output = denoiser(network, noised_input, sigmas, cond, **additional_model_inputs)
+ w = append_dims(denoiser.w(sigmas), input.ndim)
+ return self.get_loss(model_output, input, w)
+
+ def get_loss(self, model_output, target, w):
+ if self.type == "l2":
+ return torch.mean((w * (model_output - target) ** 2).reshape(target.shape[0], -1), 1)
+ elif self.type == "l1":
+ return torch.mean((w * (model_output - target).abs()).reshape(target.shape[0], -1), 1)
+ elif self.type == "lpips":
+ loss = self.lpips(model_output, target).reshape(-1)
+ return loss
+
+
+class VideoDiffusionLoss(StandardDiffusionLoss):
+ def __init__(self, block_scale=None, block_size=None, min_snr_value=None, fixed_frames=0, **kwargs):
+ self.fixed_frames = fixed_frames
+ self.block_scale = block_scale
+ self.block_size = block_size
+ self.min_snr_value = min_snr_value
+ super().__init__(**kwargs)
+
+ def __call__(self, network, denoiser, conditioner, input, batch):
+ cond = conditioner(batch)
+ additional_model_inputs = {key: batch[key] for key in self.batch2model_keys.intersection(batch)}
+
+ alphas_cumprod_sqrt, idx = self.sigma_sampler(input.shape[0], return_idx=True)
+ alphas_cumprod_sqrt = alphas_cumprod_sqrt.to(input.device)
+ idx = idx.to(input.device)
+
+ noise = torch.randn_like(input)
+
+ # broadcast noise
+ mp_size = mpu.get_model_parallel_world_size()
+ global_rank = torch.distributed.get_rank() // mp_size
+ src = global_rank * mp_size
+ torch.distributed.broadcast(idx, src=src, group=mpu.get_model_parallel_group())
+ torch.distributed.broadcast(noise, src=src, group=mpu.get_model_parallel_group())
+ torch.distributed.broadcast(alphas_cumprod_sqrt, src=src, group=mpu.get_model_parallel_group())
+
+ additional_model_inputs["idx"] = idx
+
+ if self.offset_noise_level > 0.0:
+ noise = (
+ noise + append_dims(torch.randn(input.shape[0]).to(input.device), input.ndim) * self.offset_noise_level
+ )
+
+ noised_input = input.float() * append_dims(alphas_cumprod_sqrt, input.ndim) + noise * append_dims(
+ (1 - alphas_cumprod_sqrt**2) ** 0.5, input.ndim
+ )
+
+ model_output = denoiser(network, noised_input, alphas_cumprod_sqrt, cond, **additional_model_inputs)
+ w = append_dims(1 / (1 - alphas_cumprod_sqrt**2), input.ndim) # v-pred
+
+ if self.min_snr_value is not None:
+ w = min(w, self.min_snr_value)
+ return self.get_loss(model_output, input, w)
+
+ def get_loss(self, model_output, target, w):
+ if self.type == "l2":
+ return torch.mean((w * (model_output - target) ** 2).reshape(target.shape[0], -1), 1)
+ elif self.type == "l1":
+ return torch.mean((w * (model_output - target).abs()).reshape(target.shape[0], -1), 1)
+ elif self.type == "lpips":
+ loss = self.lpips(model_output, target).reshape(-1)
+ return loss
+
+
+def get_3d_position_ids(frame_len, h, w):
+ i = torch.arange(frame_len).view(frame_len, 1, 1).expand(frame_len, h, w)
+ j = torch.arange(h).view(1, h, 1).expand(frame_len, h, w)
+ k = torch.arange(w).view(1, 1, w).expand(frame_len, h, w)
+ position_ids = torch.stack([i, j, k], dim=-1).reshape(-1, 3)
+ return position_ids
diff --git a/sat/sgm/modules/diffusionmodules/model.py b/sat/sgm/modules/diffusionmodules/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..466f01ac967bcc6d240d1eda06b308b46ac07bce
--- /dev/null
+++ b/sat/sgm/modules/diffusionmodules/model.py
@@ -0,0 +1,683 @@
+# pytorch_diffusion + derived encoder decoder
+import math
+from typing import Any, Callable, Optional
+
+import numpy as np
+import torch
+import torch.nn as nn
+from einops import rearrange
+from packaging import version
+
+try:
+ import xformers
+ import xformers.ops
+
+ XFORMERS_IS_AVAILABLE = True
+except:
+ XFORMERS_IS_AVAILABLE = False
+ print("no module 'xformers'. Processing without...")
+
+from ...modules.attention import LinearAttention, MemoryEfficientCrossAttention
+
+
+def get_timestep_embedding(timesteps, embedding_dim):
+ """
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
+ From Fairseq.
+ Build sinusoidal embeddings.
+ This matches the implementation in tensor2tensor, but differs slightly
+ from the description in Section 3.5 of "Attention Is All You Need".
+ """
+ assert len(timesteps.shape) == 1
+
+ half_dim = embedding_dim // 2
+ emb = math.log(10000) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
+ emb = emb.to(device=timesteps.device)
+ emb = timesteps.float()[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
+ return emb
+
+
+def nonlinearity(x):
+ # swish
+ return x * torch.sigmoid(x)
+
+
+def Normalize(in_channels, num_groups=32):
+ return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
+
+
+class Upsample(nn.Module):
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
+
+ def forward(self, x):
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
+ if self.with_conv:
+ x = self.conv(x)
+ return x
+
+
+class Downsample(nn.Module):
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ # no asymmetric padding in torch conv, must do it ourselves
+ self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
+
+ def forward(self, x):
+ if self.with_conv:
+ pad = (0, 1, 0, 1)
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
+ x = self.conv(x)
+ else:
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
+ return x
+
+
+class ResnetBlock(nn.Module):
+ def __init__(
+ self,
+ *,
+ in_channels,
+ out_channels=None,
+ conv_shortcut=False,
+ dropout,
+ temb_channels=512,
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+ self.use_conv_shortcut = conv_shortcut
+
+ self.norm1 = Normalize(in_channels)
+ self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
+ if temb_channels > 0:
+ self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
+ self.norm2 = Normalize(out_channels)
+ self.dropout = torch.nn.Dropout(dropout)
+ self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
+ else:
+ self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
+
+ def forward(self, x, temb):
+ h = x
+ h = self.norm1(h)
+ h = nonlinearity(h)
+ h = self.conv1(h)
+
+ if temb is not None:
+ h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
+
+ h = self.norm2(h)
+ h = nonlinearity(h)
+ h = self.dropout(h)
+ h = self.conv2(h)
+
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ x = self.conv_shortcut(x)
+ else:
+ x = self.nin_shortcut(x)
+
+ return x + h
+
+
+class LinAttnBlock(LinearAttention):
+ """to match AttnBlock usage"""
+
+ def __init__(self, in_channels):
+ super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
+
+
+class AttnBlock(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+ self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+ self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+ self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+
+ def attention(self, h_: torch.Tensor) -> torch.Tensor:
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ b, c, h, w = q.shape
+ q, k, v = map(lambda x: rearrange(x, "b c h w -> b 1 (h w) c").contiguous(), (q, k, v))
+ h_ = torch.nn.functional.scaled_dot_product_attention(q, k, v) # scale is dim ** -0.5 per default
+ # compute attention
+
+ return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
+
+ def forward(self, x, **kwargs):
+ h_ = x
+ h_ = self.attention(h_)
+ h_ = self.proj_out(h_)
+ return x + h_
+
+
+class MemoryEfficientAttnBlock(nn.Module):
+ """
+ Uses xformers efficient implementation,
+ see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
+ Note: this is a single-head self-attention operation
+ """
+
+ #
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+ self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+ self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+ self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+ self.attention_op: Optional[Any] = None
+
+ def attention(self, h_: torch.Tensor) -> torch.Tensor:
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ B, C, H, W = q.shape
+ q, k, v = map(lambda x: rearrange(x, "b c h w -> b (h w) c"), (q, k, v))
+
+ q, k, v = map(
+ lambda t: t.unsqueeze(3)
+ .reshape(B, t.shape[1], 1, C)
+ .permute(0, 2, 1, 3)
+ .reshape(B * 1, t.shape[1], C)
+ .contiguous(),
+ (q, k, v),
+ )
+ out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
+
+ out = out.unsqueeze(0).reshape(B, 1, out.shape[1], C).permute(0, 2, 1, 3).reshape(B, out.shape[1], C)
+ return rearrange(out, "b (h w) c -> b c h w", b=B, h=H, w=W, c=C)
+
+ def forward(self, x, **kwargs):
+ h_ = x
+ h_ = self.attention(h_)
+ h_ = self.proj_out(h_)
+ return x + h_
+
+
+class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
+ def forward(self, x, context=None, mask=None, **unused_kwargs):
+ b, c, h, w = x.shape
+ x = rearrange(x, "b c h w -> b (h w) c")
+ out = super().forward(x, context=context, mask=mask)
+ out = rearrange(out, "b (h w) c -> b c h w", h=h, w=w, c=c)
+ return x + out
+
+
+def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
+ assert attn_type in [
+ "vanilla",
+ "vanilla-xformers",
+ "memory-efficient-cross-attn",
+ "linear",
+ "none",
+ ], f"attn_type {attn_type} unknown"
+ if version.parse(torch.__version__) < version.parse("2.0.0") and attn_type != "none":
+ assert XFORMERS_IS_AVAILABLE, (
+ f"We do not support vanilla attention in {torch.__version__} anymore, "
+ f"as it is too expensive. Please install xformers via e.g. 'pip install xformers==0.0.16'"
+ )
+ attn_type = "vanilla-xformers"
+ print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
+ if attn_type == "vanilla":
+ assert attn_kwargs is None
+ return AttnBlock(in_channels)
+ elif attn_type == "vanilla-xformers":
+ print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...")
+ return MemoryEfficientAttnBlock(in_channels)
+ elif type == "memory-efficient-cross-attn":
+ attn_kwargs["query_dim"] = in_channels
+ return MemoryEfficientCrossAttentionWrapper(**attn_kwargs)
+ elif attn_type == "none":
+ return nn.Identity(in_channels)
+ else:
+ return LinAttnBlock(in_channels)
+
+
+class Model(nn.Module):
+ def __init__(
+ self,
+ *,
+ ch,
+ out_ch,
+ ch_mult=(1, 2, 4, 8),
+ num_res_blocks,
+ attn_resolutions,
+ dropout=0.0,
+ resamp_with_conv=True,
+ in_channels,
+ resolution,
+ use_timestep=True,
+ use_linear_attn=False,
+ attn_type="vanilla",
+ ):
+ super().__init__()
+ if use_linear_attn:
+ attn_type = "linear"
+ self.ch = ch
+ self.temb_ch = self.ch * 4
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+
+ self.use_timestep = use_timestep
+ if self.use_timestep:
+ # timestep embedding
+ self.temb = nn.Module()
+ self.temb.dense = nn.ModuleList(
+ [
+ torch.nn.Linear(self.ch, self.temb_ch),
+ torch.nn.Linear(self.temb_ch, self.temb_ch),
+ ]
+ )
+
+ # downsampling
+ self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
+
+ curr_res = resolution
+ in_ch_mult = (1,) + tuple(ch_mult)
+ self.down = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_in = ch * in_ch_mult[i_level]
+ block_out = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks):
+ block.append(
+ ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+ )
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ if i_level != self.num_resolutions - 1:
+ down.downsample = Downsample(block_in, resamp_with_conv)
+ curr_res = curr_res // 2
+ self.down.append(down)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+ self.mid.block_2 = ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch * ch_mult[i_level]
+ skip_in = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks + 1):
+ if i_block == self.num_res_blocks:
+ skip_in = ch * in_ch_mult[i_level]
+ block.append(
+ ResnetBlock(
+ in_channels=block_in + skip_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+ )
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ up.upsample = Upsample(block_in, resamp_with_conv)
+ curr_res = curr_res * 2
+ self.up.insert(0, up) # prepend to get consistent order
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
+
+ def forward(self, x, t=None, context=None):
+ # assert x.shape[2] == x.shape[3] == self.resolution
+ if context is not None:
+ # assume aligned context, cat along channel axis
+ x = torch.cat((x, context), dim=1)
+ if self.use_timestep:
+ # timestep embedding
+ assert t is not None
+ temb = get_timestep_embedding(t, self.ch)
+ temb = self.temb.dense[0](temb)
+ temb = nonlinearity(temb)
+ temb = self.temb.dense[1](temb)
+ else:
+ temb = None
+
+ # downsampling
+ hs = [self.conv_in(x)]
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](hs[-1], temb)
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h)
+ hs.append(h)
+ if i_level != self.num_resolutions - 1:
+ hs.append(self.down[i_level].downsample(hs[-1]))
+
+ # middle
+ h = hs[-1]
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks + 1):
+ h = self.up[i_level].block[i_block](torch.cat([h, hs.pop()], dim=1), temb)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h)
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+
+ # end
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+ def get_last_layer(self):
+ return self.conv_out.weight
+
+
+class Encoder(nn.Module):
+ def __init__(
+ self,
+ *,
+ ch,
+ out_ch,
+ ch_mult=(1, 2, 4, 8),
+ num_res_blocks,
+ attn_resolutions,
+ dropout=0.0,
+ resamp_with_conv=True,
+ in_channels,
+ resolution,
+ z_channels,
+ double_z=True,
+ use_linear_attn=False,
+ attn_type="vanilla",
+ **ignore_kwargs,
+ ):
+ super().__init__()
+ if use_linear_attn:
+ attn_type = "linear"
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+
+ # downsampling
+ self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
+
+ curr_res = resolution
+ in_ch_mult = (1,) + tuple(ch_mult)
+ self.in_ch_mult = in_ch_mult
+ self.down = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_in = ch * in_ch_mult[i_level]
+ block_out = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks):
+ block.append(
+ ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+ )
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ if i_level != self.num_resolutions - 1:
+ down.downsample = Downsample(block_in, resamp_with_conv)
+ curr_res = curr_res // 2
+ self.down.append(down)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+ self.mid.block_2 = ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(
+ block_in,
+ 2 * z_channels if double_z else z_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ )
+
+ def forward(self, x):
+ # timestep embedding
+ temb = None
+
+ # downsampling
+ hs = [self.conv_in(x)]
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](hs[-1], temb)
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h)
+ hs.append(h)
+ if i_level != self.num_resolutions - 1:
+ hs.append(self.down[i_level].downsample(hs[-1]))
+
+ # middle
+ h = hs[-1]
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # end
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+
+class Decoder(nn.Module):
+ def __init__(
+ self,
+ *,
+ ch,
+ out_ch,
+ ch_mult=(1, 2, 4, 8),
+ num_res_blocks,
+ attn_resolutions,
+ dropout=0.0,
+ resamp_with_conv=True,
+ in_channels,
+ resolution,
+ z_channels,
+ give_pre_end=False,
+ tanh_out=False,
+ use_linear_attn=False,
+ attn_type="vanilla",
+ **ignorekwargs,
+ ):
+ super().__init__()
+ if use_linear_attn:
+ attn_type = "linear"
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+ self.give_pre_end = give_pre_end
+ self.tanh_out = tanh_out
+
+ # compute in_ch_mult, block_in and curr_res at lowest res
+ in_ch_mult = (1,) + tuple(ch_mult)
+ block_in = ch * ch_mult[self.num_resolutions - 1]
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
+ self.z_shape = (1, z_channels, curr_res, curr_res)
+ print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
+
+ make_attn_cls = self._make_attn()
+ make_resblock_cls = self._make_resblock()
+ make_conv_cls = self._make_conv()
+ # z to block_in
+ self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = make_resblock_cls(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+ self.mid.attn_1 = make_attn_cls(block_in, attn_type=attn_type)
+ self.mid.block_2 = make_resblock_cls(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks + 1):
+ block.append(
+ make_resblock_cls(
+ in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+ )
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn_cls(block_in, attn_type=attn_type))
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ up.upsample = Upsample(block_in, resamp_with_conv)
+ curr_res = curr_res * 2
+ self.up.insert(0, up) # prepend to get consistent order
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = make_conv_cls(block_in, out_ch, kernel_size=3, stride=1, padding=1)
+
+ def _make_attn(self) -> Callable:
+ return make_attn
+
+ def _make_resblock(self) -> Callable:
+ return ResnetBlock
+
+ def _make_conv(self) -> Callable:
+ return torch.nn.Conv2d
+
+ def get_last_layer(self, **kwargs):
+ return self.conv_out.weight
+
+ def forward(self, z, **kwargs):
+ # assert z.shape[1:] == self.z_shape[1:]
+ self.last_z_shape = z.shape
+
+ # timestep embedding
+ temb = None
+
+ # z to block_in
+ h = self.conv_in(z)
+
+ # middle
+ h = self.mid.block_1(h, temb, **kwargs)
+ h = self.mid.attn_1(h, **kwargs)
+ h = self.mid.block_2(h, temb, **kwargs)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks + 1):
+ h = self.up[i_level].block[i_block](h, temb, **kwargs)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h, **kwargs)
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+
+ # end
+ if self.give_pre_end:
+ return h
+
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h, **kwargs)
+ if self.tanh_out:
+ h = torch.tanh(h)
+ return h
diff --git a/sat/sgm/modules/diffusionmodules/openaimodel.py b/sat/sgm/modules/diffusionmodules/openaimodel.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb5be798afc4fd76a9cf54025e955bda6decdeca
--- /dev/null
+++ b/sat/sgm/modules/diffusionmodules/openaimodel.py
@@ -0,0 +1,1249 @@
+import os
+import math
+from abc import abstractmethod
+from functools import partial
+from typing import Iterable, List, Optional, Tuple, Union
+
+import numpy as np
+import torch as th
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange
+
+from ...modules.attention import SpatialTransformer
+from ...modules.diffusionmodules.util import (
+ avg_pool_nd,
+ checkpoint,
+ conv_nd,
+ linear,
+ normalization,
+ timestep_embedding,
+ zero_module,
+)
+from ...modules.diffusionmodules.lora import inject_trainable_lora_extended, update_lora_scale
+from ...modules.video_attention import SpatialVideoTransformer
+from ...util import default, exists
+
+
+# dummy replace
+def convert_module_to_f16(x):
+ pass
+
+
+def convert_module_to_f32(x):
+ pass
+
+
+## go
+class AttentionPool2d(nn.Module):
+ """
+ Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
+ """
+
+ def __init__(
+ self,
+ spacial_dim: int,
+ embed_dim: int,
+ num_heads_channels: int,
+ output_dim: int = None,
+ ):
+ super().__init__()
+ self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5)
+ self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
+ self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
+ self.num_heads = embed_dim // num_heads_channels
+ self.attention = QKVAttention(self.num_heads)
+
+ def forward(self, x):
+ b, c, *_spatial = x.shape
+ x = x.reshape(b, c, -1) # NC(HW)
+ x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
+ x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
+ x = self.qkv_proj(x)
+ x = self.attention(x)
+ x = self.c_proj(x)
+ return x[:, :, 0]
+
+
+class TimestepBlock(nn.Module):
+ """
+ Any module where forward() takes timestep embeddings as a second argument.
+ """
+
+ @abstractmethod
+ def forward(self, x, emb):
+ """
+ Apply the module to `x` given `emb` timestep embeddings.
+ """
+
+
+class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
+ """
+ A sequential module that passes timestep embeddings to the children that
+ support it as an extra input.
+ """
+
+ def forward(
+ self,
+ x: th.Tensor,
+ emb: th.Tensor,
+ context: Optional[th.Tensor] = None,
+ image_only_indicator: Optional[th.Tensor] = None,
+ time_context: Optional[int] = None,
+ num_video_frames: Optional[int] = None,
+ ):
+ from ...modules.diffusionmodules.video_model import VideoResBlock
+
+ for layer in self:
+ module = layer
+
+ if isinstance(module, TimestepBlock) and not isinstance(module, VideoResBlock):
+ x = layer(x, emb)
+ elif isinstance(module, VideoResBlock):
+ x = layer(x, emb, num_video_frames, image_only_indicator)
+ elif isinstance(module, SpatialVideoTransformer):
+ x = layer(
+ x,
+ context,
+ time_context,
+ num_video_frames,
+ image_only_indicator,
+ )
+ elif isinstance(module, SpatialTransformer):
+ x = layer(x, context)
+ else:
+ x = layer(x)
+ return x
+
+
+class Upsample(nn.Module):
+ """
+ An upsampling layer with an optional convolution.
+ :param channels: channels in the inputs and outputs.
+ :param use_conv: a bool determining if a convolution is applied.
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ upsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, third_up=False):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.dims = dims
+ self.third_up = third_up
+ if use_conv:
+ self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ if self.dims == 3:
+ t_factor = 1 if not self.third_up else 2
+ x = F.interpolate(
+ x,
+ (t_factor * x.shape[2], x.shape[3] * 2, x.shape[4] * 2),
+ mode="nearest",
+ )
+ else:
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
+ if self.use_conv:
+ x = self.conv(x)
+ return x
+
+
+class TransposedUpsample(nn.Module):
+ "Learned 2x upsampling without padding"
+
+ def __init__(self, channels, out_channels=None, ks=5):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+
+ self.up = nn.ConvTranspose2d(self.channels, self.out_channels, kernel_size=ks, stride=2)
+
+ def forward(self, x):
+ return self.up(x)
+
+
+class Downsample(nn.Module):
+ """
+ A downsampling layer with an optional convolution.
+ :param channels: channels in the inputs and outputs.
+ :param use_conv: a bool determining if a convolution is applied.
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ downsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, third_down=False):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.dims = dims
+ stride = 2 if dims != 3 else ((1, 2, 2) if not third_down else (2, 2, 2))
+ if use_conv:
+ print(f"Building a Downsample layer with {dims} dims.")
+ print(
+ f" --> settings are: \n in-chn: {self.channels}, out-chn: {self.out_channels}, "
+ f"kernel-size: 3, stride: {stride}, padding: {padding}"
+ )
+ if dims == 3:
+ print(f" --> Downsampling third axis (time): {third_down}")
+ self.op = conv_nd(
+ dims,
+ self.channels,
+ self.out_channels,
+ 3,
+ stride=stride,
+ padding=padding,
+ )
+ else:
+ assert self.channels == self.out_channels
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ return self.op(x)
+
+
+class ResBlock(TimestepBlock):
+ """
+ A residual block that can optionally change the number of channels.
+ :param channels: the number of input channels.
+ :param emb_channels: the number of timestep embedding channels.
+ :param dropout: the rate of dropout.
+ :param out_channels: if specified, the number of out channels.
+ :param use_conv: if True and out_channels is specified, use a spatial
+ convolution instead of a smaller 1x1 convolution to change the
+ channels in the skip connection.
+ :param dims: determines if the signal is 1D, 2D, or 3D.
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
+ :param up: if True, use this block for upsampling.
+ :param down: if True, use this block for downsampling.
+ """
+
+ def __init__(
+ self,
+ channels,
+ emb_channels,
+ dropout,
+ out_channels=None,
+ use_conv=False,
+ use_scale_shift_norm=False,
+ dims=2,
+ use_checkpoint=False,
+ up=False,
+ down=False,
+ kernel_size=3,
+ exchange_temb_dims=False,
+ skip_t_emb=False,
+ ):
+ super().__init__()
+ self.channels = channels
+ self.emb_channels = emb_channels
+ self.dropout = dropout
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.use_checkpoint = use_checkpoint
+ self.use_scale_shift_norm = use_scale_shift_norm
+ self.exchange_temb_dims = exchange_temb_dims
+
+ if isinstance(kernel_size, Iterable):
+ padding = [k // 2 for k in kernel_size]
+ else:
+ padding = kernel_size // 2
+
+ self.in_layers = nn.Sequential(
+ normalization(channels),
+ nn.SiLU(),
+ conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding),
+ )
+
+ self.updown = up or down
+
+ if up:
+ self.h_upd = Upsample(channels, False, dims)
+ self.x_upd = Upsample(channels, False, dims)
+ elif down:
+ self.h_upd = Downsample(channels, False, dims)
+ self.x_upd = Downsample(channels, False, dims)
+ else:
+ self.h_upd = self.x_upd = nn.Identity()
+
+ self.skip_t_emb = skip_t_emb
+ self.emb_out_channels = 2 * self.out_channels if use_scale_shift_norm else self.out_channels
+ if self.skip_t_emb:
+ print(f"Skipping timestep embedding in {self.__class__.__name__}")
+ assert not self.use_scale_shift_norm
+ self.emb_layers = None
+ self.exchange_temb_dims = False
+ else:
+ self.emb_layers = nn.Sequential(
+ nn.SiLU(),
+ linear(
+ emb_channels,
+ self.emb_out_channels,
+ ),
+ )
+
+ self.out_layers = nn.Sequential(
+ normalization(self.out_channels),
+ nn.SiLU(),
+ nn.Dropout(p=dropout),
+ zero_module(
+ conv_nd(
+ dims,
+ self.out_channels,
+ self.out_channels,
+ kernel_size,
+ padding=padding,
+ )
+ ),
+ )
+
+ if self.out_channels == channels:
+ self.skip_connection = nn.Identity()
+ elif use_conv:
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding)
+ else:
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
+
+ def forward(self, x, emb):
+ """
+ Apply the block to a Tensor, conditioned on a timestep embedding.
+ :param x: an [N x C x ...] Tensor of features.
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
+ :return: an [N x C x ...] Tensor of outputs.
+ """
+ return checkpoint(self._forward, (x, emb), self.parameters(), self.use_checkpoint)
+
+ def _forward(self, x, emb):
+ if self.updown:
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
+ h = in_rest(x)
+ h = self.h_upd(h)
+ x = self.x_upd(x)
+ h = in_conv(h)
+ else:
+ h = self.in_layers(x)
+
+ if self.skip_t_emb:
+ emb_out = th.zeros_like(h)
+ else:
+ emb_out = self.emb_layers(emb).type(h.dtype)
+ while len(emb_out.shape) < len(h.shape):
+ emb_out = emb_out[..., None]
+ if self.use_scale_shift_norm:
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
+ scale, shift = th.chunk(emb_out, 2, dim=1)
+ h = out_norm(h) * (1 + scale) + shift
+ h = out_rest(h)
+ else:
+ if self.exchange_temb_dims:
+ emb_out = rearrange(emb_out, "b t c ... -> b c t ...")
+ h = h + emb_out
+ h = self.out_layers(h)
+ return self.skip_connection(x) + h
+
+
+class AttentionBlock(nn.Module):
+ """
+ An attention block that allows spatial positions to attend to each other.
+ Originally ported from here, but adapted to the N-d case.
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
+ """
+
+ def __init__(
+ self,
+ channels,
+ num_heads=1,
+ num_head_channels=-1,
+ use_checkpoint=False,
+ use_new_attention_order=False,
+ ):
+ super().__init__()
+ self.channels = channels
+ if num_head_channels == -1:
+ self.num_heads = num_heads
+ else:
+ assert (
+ channels % num_head_channels == 0
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
+ self.num_heads = channels // num_head_channels
+ self.use_checkpoint = use_checkpoint
+ self.norm = normalization(channels)
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
+ if use_new_attention_order:
+ # split qkv before split heads
+ self.attention = QKVAttention(self.num_heads)
+ else:
+ # split heads before split qkv
+ self.attention = QKVAttentionLegacy(self.num_heads)
+
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
+
+ def forward(self, x, **kwargs):
+ # TODO add crossframe attention and use mixed checkpoint
+ return checkpoint(
+ self._forward, (x,), self.parameters(), True
+ ) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
+ # return pt_checkpoint(self._forward, x) # pytorch
+
+ def _forward(self, x):
+ b, c, *spatial = x.shape
+ x = x.reshape(b, c, -1)
+ qkv = self.qkv(self.norm(x))
+ h = self.attention(qkv)
+ h = self.proj_out(h)
+ return (x + h).reshape(b, c, *spatial)
+
+
+def count_flops_attn(model, _x, y):
+ """
+ A counter for the `thop` package to count the operations in an
+ attention operation.
+ Meant to be used like:
+ macs, params = thop.profile(
+ model,
+ inputs=(inputs, timestamps),
+ custom_ops={QKVAttention: QKVAttention.count_flops},
+ )
+ """
+ b, c, *spatial = y[0].shape
+ num_spatial = int(np.prod(spatial))
+ # We perform two matmuls with the same number of ops.
+ # The first computes the weight matrix, the second computes
+ # the combination of the value vectors.
+ matmul_ops = 2 * b * (num_spatial**2) * c
+ model.total_ops += th.DoubleTensor([matmul_ops])
+
+
+class QKVAttentionLegacy(nn.Module):
+ """
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
+ """
+
+ def __init__(self, n_heads):
+ super().__init__()
+ self.n_heads = n_heads
+
+ def forward(self, qkv):
+ """
+ Apply QKV attention.
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
+ :return: an [N x (H * C) x T] tensor after attention.
+ """
+ bs, width, length = qkv.shape
+ assert width % (3 * self.n_heads) == 0
+ ch = width // (3 * self.n_heads)
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
+ scale = 1 / math.sqrt(math.sqrt(ch))
+ weight = th.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
+ a = th.einsum("bts,bcs->bct", weight, v)
+ return a.reshape(bs, -1, length)
+
+ @staticmethod
+ def count_flops(model, _x, y):
+ return count_flops_attn(model, _x, y)
+
+
+class QKVAttention(nn.Module):
+ """
+ A module which performs QKV attention and splits in a different order.
+ """
+
+ def __init__(self, n_heads):
+ super().__init__()
+ self.n_heads = n_heads
+
+ def forward(self, qkv):
+ """
+ Apply QKV attention.
+ :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
+ :return: an [N x (H * C) x T] tensor after attention.
+ """
+ bs, width, length = qkv.shape
+ assert width % (3 * self.n_heads) == 0
+ ch = width // (3 * self.n_heads)
+ q, k, v = qkv.chunk(3, dim=1)
+ scale = 1 / math.sqrt(math.sqrt(ch))
+ weight = th.einsum(
+ "bct,bcs->bts",
+ (q * scale).view(bs * self.n_heads, ch, length),
+ (k * scale).view(bs * self.n_heads, ch, length),
+ ) # More stable with f16 than dividing afterwards
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
+ a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
+ return a.reshape(bs, -1, length)
+
+ @staticmethod
+ def count_flops(model, _x, y):
+ return count_flops_attn(model, _x, y)
+
+
+class Timestep(nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+ self.dim = dim
+
+ def forward(self, t):
+ return timestep_embedding(t, self.dim)
+
+
+str_to_dtype = {"fp32": th.float32, "fp16": th.float16, "bf16": th.bfloat16}
+
+
+class UNetModel(nn.Module):
+ """
+ The full UNet model with attention and timestep embedding.
+ :param in_channels: channels in the input Tensor.
+ :param model_channels: base channel count for the model.
+ :param out_channels: channels in the output Tensor.
+ :param num_res_blocks: number of residual blocks per downsample.
+ :param attention_resolutions: a collection of downsample rates at which
+ attention will take place. May be a set, list, or tuple.
+ For example, if this contains 4, then at 4x downsampling, attention
+ will be used.
+ :param dropout: the dropout probability.
+ :param channel_mult: channel multiplier for each level of the UNet.
+ :param conv_resample: if True, use learned convolutions for upsampling and
+ downsampling.
+ :param dims: determines if the signal is 1D, 2D, or 3D.
+ :param num_classes: if specified (as an int), then this model will be
+ class-conditional with `num_classes` classes.
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
+ :param num_heads: the number of attention heads in each attention layer.
+ :param num_heads_channels: if specified, ignore num_heads and instead use
+ a fixed channel width per attention head.
+ :param num_heads_upsample: works with num_heads to set a different number
+ of heads for upsampling. Deprecated.
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
+ :param resblock_updown: use residual blocks for up/downsampling.
+ :param use_new_attention_order: use a different attention pattern for potentially
+ increased efficiency.
+ """
+
+ def __init__(
+ self,
+ in_channels,
+ model_channels,
+ out_channels,
+ num_res_blocks,
+ attention_resolutions,
+ dropout=0,
+ channel_mult=(1, 2, 4, 8),
+ conv_resample=True,
+ dims=2,
+ num_classes=None,
+ use_checkpoint=False,
+ use_fp16=False,
+ num_heads=-1,
+ num_head_channels=-1,
+ num_heads_upsample=-1,
+ use_scale_shift_norm=False,
+ resblock_updown=False,
+ use_new_attention_order=False,
+ use_spatial_transformer=False, # custom transformer support
+ transformer_depth=1, # custom transformer support
+ context_dim=None, # custom transformer support
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
+ legacy=True,
+ disable_self_attentions=None,
+ num_attention_blocks=None,
+ disable_middle_self_attn=False,
+ use_linear_in_transformer=False,
+ spatial_transformer_attn_type="softmax",
+ adm_in_channels=None,
+ use_fairscale_checkpoint=False,
+ offload_to_cpu=False,
+ transformer_depth_middle=None,
+ dtype="fp32",
+ lora_init=False,
+ lora_rank=4,
+ lora_scale=1.0,
+ lora_weight_path=None,
+ ):
+ super().__init__()
+ from omegaconf.listconfig import ListConfig
+
+ self.dtype = str_to_dtype[dtype]
+
+ if use_spatial_transformer:
+ assert (
+ context_dim is not None
+ ), "Fool!! You forgot to include the dimension of your cross-attention conditioning..."
+
+ if context_dim is not None:
+ assert (
+ use_spatial_transformer
+ ), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..."
+ if type(context_dim) == ListConfig:
+ context_dim = list(context_dim)
+
+ if num_heads_upsample == -1:
+ num_heads_upsample = num_heads
+
+ if num_heads == -1:
+ assert num_head_channels != -1, "Either num_heads or num_head_channels has to be set"
+
+ if num_head_channels == -1:
+ assert num_heads != -1, "Either num_heads or num_head_channels has to be set"
+
+ self.in_channels = in_channels
+ self.model_channels = model_channels
+ self.out_channels = out_channels
+ if isinstance(transformer_depth, int):
+ transformer_depth = len(channel_mult) * [transformer_depth]
+ elif isinstance(transformer_depth, ListConfig):
+ transformer_depth = list(transformer_depth)
+ transformer_depth_middle = default(transformer_depth_middle, transformer_depth[-1])
+
+ if isinstance(num_res_blocks, int):
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
+ else:
+ if len(num_res_blocks) != len(channel_mult):
+ raise ValueError(
+ "provide num_res_blocks either as an int (globally constant) or "
+ "as a list/tuple (per-level) with the same length as channel_mult"
+ )
+ self.num_res_blocks = num_res_blocks
+ # self.num_res_blocks = num_res_blocks
+ if disable_self_attentions is not None:
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
+ assert len(disable_self_attentions) == len(channel_mult)
+ if num_attention_blocks is not None:
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
+ assert all(
+ map(
+ lambda i: self.num_res_blocks[i] >= num_attention_blocks[i],
+ range(len(num_attention_blocks)),
+ )
+ )
+ print(
+ f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
+ f"attention will still not be set."
+ ) # todo: convert to warning
+
+ self.attention_resolutions = attention_resolutions
+ self.dropout = dropout
+ self.channel_mult = channel_mult
+ self.conv_resample = conv_resample
+ self.num_classes = num_classes
+ self.use_checkpoint = use_checkpoint
+ if use_fp16:
+ print("WARNING: use_fp16 was dropped and has no effect anymore.")
+ # self.dtype = th.float16 if use_fp16 else th.float32
+ self.num_heads = num_heads
+ self.num_head_channels = num_head_channels
+ self.num_heads_upsample = num_heads_upsample
+ self.predict_codebook_ids = n_embed is not None
+
+ assert use_fairscale_checkpoint != use_checkpoint or not (use_checkpoint or use_fairscale_checkpoint)
+
+ self.use_fairscale_checkpoint = False
+ checkpoint_wrapper_fn = (
+ partial(checkpoint_wrapper, offload_to_cpu=offload_to_cpu)
+ if self.use_fairscale_checkpoint
+ else lambda x: x
+ )
+
+ time_embed_dim = model_channels * 4
+ self.time_embed = checkpoint_wrapper_fn(
+ nn.Sequential(
+ linear(model_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ )
+ )
+
+ if self.num_classes is not None:
+ if isinstance(self.num_classes, int):
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
+ elif self.num_classes == "continuous":
+ print("setting up linear c_adm embedding layer")
+ self.label_emb = nn.Linear(1, time_embed_dim)
+ elif self.num_classes == "timestep":
+ self.label_emb = checkpoint_wrapper_fn(
+ nn.Sequential(
+ Timestep(model_channels),
+ nn.Sequential(
+ linear(model_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ ),
+ )
+ )
+ elif self.num_classes == "sequential":
+ assert adm_in_channels is not None
+ self.label_emb = nn.Sequential(
+ nn.Sequential(
+ linear(adm_in_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ )
+ )
+ else:
+ raise ValueError()
+
+ self.input_blocks = nn.ModuleList(
+ [TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))]
+ )
+ self._feature_size = model_channels
+ input_block_chans = [model_channels]
+ ch = model_channels
+ ds = 1
+ for level, mult in enumerate(channel_mult):
+ for nr in range(self.num_res_blocks[level]):
+ layers = [
+ checkpoint_wrapper_fn(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=mult * model_channels,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ )
+ ]
+ ch = mult * model_channels
+ if ds in attention_resolutions:
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ # num_heads = 1
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+ if exists(disable_self_attentions):
+ disabled_sa = disable_self_attentions[level]
+ else:
+ disabled_sa = False
+
+ if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
+ layers.append(
+ checkpoint_wrapper_fn(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ )
+ )
+ if not use_spatial_transformer
+ else checkpoint_wrapper_fn(
+ SpatialTransformer(
+ ch,
+ num_heads,
+ dim_head,
+ depth=transformer_depth[level],
+ context_dim=context_dim,
+ disable_self_attn=disabled_sa,
+ use_linear=use_linear_in_transformer,
+ attn_type=spatial_transformer_attn_type,
+ use_checkpoint=use_checkpoint,
+ )
+ )
+ )
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+ input_block_chans.append(ch)
+ if level != len(channel_mult) - 1:
+ out_ch = ch
+ self.input_blocks.append(
+ TimestepEmbedSequential(
+ checkpoint_wrapper_fn(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ down=True,
+ )
+ )
+ if resblock_updown
+ else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch)
+ )
+ )
+ ch = out_ch
+ input_block_chans.append(ch)
+ ds *= 2
+ self._feature_size += ch
+
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ # num_heads = 1
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+ self.middle_block = TimestepEmbedSequential(
+ checkpoint_wrapper_fn(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ),
+ checkpoint_wrapper_fn(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ )
+ )
+ if not use_spatial_transformer
+ else checkpoint_wrapper_fn(
+ SpatialTransformer( # always uses a self-attn
+ ch,
+ num_heads,
+ dim_head,
+ depth=transformer_depth_middle,
+ context_dim=context_dim,
+ disable_self_attn=disable_middle_self_attn,
+ use_linear=use_linear_in_transformer,
+ attn_type=spatial_transformer_attn_type,
+ use_checkpoint=use_checkpoint,
+ )
+ ),
+ checkpoint_wrapper_fn(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ),
+ )
+ self._feature_size += ch
+
+ self.output_blocks = nn.ModuleList([])
+ for level, mult in list(enumerate(channel_mult))[::-1]:
+ for i in range(self.num_res_blocks[level] + 1):
+ ich = input_block_chans.pop()
+ layers = [
+ checkpoint_wrapper_fn(
+ ResBlock(
+ ch + ich,
+ time_embed_dim,
+ dropout,
+ out_channels=model_channels * mult,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ )
+ ]
+ ch = model_channels * mult
+ if ds in attention_resolutions:
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ # num_heads = 1
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+ if exists(disable_self_attentions):
+ disabled_sa = disable_self_attentions[level]
+ else:
+ disabled_sa = False
+
+ if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
+ layers.append(
+ checkpoint_wrapper_fn(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads_upsample,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ )
+ )
+ if not use_spatial_transformer
+ else checkpoint_wrapper_fn(
+ SpatialTransformer(
+ ch,
+ num_heads,
+ dim_head,
+ depth=transformer_depth[level],
+ context_dim=context_dim,
+ disable_self_attn=disabled_sa,
+ use_linear=use_linear_in_transformer,
+ attn_type=spatial_transformer_attn_type,
+ use_checkpoint=use_checkpoint,
+ )
+ )
+ )
+ if level and i == self.num_res_blocks[level]:
+ out_ch = ch
+ layers.append(
+ checkpoint_wrapper_fn(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ up=True,
+ )
+ )
+ if resblock_updown
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
+ )
+ ds //= 2
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+
+ self.out = checkpoint_wrapper_fn(
+ nn.Sequential(
+ normalization(ch),
+ nn.SiLU(),
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
+ )
+ )
+ if self.predict_codebook_ids:
+ self.id_predictor = checkpoint_wrapper_fn(
+ nn.Sequential(
+ normalization(ch),
+ conv_nd(dims, model_channels, n_embed, 1),
+ # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
+ )
+ )
+
+ if lora_init:
+ self._init_lora(lora_rank, lora_scale, lora_weight_path)
+
+ def _init_lora(self, rank, scale, ckpt_dir=None):
+ inject_trainable_lora_extended(self, target_replace_module=None, rank=rank, scale=scale)
+
+ if ckpt_dir is not None:
+ with open(os.path.join(ckpt_dir, "latest")) as latest_file:
+ latest = latest_file.read().strip()
+ ckpt_path = os.path.join(ckpt_dir, latest, "mp_rank_00_model_states.pt")
+ print(f"loading lora from {ckpt_path}")
+ sd = th.load(ckpt_path)["module"]
+ sd = {
+ key[len("model.diffusion_model") :]: sd[key] for key in sd if key.startswith("model.diffusion_model")
+ }
+ self.load_state_dict(sd, strict=False)
+
+ def _update_scale(self, scale):
+ update_lora_scale(self, scale)
+
+ def convert_to_fp16(self):
+ """
+ Convert the torso of the model to float16.
+ """
+ self.input_blocks.apply(convert_module_to_f16)
+ self.middle_block.apply(convert_module_to_f16)
+ self.output_blocks.apply(convert_module_to_f16)
+
+ def convert_to_fp32(self):
+ """
+ Convert the torso of the model to float32.
+ """
+ self.input_blocks.apply(convert_module_to_f32)
+ self.middle_block.apply(convert_module_to_f32)
+ self.output_blocks.apply(convert_module_to_f32)
+
+ def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
+ """
+ Apply the model to an input batch.
+ :param x: an [N x C x ...] Tensor of inputs.
+ :param timesteps: a 1-D batch of timesteps.
+ :param context: conditioning plugged in via crossattn
+ :param y: an [N] Tensor of labels, if class-conditional.
+ :return: an [N x C x ...] Tensor of outputs.
+ """
+ assert (y is not None) == (
+ self.num_classes is not None
+ ), "must specify y if and only if the model is class-conditional"
+ hs = []
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False, dtype=self.dtype)
+ emb = self.time_embed(t_emb)
+
+ if self.num_classes is not None:
+ assert y.shape[0] == x.shape[0]
+ emb = emb + self.label_emb(y)
+
+ # h = x.type(self.dtype)
+ h = x
+ for module in self.input_blocks:
+ h = module(h, emb, context)
+ hs.append(h)
+ h = self.middle_block(h, emb, context)
+ for module in self.output_blocks:
+ h = th.cat([h, hs.pop()], dim=1)
+ h = module(h, emb, context)
+ h = h.type(x.dtype)
+ if self.predict_codebook_ids:
+ assert False, "not supported anymore. what the f*** are you doing?"
+ else:
+ return self.out(h)
+
+
+class NoTimeUNetModel(UNetModel):
+ def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
+ timesteps = th.zeros_like(timesteps)
+ return super().forward(x, timesteps, context, y, **kwargs)
+
+
+class EncoderUNetModel(nn.Module):
+ """
+ The half UNet model with attention and timestep embedding.
+ For usage, see UNet.
+ """
+
+ def __init__(
+ self,
+ image_size,
+ in_channels,
+ model_channels,
+ out_channels,
+ num_res_blocks,
+ attention_resolutions,
+ dropout=0,
+ channel_mult=(1, 2, 4, 8),
+ conv_resample=True,
+ dims=2,
+ use_checkpoint=False,
+ use_fp16=False,
+ num_heads=1,
+ num_head_channels=-1,
+ num_heads_upsample=-1,
+ use_scale_shift_norm=False,
+ resblock_updown=False,
+ use_new_attention_order=False,
+ pool="adaptive",
+ *args,
+ **kwargs,
+ ):
+ super().__init__()
+
+ if num_heads_upsample == -1:
+ num_heads_upsample = num_heads
+
+ self.in_channels = in_channels
+ self.model_channels = model_channels
+ self.out_channels = out_channels
+ self.num_res_blocks = num_res_blocks
+ self.attention_resolutions = attention_resolutions
+ self.dropout = dropout
+ self.channel_mult = channel_mult
+ self.conv_resample = conv_resample
+ self.use_checkpoint = use_checkpoint
+ self.dtype = th.float16 if use_fp16 else th.float32
+ self.num_heads = num_heads
+ self.num_head_channels = num_head_channels
+ self.num_heads_upsample = num_heads_upsample
+
+ time_embed_dim = model_channels * 4
+ self.time_embed = nn.Sequential(
+ linear(model_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ )
+
+ self.input_blocks = nn.ModuleList(
+ [TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))]
+ )
+ self._feature_size = model_channels
+ input_block_chans = [model_channels]
+ ch = model_channels
+ ds = 1
+ for level, mult in enumerate(channel_mult):
+ for _ in range(num_res_blocks):
+ layers = [
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=mult * model_channels,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = mult * model_channels
+ if ds in attention_resolutions:
+ layers.append(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=num_head_channels,
+ use_new_attention_order=use_new_attention_order,
+ )
+ )
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+ input_block_chans.append(ch)
+ if level != len(channel_mult) - 1:
+ out_ch = ch
+ self.input_blocks.append(
+ TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ down=True,
+ )
+ if resblock_updown
+ else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch)
+ )
+ )
+ ch = out_ch
+ input_block_chans.append(ch)
+ ds *= 2
+ self._feature_size += ch
+
+ self.middle_block = TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=num_head_channels,
+ use_new_attention_order=use_new_attention_order,
+ ),
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ )
+ self._feature_size += ch
+ self.pool = pool
+ if pool == "adaptive":
+ self.out = nn.Sequential(
+ normalization(ch),
+ nn.SiLU(),
+ nn.AdaptiveAvgPool2d((1, 1)),
+ zero_module(conv_nd(dims, ch, out_channels, 1)),
+ nn.Flatten(),
+ )
+ elif pool == "attention":
+ assert num_head_channels != -1
+ self.out = nn.Sequential(
+ normalization(ch),
+ nn.SiLU(),
+ AttentionPool2d((image_size // ds), ch, num_head_channels, out_channels),
+ )
+ elif pool == "spatial":
+ self.out = nn.Sequential(
+ nn.Linear(self._feature_size, 2048),
+ nn.ReLU(),
+ nn.Linear(2048, self.out_channels),
+ )
+ elif pool == "spatial_v2":
+ self.out = nn.Sequential(
+ nn.Linear(self._feature_size, 2048),
+ normalization(2048),
+ nn.SiLU(),
+ nn.Linear(2048, self.out_channels),
+ )
+ else:
+ raise NotImplementedError(f"Unexpected {pool} pooling")
+
+ def convert_to_fp16(self):
+ """
+ Convert the torso of the model to float16.
+ """
+ self.input_blocks.apply(convert_module_to_f16)
+ self.middle_block.apply(convert_module_to_f16)
+
+ def convert_to_fp32(self):
+ """
+ Convert the torso of the model to float32.
+ """
+ self.input_blocks.apply(convert_module_to_f32)
+ self.middle_block.apply(convert_module_to_f32)
+
+ def forward(self, x, timesteps):
+ """
+ Apply the model to an input batch.
+ :param x: an [N x C x ...] Tensor of inputs.
+ :param timesteps: a 1-D batch of timesteps.
+ :return: an [N x K] Tensor of outputs.
+ """
+ emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
+
+ results = []
+ # h = x.type(self.dtype)
+ h = x
+ for module in self.input_blocks:
+ h = module(h, emb)
+ if self.pool.startswith("spatial"):
+ results.append(h.type(x.dtype).mean(dim=(2, 3)))
+ h = self.middle_block(h, emb)
+ if self.pool.startswith("spatial"):
+ results.append(h.type(x.dtype).mean(dim=(2, 3)))
+ h = th.cat(results, axis=-1)
+ return self.out(h)
+ else:
+ h = h.type(x.dtype)
+ return self.out(h)
+
+
+if __name__ == "__main__":
+
+ class Dummy(nn.Module):
+ def __init__(self, in_channels=3, model_channels=64):
+ super().__init__()
+ self.input_blocks = nn.ModuleList(
+ [TimestepEmbedSequential(conv_nd(2, in_channels, model_channels, 3, padding=1))]
+ )
+
+ model = UNetModel(
+ use_checkpoint=True,
+ image_size=64,
+ in_channels=4,
+ out_channels=4,
+ model_channels=128,
+ attention_resolutions=[4, 2],
+ num_res_blocks=2,
+ channel_mult=[1, 2, 4],
+ num_head_channels=64,
+ use_spatial_transformer=False,
+ use_linear_in_transformer=True,
+ transformer_depth=1,
+ legacy=False,
+ ).cuda()
+ x = th.randn(11, 4, 64, 64).cuda()
+ t = th.randint(low=0, high=10, size=(11,), device="cuda")
+ o = model(x, t)
+ print("done.")
diff --git a/sat/sgm/modules/diffusionmodules/sampling.py b/sat/sgm/modules/diffusionmodules/sampling.py
new file mode 100644
index 0000000000000000000000000000000000000000..f0f18302b2b36c607a15bdbf27b489ca5a447712
--- /dev/null
+++ b/sat/sgm/modules/diffusionmodules/sampling.py
@@ -0,0 +1,763 @@
+"""
+Partially ported from https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py
+"""
+
+from typing import Dict, Union
+
+import torch
+from omegaconf import ListConfig, OmegaConf
+from tqdm import tqdm
+
+from ...modules.diffusionmodules.sampling_utils import (
+ get_ancestral_step,
+ linear_multistep_coeff,
+ to_d,
+ to_neg_log_sigma,
+ to_sigma,
+)
+from ...util import append_dims, default, instantiate_from_config
+from ...util import SeededNoise
+
+from .guiders import DynamicCFG
+
+DEFAULT_GUIDER = {"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"}
+
+
+class BaseDiffusionSampler:
+ def __init__(
+ self,
+ discretization_config: Union[Dict, ListConfig, OmegaConf],
+ num_steps: Union[int, None] = None,
+ guider_config: Union[Dict, ListConfig, OmegaConf, None] = None,
+ verbose: bool = False,
+ device: str = "cuda",
+ ):
+ self.num_steps = num_steps
+ self.discretization = instantiate_from_config(discretization_config)
+ self.guider = instantiate_from_config(
+ default(
+ guider_config,
+ DEFAULT_GUIDER,
+ )
+ )
+ self.verbose = verbose
+ self.device = device
+
+ def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None):
+ sigmas = self.discretization(self.num_steps if num_steps is None else num_steps, device=self.device)
+ uc = default(uc, cond)
+
+ x *= torch.sqrt(1.0 + sigmas[0] ** 2.0)
+ num_sigmas = len(sigmas)
+
+ s_in = x.new_ones([x.shape[0]]).float()
+
+ return x, s_in, sigmas, num_sigmas, cond, uc
+
+ def denoise(self, x, denoiser, sigma, cond, uc):
+ denoised = denoiser(*self.guider.prepare_inputs(x, sigma, cond, uc))
+ denoised = self.guider(denoised, sigma)
+ return denoised
+
+ def get_sigma_gen(self, num_sigmas):
+ sigma_generator = range(num_sigmas - 1)
+ if self.verbose:
+ print("#" * 30, " Sampling setting ", "#" * 30)
+ print(f"Sampler: {self.__class__.__name__}")
+ print(f"Discretization: {self.discretization.__class__.__name__}")
+ print(f"Guider: {self.guider.__class__.__name__}")
+ sigma_generator = tqdm(
+ sigma_generator,
+ total=num_sigmas,
+ desc=f"Sampling with {self.__class__.__name__} for {num_sigmas} steps",
+ )
+ return sigma_generator
+
+
+class SingleStepDiffusionSampler(BaseDiffusionSampler):
+ def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc, *args, **kwargs):
+ raise NotImplementedError
+
+ def euler_step(self, x, d, dt):
+ return x + dt * d
+
+
+class EDMSampler(SingleStepDiffusionSampler):
+ def __init__(self, s_churn=0.0, s_tmin=0.0, s_tmax=float("inf"), s_noise=1.0, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ self.s_churn = s_churn
+ self.s_tmin = s_tmin
+ self.s_tmax = s_tmax
+ self.s_noise = s_noise
+
+ def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, gamma=0.0):
+ sigma_hat = sigma * (gamma + 1.0)
+ if gamma > 0:
+ eps = torch.randn_like(x) * self.s_noise
+ x = x + eps * append_dims(sigma_hat**2 - sigma**2, x.ndim) ** 0.5
+
+ denoised = self.denoise(x, denoiser, sigma_hat, cond, uc)
+ d = to_d(x, sigma_hat, denoised)
+ dt = append_dims(next_sigma - sigma_hat, x.ndim)
+
+ euler_step = self.euler_step(x, d, dt)
+ x = self.possible_correction_step(euler_step, x, d, dt, next_sigma, denoiser, cond, uc)
+ return x
+
+ def __call__(self, denoiser, x, cond, uc=None, num_steps=None):
+ x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps)
+
+ for i in self.get_sigma_gen(num_sigmas):
+ gamma = (
+ min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1) if self.s_tmin <= sigmas[i] <= self.s_tmax else 0.0
+ )
+ x = self.sampler_step(
+ s_in * sigmas[i],
+ s_in * sigmas[i + 1],
+ denoiser,
+ x,
+ cond,
+ uc,
+ gamma,
+ )
+
+ return x
+
+
+class DDIMSampler(SingleStepDiffusionSampler):
+ def __init__(self, s_noise=0.1, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ self.s_noise = s_noise
+
+ def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, s_noise=0.0):
+ denoised = self.denoise(x, denoiser, sigma, cond, uc)
+ d = to_d(x, sigma, denoised)
+ dt = append_dims(next_sigma * (1 - s_noise**2) ** 0.5 - sigma, x.ndim)
+
+ euler_step = x + dt * d + s_noise * append_dims(next_sigma, x.ndim) * torch.randn_like(x)
+
+ x = self.possible_correction_step(euler_step, x, d, dt, next_sigma, denoiser, cond, uc)
+ return x
+
+ def __call__(self, denoiser, x, cond, uc=None, num_steps=None):
+ x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps)
+
+ for i in self.get_sigma_gen(num_sigmas):
+ x = self.sampler_step(
+ s_in * sigmas[i],
+ s_in * sigmas[i + 1],
+ denoiser,
+ x,
+ cond,
+ uc,
+ self.s_noise,
+ )
+
+ return x
+
+
+class AncestralSampler(SingleStepDiffusionSampler):
+ def __init__(self, eta=1.0, s_noise=1.0, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ self.eta = eta
+ self.s_noise = s_noise
+ self.noise_sampler = lambda x: torch.randn_like(x)
+
+ def ancestral_euler_step(self, x, denoised, sigma, sigma_down):
+ d = to_d(x, sigma, denoised)
+ dt = append_dims(sigma_down - sigma, x.ndim)
+
+ return self.euler_step(x, d, dt)
+
+ def ancestral_step(self, x, sigma, next_sigma, sigma_up):
+ x = torch.where(
+ append_dims(next_sigma, x.ndim) > 0.0,
+ x + self.noise_sampler(x) * self.s_noise * append_dims(sigma_up, x.ndim),
+ x,
+ )
+ return x
+
+ def __call__(self, denoiser, x, cond, uc=None, num_steps=None):
+ x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps)
+
+ for i in self.get_sigma_gen(num_sigmas):
+ x = self.sampler_step(
+ s_in * sigmas[i],
+ s_in * sigmas[i + 1],
+ denoiser,
+ x,
+ cond,
+ uc,
+ )
+
+ return x
+
+
+class LinearMultistepSampler(BaseDiffusionSampler):
+ def __init__(
+ self,
+ order=4,
+ *args,
+ **kwargs,
+ ):
+ super().__init__(*args, **kwargs)
+
+ self.order = order
+
+ def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs):
+ x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps)
+
+ ds = []
+ sigmas_cpu = sigmas.detach().cpu().numpy()
+ for i in self.get_sigma_gen(num_sigmas):
+ sigma = s_in * sigmas[i]
+ denoised = denoiser(*self.guider.prepare_inputs(x, sigma, cond, uc), **kwargs)
+ denoised = self.guider(denoised, sigma)
+ d = to_d(x, sigma, denoised)
+ ds.append(d)
+ if len(ds) > self.order:
+ ds.pop(0)
+ cur_order = min(i + 1, self.order)
+ coeffs = [linear_multistep_coeff(cur_order, sigmas_cpu, i, j) for j in range(cur_order)]
+ x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
+
+ return x
+
+
+class EulerEDMSampler(EDMSampler):
+ def possible_correction_step(self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc):
+ return euler_step
+
+
+class HeunEDMSampler(EDMSampler):
+ def possible_correction_step(self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc):
+ if torch.sum(next_sigma) < 1e-14:
+ # Save a network evaluation if all noise levels are 0
+ return euler_step
+ else:
+ denoised = self.denoise(euler_step, denoiser, next_sigma, cond, uc)
+ d_new = to_d(euler_step, next_sigma, denoised)
+ d_prime = (d + d_new) / 2.0
+
+ # apply correction if noise level is not 0
+ x = torch.where(append_dims(next_sigma, x.ndim) > 0.0, x + d_prime * dt, euler_step)
+ return x
+
+
+class EulerAncestralSampler(AncestralSampler):
+ def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc):
+ sigma_down, sigma_up = get_ancestral_step(sigma, next_sigma, eta=self.eta)
+ denoised = self.denoise(x, denoiser, sigma, cond, uc)
+ x = self.ancestral_euler_step(x, denoised, sigma, sigma_down)
+ x = self.ancestral_step(x, sigma, next_sigma, sigma_up)
+
+ return x
+
+
+class DPMPP2SAncestralSampler(AncestralSampler):
+ def get_variables(self, sigma, sigma_down):
+ t, t_next = [to_neg_log_sigma(s) for s in (sigma, sigma_down)]
+ h = t_next - t
+ s = t + 0.5 * h
+ return h, s, t, t_next
+
+ def get_mult(self, h, s, t, t_next):
+ mult1 = to_sigma(s) / to_sigma(t)
+ mult2 = (-0.5 * h).expm1()
+ mult3 = to_sigma(t_next) / to_sigma(t)
+ mult4 = (-h).expm1()
+
+ return mult1, mult2, mult3, mult4
+
+ def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, **kwargs):
+ sigma_down, sigma_up = get_ancestral_step(sigma, next_sigma, eta=self.eta)
+ denoised = self.denoise(x, denoiser, sigma, cond, uc)
+ x_euler = self.ancestral_euler_step(x, denoised, sigma, sigma_down)
+
+ if torch.sum(sigma_down) < 1e-14:
+ # Save a network evaluation if all noise levels are 0
+ x = x_euler
+ else:
+ h, s, t, t_next = self.get_variables(sigma, sigma_down)
+ mult = [append_dims(mult, x.ndim) for mult in self.get_mult(h, s, t, t_next)]
+
+ x2 = mult[0] * x - mult[1] * denoised
+ denoised2 = self.denoise(x2, denoiser, to_sigma(s), cond, uc)
+ x_dpmpp2s = mult[2] * x - mult[3] * denoised2
+
+ # apply correction if noise level is not 0
+ x = torch.where(append_dims(sigma_down, x.ndim) > 0.0, x_dpmpp2s, x_euler)
+
+ x = self.ancestral_step(x, sigma, next_sigma, sigma_up)
+ return x
+
+
+class DPMPP2MSampler(BaseDiffusionSampler):
+ def get_variables(self, sigma, next_sigma, previous_sigma=None):
+ t, t_next = [to_neg_log_sigma(s) for s in (sigma, next_sigma)]
+ h = t_next - t
+
+ if previous_sigma is not None:
+ h_last = t - to_neg_log_sigma(previous_sigma)
+ r = h_last / h
+ return h, r, t, t_next
+ else:
+ return h, None, t, t_next
+
+ def get_mult(self, h, r, t, t_next, previous_sigma):
+ mult1 = to_sigma(t_next) / to_sigma(t)
+ mult2 = (-h).expm1()
+
+ if previous_sigma is not None:
+ mult3 = 1 + 1 / (2 * r)
+ mult4 = 1 / (2 * r)
+ return mult1, mult2, mult3, mult4
+ else:
+ return mult1, mult2
+
+ def sampler_step(
+ self,
+ old_denoised,
+ previous_sigma,
+ sigma,
+ next_sigma,
+ denoiser,
+ x,
+ cond,
+ uc=None,
+ ):
+ denoised = self.denoise(x, denoiser, sigma, cond, uc)
+
+ h, r, t, t_next = self.get_variables(sigma, next_sigma, previous_sigma)
+ mult = [append_dims(mult, x.ndim) for mult in self.get_mult(h, r, t, t_next, previous_sigma)]
+
+ x_standard = mult[0] * x - mult[1] * denoised
+ if old_denoised is None or torch.sum(next_sigma) < 1e-14:
+ # Save a network evaluation if all noise levels are 0 or on the first step
+ return x_standard, denoised
+ else:
+ denoised_d = mult[2] * denoised - mult[3] * old_denoised
+ x_advanced = mult[0] * x - mult[1] * denoised_d
+
+ # apply correction if noise level is not 0 and not first step
+ x = torch.where(append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard)
+
+ return x, denoised
+
+ def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs):
+ x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps)
+
+ old_denoised = None
+ for i in self.get_sigma_gen(num_sigmas):
+ x, old_denoised = self.sampler_step(
+ old_denoised,
+ None if i == 0 else s_in * sigmas[i - 1],
+ s_in * sigmas[i],
+ s_in * sigmas[i + 1],
+ denoiser,
+ x,
+ cond,
+ uc=uc,
+ )
+
+ return x
+
+
+class SDEDPMPP2MSampler(BaseDiffusionSampler):
+ def get_variables(self, sigma, next_sigma, previous_sigma=None):
+ t, t_next = [to_neg_log_sigma(s) for s in (sigma, next_sigma)]
+ h = t_next - t
+
+ if previous_sigma is not None:
+ h_last = t - to_neg_log_sigma(previous_sigma)
+ r = h_last / h
+ return h, r, t, t_next
+ else:
+ return h, None, t, t_next
+
+ def get_mult(self, h, r, t, t_next, previous_sigma):
+ mult1 = to_sigma(t_next) / to_sigma(t) * (-h).exp()
+ mult2 = (-2 * h).expm1()
+
+ if previous_sigma is not None:
+ mult3 = 1 + 1 / (2 * r)
+ mult4 = 1 / (2 * r)
+ return mult1, mult2, mult3, mult4
+ else:
+ return mult1, mult2
+
+ def sampler_step(
+ self,
+ old_denoised,
+ previous_sigma,
+ sigma,
+ next_sigma,
+ denoiser,
+ x,
+ cond,
+ uc=None,
+ ):
+ denoised = self.denoise(x, denoiser, sigma, cond, uc)
+
+ h, r, t, t_next = self.get_variables(sigma, next_sigma, previous_sigma)
+ mult = [append_dims(mult, x.ndim) for mult in self.get_mult(h, r, t, t_next, previous_sigma)]
+ mult_noise = append_dims(next_sigma * (1 - (-2 * h).exp()) ** 0.5, x.ndim)
+
+ x_standard = mult[0] * x - mult[1] * denoised + mult_noise * torch.randn_like(x)
+ if old_denoised is None or torch.sum(next_sigma) < 1e-14:
+ # Save a network evaluation if all noise levels are 0 or on the first step
+ return x_standard, denoised
+ else:
+ denoised_d = mult[2] * denoised - mult[3] * old_denoised
+ x_advanced = mult[0] * x - mult[1] * denoised_d + mult_noise * torch.randn_like(x)
+
+ # apply correction if noise level is not 0 and not first step
+ x = torch.where(append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard)
+
+ return x, denoised
+
+ def __call__(self, denoiser, x, cond, uc=None, num_steps=None, scale=None, **kwargs):
+ x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps)
+
+ old_denoised = None
+ for i in self.get_sigma_gen(num_sigmas):
+ x, old_denoised = self.sampler_step(
+ old_denoised,
+ None if i == 0 else s_in * sigmas[i - 1],
+ s_in * sigmas[i],
+ s_in * sigmas[i + 1],
+ denoiser,
+ x,
+ cond,
+ uc=uc,
+ )
+
+ return x
+
+
+class SdeditEDMSampler(EulerEDMSampler):
+ def __init__(self, edit_ratio=0.5, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ self.edit_ratio = edit_ratio
+
+ def __call__(self, denoiser, image, randn, cond, uc=None, num_steps=None, edit_ratio=None):
+ randn_unit = randn.clone()
+ randn, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(randn, cond, uc, num_steps)
+
+ if num_steps is None:
+ num_steps = self.num_steps
+ if edit_ratio is None:
+ edit_ratio = self.edit_ratio
+ x = None
+
+ for i in self.get_sigma_gen(num_sigmas):
+ if i / num_steps < edit_ratio:
+ continue
+ if x is None:
+ x = image + randn_unit * append_dims(s_in * sigmas[i], len(randn_unit.shape))
+
+ gamma = (
+ min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1) if self.s_tmin <= sigmas[i] <= self.s_tmax else 0.0
+ )
+ x = self.sampler_step(
+ s_in * sigmas[i],
+ s_in * sigmas[i + 1],
+ denoiser,
+ x,
+ cond,
+ uc,
+ gamma,
+ )
+
+ return x
+
+
+class VideoDDIMSampler(BaseDiffusionSampler):
+ def __init__(self, fixed_frames=0, sdedit=False, **kwargs):
+ super().__init__(**kwargs)
+ self.fixed_frames = fixed_frames
+ self.sdedit = sdedit
+
+ def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None):
+ alpha_cumprod_sqrt, timesteps = self.discretization(
+ self.num_steps if num_steps is None else num_steps,
+ device=self.device,
+ return_idx=True,
+ do_append_zero=False,
+ )
+ alpha_cumprod_sqrt = torch.cat([alpha_cumprod_sqrt, alpha_cumprod_sqrt.new_ones([1])])
+ timesteps = torch.cat([torch.tensor(list(timesteps)).new_zeros([1]) - 1, torch.tensor(list(timesteps))])
+
+ uc = default(uc, cond)
+
+ num_sigmas = len(alpha_cumprod_sqrt)
+
+ s_in = x.new_ones([x.shape[0]])
+
+ return x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, timesteps
+
+ def denoise(self, x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep=None, idx=None, scale=None, scale_emb=None):
+ additional_model_inputs = {}
+
+ if isinstance(scale, torch.Tensor) == False and scale == 1:
+ additional_model_inputs["idx"] = x.new_ones([x.shape[0]]) * timestep
+ if scale_emb is not None:
+ additional_model_inputs["scale_emb"] = scale_emb
+ denoised = denoiser(x, alpha_cumprod_sqrt, cond, **additional_model_inputs).to(torch.float32)
+ else:
+ additional_model_inputs["idx"] = torch.cat([x.new_ones([x.shape[0]]) * timestep] * 2)
+ denoised = denoiser(
+ *self.guider.prepare_inputs(x, alpha_cumprod_sqrt, cond, uc), **additional_model_inputs
+ ).to(torch.float32)
+ if isinstance(self.guider, DynamicCFG):
+ denoised = self.guider(
+ denoised, (1 - alpha_cumprod_sqrt**2) ** 0.5, step_index=self.num_steps - timestep, scale=scale
+ )
+ else:
+ denoised = self.guider(denoised, (1 - alpha_cumprod_sqrt**2) ** 0.5, scale=scale)
+ return denoised
+
+ def sampler_step(
+ self,
+ alpha_cumprod_sqrt,
+ next_alpha_cumprod_sqrt,
+ denoiser,
+ x,
+ cond,
+ uc=None,
+ idx=None,
+ timestep=None,
+ scale=None,
+ scale_emb=None,
+ ):
+ denoised = self.denoise(
+ x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep, idx, scale=scale, scale_emb=scale_emb
+ ).to(torch.float32)
+
+ a_t = ((1 - next_alpha_cumprod_sqrt**2) / (1 - alpha_cumprod_sqrt**2)) ** 0.5
+ b_t = next_alpha_cumprod_sqrt - alpha_cumprod_sqrt * a_t
+
+ x = append_dims(a_t, x.ndim) * x + append_dims(b_t, x.ndim) * denoised
+ return x
+
+ def __call__(self, denoiser, x, cond, uc=None, num_steps=None, scale=None, scale_emb=None):
+ x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, timesteps = self.prepare_sampling_loop(
+ x, cond, uc, num_steps
+ )
+
+ for i in self.get_sigma_gen(num_sigmas):
+ x = self.sampler_step(
+ s_in * alpha_cumprod_sqrt[i],
+ s_in * alpha_cumprod_sqrt[i + 1],
+ denoiser,
+ x,
+ cond,
+ uc,
+ idx=self.num_steps - i,
+ timestep=timesteps[-(i + 1)],
+ scale=scale,
+ scale_emb=scale_emb,
+ )
+
+ return x
+
+
+class VPSDEDPMPP2MSampler(VideoDDIMSampler):
+ def get_variables(self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt=None):
+ alpha_cumprod = alpha_cumprod_sqrt**2
+ lamb = ((alpha_cumprod / (1 - alpha_cumprod)) ** 0.5).log()
+ next_alpha_cumprod = next_alpha_cumprod_sqrt**2
+ lamb_next = ((next_alpha_cumprod / (1 - next_alpha_cumprod)) ** 0.5).log()
+ h = lamb_next - lamb
+
+ if previous_alpha_cumprod_sqrt is not None:
+ previous_alpha_cumprod = previous_alpha_cumprod_sqrt**2
+ lamb_previous = ((previous_alpha_cumprod / (1 - previous_alpha_cumprod)) ** 0.5).log()
+ h_last = lamb - lamb_previous
+ r = h_last / h
+ return h, r, lamb, lamb_next
+ else:
+ return h, None, lamb, lamb_next
+
+ def get_mult(self, h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt):
+ mult1 = ((1 - next_alpha_cumprod_sqrt**2) / (1 - alpha_cumprod_sqrt**2)) ** 0.5 * (-h).exp()
+ mult2 = (-2 * h).expm1() * next_alpha_cumprod_sqrt
+
+ if previous_alpha_cumprod_sqrt is not None:
+ mult3 = 1 + 1 / (2 * r)
+ mult4 = 1 / (2 * r)
+ return mult1, mult2, mult3, mult4
+ else:
+ return mult1, mult2
+
+ def sampler_step(
+ self,
+ old_denoised,
+ previous_alpha_cumprod_sqrt,
+ alpha_cumprod_sqrt,
+ next_alpha_cumprod_sqrt,
+ denoiser,
+ x,
+ cond,
+ uc=None,
+ idx=None,
+ timestep=None,
+ scale=None,
+ scale_emb=None,
+ ):
+ denoised = self.denoise(
+ x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep, idx, scale=scale, scale_emb=scale_emb
+ ).to(torch.float32)
+ if idx == 1:
+ return denoised, denoised
+
+ h, r, lamb, lamb_next = self.get_variables(
+ alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt
+ )
+ mult = [
+ append_dims(mult, x.ndim)
+ for mult in self.get_mult(h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt)
+ ]
+ mult_noise = append_dims((1 - next_alpha_cumprod_sqrt**2) ** 0.5 * (1 - (-2 * h).exp()) ** 0.5, x.ndim)
+
+ x_standard = mult[0] * x - mult[1] * denoised + mult_noise * torch.randn_like(x)
+ if old_denoised is None or torch.sum(next_alpha_cumprod_sqrt) < 1e-14:
+ # Save a network evaluation if all noise levels are 0 or on the first step
+ return x_standard, denoised
+ else:
+ denoised_d = mult[2] * denoised - mult[3] * old_denoised
+ x_advanced = mult[0] * x - mult[1] * denoised_d + mult_noise * torch.randn_like(x)
+
+ x = x_advanced
+
+ return x, denoised
+
+ def __call__(self, denoiser, x, cond, uc=None, num_steps=None, scale=None, scale_emb=None):
+ x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, timesteps = self.prepare_sampling_loop(
+ x, cond, uc, num_steps
+ )
+
+ if self.fixed_frames > 0:
+ prefix_frames = x[:, : self.fixed_frames]
+ old_denoised = None
+ for i in self.get_sigma_gen(num_sigmas):
+ if self.fixed_frames > 0:
+ if self.sdedit:
+ rd = torch.randn_like(prefix_frames)
+ noised_prefix_frames = alpha_cumprod_sqrt[i] * prefix_frames + rd * append_dims(
+ s_in * (1 - alpha_cumprod_sqrt[i] ** 2) ** 0.5, len(prefix_frames.shape)
+ )
+ x = torch.cat([noised_prefix_frames, x[:, self.fixed_frames :]], dim=1)
+ else:
+ x = torch.cat([prefix_frames, x[:, self.fixed_frames :]], dim=1)
+ x, old_denoised = self.sampler_step(
+ old_denoised,
+ None if i == 0 else s_in * alpha_cumprod_sqrt[i - 1],
+ s_in * alpha_cumprod_sqrt[i],
+ s_in * alpha_cumprod_sqrt[i + 1],
+ denoiser,
+ x,
+ cond,
+ uc=uc,
+ idx=self.num_steps - i,
+ timestep=timesteps[-(i + 1)],
+ scale=scale,
+ scale_emb=scale_emb,
+ )
+
+ if self.fixed_frames > 0:
+ x = torch.cat([prefix_frames, x[:, self.fixed_frames :]], dim=1)
+
+ return x
+
+
+class VPODEDPMPP2MSampler(VideoDDIMSampler):
+ def get_variables(self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt=None):
+ alpha_cumprod = alpha_cumprod_sqrt**2
+ lamb = ((alpha_cumprod / (1 - alpha_cumprod)) ** 0.5).log()
+ next_alpha_cumprod = next_alpha_cumprod_sqrt**2
+ lamb_next = ((next_alpha_cumprod / (1 - next_alpha_cumprod)) ** 0.5).log()
+ h = lamb_next - lamb
+
+ if previous_alpha_cumprod_sqrt is not None:
+ previous_alpha_cumprod = previous_alpha_cumprod_sqrt**2
+ lamb_previous = ((previous_alpha_cumprod / (1 - previous_alpha_cumprod)) ** 0.5).log()
+ h_last = lamb - lamb_previous
+ r = h_last / h
+ return h, r, lamb, lamb_next
+ else:
+ return h, None, lamb, lamb_next
+
+ def get_mult(self, h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt):
+ mult1 = ((1 - next_alpha_cumprod_sqrt**2) / (1 - alpha_cumprod_sqrt**2)) ** 0.5
+ mult2 = (-h).expm1() * next_alpha_cumprod_sqrt
+
+ if previous_alpha_cumprod_sqrt is not None:
+ mult3 = 1 + 1 / (2 * r)
+ mult4 = 1 / (2 * r)
+ return mult1, mult2, mult3, mult4
+ else:
+ return mult1, mult2
+
+ def sampler_step(
+ self,
+ old_denoised,
+ previous_alpha_cumprod_sqrt,
+ alpha_cumprod_sqrt,
+ next_alpha_cumprod_sqrt,
+ denoiser,
+ x,
+ cond,
+ uc=None,
+ idx=None,
+ timestep=None,
+ ):
+ denoised = self.denoise(x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep, idx).to(torch.float32)
+ if idx == 1:
+ return denoised, denoised
+
+ h, r, lamb, lamb_next = self.get_variables(
+ alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt
+ )
+ mult = [
+ append_dims(mult, x.ndim)
+ for mult in self.get_mult(h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt)
+ ]
+
+ x_standard = mult[0] * x - mult[1] * denoised
+ if old_denoised is None or torch.sum(next_alpha_cumprod_sqrt) < 1e-14:
+ # Save a network evaluation if all noise levels are 0 or on the first step
+ return x_standard, denoised
+ else:
+ denoised_d = mult[2] * denoised - mult[3] * old_denoised
+ x_advanced = mult[0] * x - mult[1] * denoised_d
+
+ x = x_advanced
+
+ return x, denoised
+
+ def __call__(self, denoiser, x, cond, uc=None, num_steps=None, scale=None, **kwargs):
+ x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, timesteps = self.prepare_sampling_loop(
+ x, cond, uc, num_steps
+ )
+
+ old_denoised = None
+ for i in self.get_sigma_gen(num_sigmas):
+ x, old_denoised = self.sampler_step(
+ old_denoised,
+ None if i == 0 else s_in * alpha_cumprod_sqrt[i - 1],
+ s_in * alpha_cumprod_sqrt[i],
+ s_in * alpha_cumprod_sqrt[i + 1],
+ denoiser,
+ x,
+ cond,
+ uc=uc,
+ idx=self.num_steps - i,
+ timestep=timesteps[-(i + 1)],
+ )
+
+ return x
diff --git a/sat/sgm/modules/diffusionmodules/sampling_utils.py b/sat/sgm/modules/diffusionmodules/sampling_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..9bb1fa829659394f673e19e6144d9ab3a1faf5a1
--- /dev/null
+++ b/sat/sgm/modules/diffusionmodules/sampling_utils.py
@@ -0,0 +1,155 @@
+import torch
+from scipy import integrate
+
+from ...util import append_dims
+from einops import rearrange
+
+
+class NoDynamicThresholding:
+ def __call__(self, uncond, cond, scale):
+ scale = append_dims(scale, cond.ndim) if isinstance(scale, torch.Tensor) else scale
+ return uncond + scale * (cond - uncond)
+
+
+class StaticThresholding:
+ def __call__(self, uncond, cond, scale):
+ result = uncond + scale * (cond - uncond)
+ result = torch.clamp(result, min=-1.0, max=1.0)
+ return result
+
+
+def dynamic_threshold(x, p=0.95):
+ N, T, C, H, W = x.shape
+ x = rearrange(x, "n t c h w -> n c (t h w)")
+ l, r = x.quantile(q=torch.tensor([1 - p, p], device=x.device), dim=-1, keepdim=True)
+ s = torch.maximum(-l, r)
+ threshold_mask = (s > 1).expand(-1, -1, H * W * T)
+ if threshold_mask.any():
+ x = torch.where(threshold_mask, x.clamp(min=-1 * s, max=s), x)
+ x = rearrange(x, "n c (t h w) -> n t c h w", t=T, h=H, w=W)
+ return x
+
+
+def dynamic_thresholding2(x0):
+ p = 0.995 # A hyperparameter in the paper of "Imagen" [1].
+ origin_dtype = x0.dtype
+ x0 = x0.to(torch.float32)
+ s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
+ s = append_dims(torch.maximum(s, torch.ones_like(s).to(s.device)), x0.dim())
+ x0 = torch.clamp(x0, -s, s) # / s
+ return x0.to(origin_dtype)
+
+
+def latent_dynamic_thresholding(x0):
+ p = 0.9995
+ origin_dtype = x0.dtype
+ x0 = x0.to(torch.float32)
+ s = torch.quantile(torch.abs(x0), p, dim=2)
+ s = append_dims(s, x0.dim())
+ x0 = torch.clamp(x0, -s, s) / s
+ return x0.to(origin_dtype)
+
+
+def dynamic_thresholding3(x0):
+ p = 0.995 # A hyperparameter in the paper of "Imagen" [1].
+ origin_dtype = x0.dtype
+ x0 = x0.to(torch.float32)
+ s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
+ s = append_dims(torch.maximum(s, torch.ones_like(s).to(s.device)), x0.dim())
+ x0 = torch.clamp(x0, -s, s) # / s
+ return x0.to(origin_dtype)
+
+
+class DynamicThresholding:
+ def __call__(self, uncond, cond, scale):
+ mean = uncond.mean()
+ std = uncond.std()
+ result = uncond + scale * (cond - uncond)
+ result_mean, result_std = result.mean(), result.std()
+ result = (result - result_mean) / result_std * std
+ # result = dynamic_thresholding3(result)
+ return result
+
+
+class DynamicThresholdingV1:
+ def __init__(self, scale_factor):
+ self.scale_factor = scale_factor
+
+ def __call__(self, uncond, cond, scale):
+ result = uncond + scale * (cond - uncond)
+ unscaled_result = result / self.scale_factor
+ B, T, C, H, W = unscaled_result.shape
+ flattened = rearrange(unscaled_result, "b t c h w -> b c (t h w)")
+ means = flattened.mean(dim=2).unsqueeze(2)
+ recentered = flattened - means
+ magnitudes = recentered.abs().max()
+ normalized = recentered / magnitudes
+ thresholded = latent_dynamic_thresholding(normalized)
+ denormalized = thresholded * magnitudes
+ uncentered = denormalized + means
+ unflattened = rearrange(uncentered, "b c (t h w) -> b t c h w", t=T, h=H, w=W)
+ scaled_result = unflattened * self.scale_factor
+ return scaled_result
+
+
+class DynamicThresholdingV2:
+ def __call__(self, uncond, cond, scale):
+ B, T, C, H, W = uncond.shape
+ diff = cond - uncond
+ mim_target = uncond + diff * 4.0
+ cfg_target = uncond + diff * 8.0
+
+ mim_flattened = rearrange(mim_target, "b t c h w -> b c (t h w)")
+ cfg_flattened = rearrange(cfg_target, "b t c h w -> b c (t h w)")
+ mim_means = mim_flattened.mean(dim=2).unsqueeze(2)
+ cfg_means = cfg_flattened.mean(dim=2).unsqueeze(2)
+ mim_centered = mim_flattened - mim_means
+ cfg_centered = cfg_flattened - cfg_means
+
+ mim_scaleref = mim_centered.std(dim=2).unsqueeze(2)
+ cfg_scaleref = cfg_centered.std(dim=2).unsqueeze(2)
+
+ cfg_renormalized = cfg_centered / cfg_scaleref * mim_scaleref
+
+ result = cfg_renormalized + cfg_means
+ unflattened = rearrange(result, "b c (t h w) -> b t c h w", t=T, h=H, w=W)
+
+ return unflattened
+
+
+def linear_multistep_coeff(order, t, i, j, epsrel=1e-4):
+ if order - 1 > i:
+ raise ValueError(f"Order {order} too high for step {i}")
+
+ def fn(tau):
+ prod = 1.0
+ for k in range(order):
+ if j == k:
+ continue
+ prod *= (tau - t[i - k]) / (t[i - j] - t[i - k])
+ return prod
+
+ return integrate.quad(fn, t[i], t[i + 1], epsrel=epsrel)[0]
+
+
+def get_ancestral_step(sigma_from, sigma_to, eta=1.0):
+ if not eta:
+ return sigma_to, 0.0
+ sigma_up = torch.minimum(
+ sigma_to,
+ eta * (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5,
+ )
+ sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
+ return sigma_down, sigma_up
+
+
+def to_d(x, sigma, denoised):
+ return (x - denoised) / append_dims(sigma, x.ndim)
+
+
+def to_neg_log_sigma(sigma):
+ return sigma.log().neg()
+
+
+def to_sigma(neg_log_sigma):
+ return neg_log_sigma.neg().exp()
diff --git a/sat/sgm/modules/diffusionmodules/sigma_sampling.py b/sat/sgm/modules/diffusionmodules/sigma_sampling.py
new file mode 100644
index 0000000000000000000000000000000000000000..770de4254e54d594e7a46663ea58d4f2f660187e
--- /dev/null
+++ b/sat/sgm/modules/diffusionmodules/sigma_sampling.py
@@ -0,0 +1,80 @@
+import torch
+import torch.distributed
+
+from sat import mpu
+
+from ...util import default, instantiate_from_config
+
+
+class EDMSampling:
+ def __init__(self, p_mean=-1.2, p_std=1.2):
+ self.p_mean = p_mean
+ self.p_std = p_std
+
+ def __call__(self, n_samples, rand=None):
+ log_sigma = self.p_mean + self.p_std * default(rand, torch.randn((n_samples,)))
+ return log_sigma.exp()
+
+
+class DiscreteSampling:
+ def __init__(self, discretization_config, num_idx, do_append_zero=False, flip=True, uniform_sampling=False):
+ self.num_idx = num_idx
+ self.sigmas = instantiate_from_config(discretization_config)(num_idx, do_append_zero=do_append_zero, flip=flip)
+ world_size = mpu.get_data_parallel_world_size()
+ self.uniform_sampling = uniform_sampling
+ if self.uniform_sampling:
+ i = 1
+ while True:
+ if world_size % i != 0 or num_idx % (world_size // i) != 0:
+ i += 1
+ else:
+ self.group_num = world_size // i
+ break
+
+ assert self.group_num > 0
+ assert world_size % self.group_num == 0
+ self.group_width = world_size // self.group_num # the number of rank in one group
+ self.sigma_interval = self.num_idx // self.group_num
+
+ def idx_to_sigma(self, idx):
+ return self.sigmas[idx]
+
+ def __call__(self, n_samples, rand=None, return_idx=False):
+ if self.uniform_sampling:
+ rank = mpu.get_data_parallel_rank()
+ group_index = rank // self.group_width
+ idx = default(
+ rand,
+ torch.randint(
+ group_index * self.sigma_interval, (group_index + 1) * self.sigma_interval, (n_samples,)
+ ),
+ )
+ else:
+ idx = default(
+ rand,
+ torch.randint(0, self.num_idx, (n_samples,)),
+ )
+ if return_idx:
+ return self.idx_to_sigma(idx), idx
+ else:
+ return self.idx_to_sigma(idx)
+
+
+class PartialDiscreteSampling:
+ def __init__(self, discretization_config, total_num_idx, partial_num_idx, do_append_zero=False, flip=True):
+ self.total_num_idx = total_num_idx
+ self.partial_num_idx = partial_num_idx
+ self.sigmas = instantiate_from_config(discretization_config)(
+ total_num_idx, do_append_zero=do_append_zero, flip=flip
+ )
+
+ def idx_to_sigma(self, idx):
+ return self.sigmas[idx]
+
+ def __call__(self, n_samples, rand=None):
+ idx = default(
+ rand,
+ # torch.randint(self.total_num_idx-self.partial_num_idx, self.total_num_idx, (n_samples,)),
+ torch.randint(0, self.partial_num_idx, (n_samples,)),
+ )
+ return self.idx_to_sigma(idx)
diff --git a/sat/sgm/modules/diffusionmodules/util.py b/sat/sgm/modules/diffusionmodules/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..abf72a758fbf8a6f08145100223fea074fa64015
--- /dev/null
+++ b/sat/sgm/modules/diffusionmodules/util.py
@@ -0,0 +1,328 @@
+"""
+adopted from
+https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
+and
+https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
+and
+https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
+
+thanks!
+"""
+
+import math
+from typing import Optional
+
+import torch
+import torch.nn as nn
+from einops import rearrange, repeat
+
+
+def make_beta_schedule(
+ schedule,
+ n_timestep,
+ linear_start=1e-4,
+ linear_end=2e-2,
+):
+ if schedule == "linear":
+ betas = torch.linspace(linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64) ** 2
+ return betas.numpy()
+
+
+def extract_into_tensor(a, t, x_shape):
+ b, *_ = t.shape
+ out = a.gather(-1, t)
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
+
+
+def mixed_checkpoint(func, inputs: dict, params, flag):
+ """
+ Evaluate a function without caching intermediate activations, allowing for
+ reduced memory at the expense of extra compute in the backward pass. This differs from the original checkpoint function
+ borrowed from https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py in that
+ it also works with non-tensor inputs
+ :param func: the function to evaluate.
+ :param inputs: the argument dictionary to pass to `func`.
+ :param params: a sequence of parameters `func` depends on but does not
+ explicitly take as arguments.
+ :param flag: if False, disable gradient checkpointing.
+ """
+ if flag:
+ tensor_keys = [key for key in inputs if isinstance(inputs[key], torch.Tensor)]
+ tensor_inputs = [inputs[key] for key in inputs if isinstance(inputs[key], torch.Tensor)]
+ non_tensor_keys = [key for key in inputs if not isinstance(inputs[key], torch.Tensor)]
+ non_tensor_inputs = [inputs[key] for key in inputs if not isinstance(inputs[key], torch.Tensor)]
+ args = tuple(tensor_inputs) + tuple(non_tensor_inputs) + tuple(params)
+ return MixedCheckpointFunction.apply(
+ func,
+ len(tensor_inputs),
+ len(non_tensor_inputs),
+ tensor_keys,
+ non_tensor_keys,
+ *args,
+ )
+ else:
+ return func(**inputs)
+
+
+class MixedCheckpointFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ run_function,
+ length_tensors,
+ length_non_tensors,
+ tensor_keys,
+ non_tensor_keys,
+ *args,
+ ):
+ ctx.end_tensors = length_tensors
+ ctx.end_non_tensors = length_tensors + length_non_tensors
+ ctx.gpu_autocast_kwargs = {
+ "enabled": torch.is_autocast_enabled(),
+ "dtype": torch.get_autocast_gpu_dtype(),
+ "cache_enabled": torch.is_autocast_cache_enabled(),
+ }
+ assert len(tensor_keys) == length_tensors and len(non_tensor_keys) == length_non_tensors
+
+ ctx.input_tensors = {key: val for (key, val) in zip(tensor_keys, list(args[: ctx.end_tensors]))}
+ ctx.input_non_tensors = {
+ key: val for (key, val) in zip(non_tensor_keys, list(args[ctx.end_tensors : ctx.end_non_tensors]))
+ }
+ ctx.run_function = run_function
+ ctx.input_params = list(args[ctx.end_non_tensors :])
+
+ with torch.no_grad():
+ output_tensors = ctx.run_function(**ctx.input_tensors, **ctx.input_non_tensors)
+ return output_tensors
+
+ @staticmethod
+ def backward(ctx, *output_grads):
+ # additional_args = {key: ctx.input_tensors[key] for key in ctx.input_tensors if not isinstance(ctx.input_tensors[key],torch.Tensor)}
+ ctx.input_tensors = {key: ctx.input_tensors[key].detach().requires_grad_(True) for key in ctx.input_tensors}
+
+ with torch.enable_grad(), torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
+ # Fixes a bug where the first op in run_function modifies the
+ # Tensor storage in place, which is not allowed for detach()'d
+ # Tensors.
+ shallow_copies = {key: ctx.input_tensors[key].view_as(ctx.input_tensors[key]) for key in ctx.input_tensors}
+ # shallow_copies.update(additional_args)
+ output_tensors = ctx.run_function(**shallow_copies, **ctx.input_non_tensors)
+ input_grads = torch.autograd.grad(
+ output_tensors,
+ list(ctx.input_tensors.values()) + ctx.input_params,
+ output_grads,
+ allow_unused=True,
+ )
+ del ctx.input_tensors
+ del ctx.input_params
+ del output_tensors
+ return (
+ (None, None, None, None, None)
+ + input_grads[: ctx.end_tensors]
+ + (None,) * (ctx.end_non_tensors - ctx.end_tensors)
+ + input_grads[ctx.end_tensors :]
+ )
+
+
+def checkpoint(func, inputs, params, flag):
+ """
+ Evaluate a function without caching intermediate activations, allowing for
+ reduced memory at the expense of extra compute in the backward pass.
+ :param func: the function to evaluate.
+ :param inputs: the argument sequence to pass to `func`.
+ :param params: a sequence of parameters `func` depends on but does not
+ explicitly take as arguments.
+ :param flag: if False, disable gradient checkpointing.
+ """
+ if flag:
+ args = tuple(inputs) + tuple(params)
+ return CheckpointFunction.apply(func, len(inputs), *args)
+ else:
+ return func(*inputs)
+
+
+class CheckpointFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, run_function, length, *args):
+ ctx.run_function = run_function
+ ctx.input_tensors = list(args[:length])
+ ctx.input_params = list(args[length:])
+ ctx.gpu_autocast_kwargs = {
+ "enabled": torch.is_autocast_enabled(),
+ "dtype": torch.get_autocast_gpu_dtype(),
+ "cache_enabled": torch.is_autocast_cache_enabled(),
+ }
+ with torch.no_grad():
+ output_tensors = ctx.run_function(*ctx.input_tensors)
+ return output_tensors
+
+ @staticmethod
+ def backward(ctx, *output_grads):
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
+ with torch.enable_grad(), torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
+ # Fixes a bug where the first op in run_function modifies the
+ # Tensor storage in place, which is not allowed for detach()'d
+ # Tensors.
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
+ output_tensors = ctx.run_function(*shallow_copies)
+ input_grads = torch.autograd.grad(
+ output_tensors,
+ ctx.input_tensors + ctx.input_params,
+ output_grads,
+ allow_unused=True,
+ )
+ del ctx.input_tensors
+ del ctx.input_params
+ del output_tensors
+ return (None, None) + input_grads
+
+
+def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False, dtype=torch.float32):
+ """
+ Create sinusoidal timestep embeddings.
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param dim: the dimension of the output.
+ :param max_period: controls the minimum frequency of the embeddings.
+ :return: an [N x dim] Tensor of positional embeddings.
+ """
+ if not repeat_only:
+ half = dim // 2
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
+ device=timesteps.device
+ )
+ args = timesteps[:, None].float() * freqs[None]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+ else:
+ embedding = repeat(timesteps, "b -> b d", d=dim)
+ return embedding.to(dtype)
+
+
+def zero_module(module):
+ """
+ Zero out the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+
+def scale_module(module, scale):
+ """
+ Scale the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().mul_(scale)
+ return module
+
+
+def mean_flat(tensor):
+ """
+ Take the mean over all non-batch dimensions.
+ """
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
+
+
+def normalization(channels):
+ """
+ Make a standard normalization layer.
+ :param channels: number of input channels.
+ :return: an nn.Module for normalization.
+ """
+ return GroupNorm32(32, channels)
+
+
+# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
+class SiLU(nn.Module):
+ def forward(self, x):
+ return x * torch.sigmoid(x)
+
+
+class GroupNorm32(nn.GroupNorm):
+ def forward(self, x):
+ return super().forward(x).type(x.dtype)
+
+
+def conv_nd(dims, *args, **kwargs):
+ """
+ Create a 1D, 2D, or 3D convolution module.
+ """
+ if dims == 1:
+ return nn.Conv1d(*args, **kwargs)
+ elif dims == 2:
+ return nn.Conv2d(*args, **kwargs)
+ elif dims == 3:
+ return nn.Conv3d(*args, **kwargs)
+ raise ValueError(f"unsupported dimensions: {dims}")
+
+
+def linear(*args, **kwargs):
+ """
+ Create a linear module.
+ """
+ return nn.Linear(*args, **kwargs)
+
+
+def avg_pool_nd(dims, *args, **kwargs):
+ """
+ Create a 1D, 2D, or 3D average pooling module.
+ """
+ if dims == 1:
+ return nn.AvgPool1d(*args, **kwargs)
+ elif dims == 2:
+ return nn.AvgPool2d(*args, **kwargs)
+ elif dims == 3:
+ return nn.AvgPool3d(*args, **kwargs)
+ raise ValueError(f"unsupported dimensions: {dims}")
+
+
+class AlphaBlender(nn.Module):
+ strategies = ["learned", "fixed", "learned_with_images"]
+
+ def __init__(
+ self,
+ alpha: float,
+ merge_strategy: str = "learned_with_images",
+ rearrange_pattern: str = "b t -> (b t) 1 1",
+ ):
+ super().__init__()
+ self.merge_strategy = merge_strategy
+ self.rearrange_pattern = rearrange_pattern
+
+ assert merge_strategy in self.strategies, f"merge_strategy needs to be in {self.strategies}"
+
+ if self.merge_strategy == "fixed":
+ self.register_buffer("mix_factor", torch.Tensor([alpha]))
+ elif self.merge_strategy == "learned" or self.merge_strategy == "learned_with_images":
+ self.register_parameter("mix_factor", torch.nn.Parameter(torch.Tensor([alpha])))
+ else:
+ raise ValueError(f"unknown merge strategy {self.merge_strategy}")
+
+ def get_alpha(self, image_only_indicator: torch.Tensor) -> torch.Tensor:
+ if self.merge_strategy == "fixed":
+ alpha = self.mix_factor
+ elif self.merge_strategy == "learned":
+ alpha = torch.sigmoid(self.mix_factor)
+ elif self.merge_strategy == "learned_with_images":
+ assert image_only_indicator is not None, "need image_only_indicator ..."
+ alpha = torch.where(
+ image_only_indicator.bool(),
+ torch.ones(1, 1, device=image_only_indicator.device),
+ rearrange(torch.sigmoid(self.mix_factor), "... -> ... 1"),
+ )
+ alpha = rearrange(alpha, self.rearrange_pattern)
+ else:
+ raise NotImplementedError
+ return alpha
+
+ def forward(
+ self,
+ x_spatial: torch.Tensor,
+ x_temporal: torch.Tensor,
+ image_only_indicator: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ alpha = self.get_alpha(image_only_indicator)
+ x = alpha.to(x_spatial.dtype) * x_spatial + (1.0 - alpha).to(x_spatial.dtype) * x_temporal
+ return x
diff --git a/sat/sgm/modules/diffusionmodules/wrappers.py b/sat/sgm/modules/diffusionmodules/wrappers.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0b78ffd502fc67238752fbd5de7ee7c6661ce05
--- /dev/null
+++ b/sat/sgm/modules/diffusionmodules/wrappers.py
@@ -0,0 +1,41 @@
+import torch
+import torch.nn as nn
+from packaging import version
+
+OPENAIUNETWRAPPER = "sgm.modules.diffusionmodules.wrappers.OpenAIWrapper"
+
+
+class IdentityWrapper(nn.Module):
+ def __init__(self, diffusion_model, compile_model: bool = False, dtype: torch.dtype = torch.float32):
+ super().__init__()
+ compile = (
+ torch.compile
+ if (version.parse(torch.__version__) >= version.parse("2.0.0")) and compile_model
+ else lambda x: x
+ )
+ self.diffusion_model = compile(diffusion_model)
+ self.dtype = dtype
+
+ def forward(self, *args, **kwargs):
+ return self.diffusion_model(*args, **kwargs)
+
+
+class OpenAIWrapper(IdentityWrapper):
+ def forward(self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs) -> torch.Tensor:
+ for key in c:
+ c[key] = c[key].to(self.dtype)
+
+ if x.dim() == 4:
+ x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=1)
+ elif x.dim() == 5:
+ x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=2)
+ else:
+ raise ValueError("Input tensor must be 4D or 5D")
+
+ return self.diffusion_model(
+ x,
+ timesteps=t,
+ context=c.get("crossattn", None),
+ y=c.get("vector", None),
+ **kwargs,
+ )
diff --git a/sat/sgm/modules/distributions/__init__.py b/sat/sgm/modules/distributions/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/sat/sgm/modules/distributions/distributions.py b/sat/sgm/modules/distributions/distributions.py
new file mode 100644
index 0000000000000000000000000000000000000000..0338a861f90e4fbd1f8f4ba8712dde1316b4fa58
--- /dev/null
+++ b/sat/sgm/modules/distributions/distributions.py
@@ -0,0 +1,94 @@
+import numpy as np
+import torch
+
+
+class AbstractDistribution:
+ def sample(self):
+ raise NotImplementedError()
+
+ def mode(self):
+ raise NotImplementedError()
+
+
+class DiracDistribution(AbstractDistribution):
+ def __init__(self, value):
+ self.value = value
+
+ def sample(self):
+ return self.value
+
+ def mode(self):
+ return self.value
+
+
+class DiagonalGaussianDistribution(object):
+ def __init__(self, parameters, deterministic=False):
+ self.parameters = parameters
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
+ self.deterministic = deterministic
+ self.std = torch.exp(0.5 * self.logvar)
+ self.var = torch.exp(self.logvar)
+ if self.deterministic:
+ self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
+
+ def sample(self):
+ # x = self.mean + self.std * torch.randn(self.mean.shape).to(
+ # device=self.parameters.device
+ # )
+ x = self.mean + self.std * torch.randn_like(self.mean).to(device=self.parameters.device)
+ return x
+
+ def kl(self, other=None):
+ if self.deterministic:
+ return torch.Tensor([0.0])
+ else:
+ if other is None:
+ return 0.5 * torch.sum(
+ torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
+ dim=[1, 2, 3],
+ )
+ else:
+ return 0.5 * torch.sum(
+ torch.pow(self.mean - other.mean, 2) / other.var
+ + self.var / other.var
+ - 1.0
+ - self.logvar
+ + other.logvar,
+ dim=[1, 2, 3],
+ )
+
+ def nll(self, sample, dims=[1, 2, 3]):
+ if self.deterministic:
+ return torch.Tensor([0.0])
+ logtwopi = np.log(2.0 * np.pi)
+ return 0.5 * torch.sum(
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
+ dim=dims,
+ )
+
+ def mode(self):
+ return self.mean
+
+
+def normal_kl(mean1, logvar1, mean2, logvar2):
+ """
+ source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
+ Compute the KL divergence between two gaussians.
+ Shapes are automatically broadcasted, so batches can be compared to
+ scalars, among other use cases.
+ """
+ tensor = None
+ for obj in (mean1, logvar1, mean2, logvar2):
+ if isinstance(obj, torch.Tensor):
+ tensor = obj
+ break
+ assert tensor is not None, "at least one argument must be a Tensor"
+
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
+ # Tensors, but it does not work for torch.exp().
+ logvar1, logvar2 = [x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) for x in (logvar1, logvar2)]
+
+ return 0.5 * (
+ -1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
+ )
diff --git a/sat/sgm/modules/ema.py b/sat/sgm/modules/ema.py
new file mode 100644
index 0000000000000000000000000000000000000000..9f1f7606c2c9b68ebd2302215a9e08f9f31ed8ab
--- /dev/null
+++ b/sat/sgm/modules/ema.py
@@ -0,0 +1,82 @@
+import torch
+from torch import nn
+
+
+class LitEma(nn.Module):
+ def __init__(self, model, decay=0.9999, use_num_upates=True):
+ super().__init__()
+ if decay < 0.0 or decay > 1.0:
+ raise ValueError("Decay must be between 0 and 1")
+
+ self.m_name2s_name = {}
+ self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32))
+ self.register_buffer(
+ "num_updates",
+ torch.tensor(0, dtype=torch.int) if use_num_upates else torch.tensor(-1, dtype=torch.int),
+ )
+
+ for name, p in model.named_parameters():
+ if p.requires_grad:
+ # remove as '.'-character is not allowed in buffers
+ s_name = name.replace(".", "")
+ self.m_name2s_name.update({name: s_name})
+ self.register_buffer(s_name, p.clone().detach().data)
+
+ self.collected_params = []
+
+ def reset_num_updates(self):
+ del self.num_updates
+ self.register_buffer("num_updates", torch.tensor(0, dtype=torch.int))
+
+ def forward(self, model):
+ decay = self.decay
+
+ if self.num_updates >= 0:
+ self.num_updates += 1
+ decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
+
+ one_minus_decay = 1.0 - decay
+
+ with torch.no_grad():
+ m_param = dict(model.named_parameters())
+ shadow_params = dict(self.named_buffers())
+
+ for key in m_param:
+ if m_param[key].requires_grad:
+ sname = self.m_name2s_name[key]
+ shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
+ shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
+ else:
+ assert not key in self.m_name2s_name
+
+ def copy_to(self, model):
+ m_param = dict(model.named_parameters())
+ shadow_params = dict(self.named_buffers())
+ for key in m_param:
+ if m_param[key].requires_grad:
+ m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
+ else:
+ assert not key in self.m_name2s_name
+
+ def store(self, parameters):
+ """
+ Save the current parameters for restoring later.
+ Args:
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ temporarily stored.
+ """
+ self.collected_params = [param.clone() for param in parameters]
+
+ def restore(self, parameters):
+ """
+ Restore the parameters stored with the `store` method.
+ Useful to validate the model with EMA parameters without affecting the
+ original optimization process. Store the parameters before the
+ `copy_to` method. After validation (or model saving), use this to
+ restore the former parameters.
+ Args:
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ updated with the stored parameters.
+ """
+ for c_param, param in zip(self.collected_params, parameters):
+ param.data.copy_(c_param.data)
diff --git a/sat/sgm/modules/encoders/__init__.py b/sat/sgm/modules/encoders/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/sat/sgm/modules/encoders/modules.py b/sat/sgm/modules/encoders/modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..e8a16fcc90040a6ea61d68ef9db3f5bc75beb8da
--- /dev/null
+++ b/sat/sgm/modules/encoders/modules.py
@@ -0,0 +1,281 @@
+import math
+from contextlib import nullcontext
+from functools import partial
+from typing import Dict, List, Optional, Tuple, Union
+
+import kornia
+import numpy as np
+import torch
+import torch.nn as nn
+from einops import rearrange, repeat
+from omegaconf import ListConfig
+from torch.utils.checkpoint import checkpoint
+from transformers import (
+ T5EncoderModel,
+ T5Tokenizer,
+)
+
+from ...util import (
+ append_dims,
+ autocast,
+ count_params,
+ default,
+ disabled_train,
+ expand_dims_like,
+ instantiate_from_config,
+)
+
+
+class AbstractEmbModel(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self._is_trainable = None
+ self._ucg_rate = None
+ self._input_key = None
+
+ @property
+ def is_trainable(self) -> bool:
+ return self._is_trainable
+
+ @property
+ def ucg_rate(self) -> Union[float, torch.Tensor]:
+ return self._ucg_rate
+
+ @property
+ def input_key(self) -> str:
+ return self._input_key
+
+ @is_trainable.setter
+ def is_trainable(self, value: bool):
+ self._is_trainable = value
+
+ @ucg_rate.setter
+ def ucg_rate(self, value: Union[float, torch.Tensor]):
+ self._ucg_rate = value
+
+ @input_key.setter
+ def input_key(self, value: str):
+ self._input_key = value
+
+ @is_trainable.deleter
+ def is_trainable(self):
+ del self._is_trainable
+
+ @ucg_rate.deleter
+ def ucg_rate(self):
+ del self._ucg_rate
+
+ @input_key.deleter
+ def input_key(self):
+ del self._input_key
+
+
+class GeneralConditioner(nn.Module):
+ OUTPUT_DIM2KEYS = {2: "vector", 3: "crossattn", 4: "concat", 5: "concat"}
+ KEY2CATDIM = {"vector": 1, "crossattn": 2, "concat": 1}
+
+ def __init__(self, emb_models: Union[List, ListConfig], cor_embs=[], cor_p=[]):
+ super().__init__()
+ embedders = []
+ for n, embconfig in enumerate(emb_models):
+ embedder = instantiate_from_config(embconfig)
+ assert isinstance(
+ embedder, AbstractEmbModel
+ ), f"embedder model {embedder.__class__.__name__} has to inherit from AbstractEmbModel"
+ embedder.is_trainable = embconfig.get("is_trainable", False)
+ embedder.ucg_rate = embconfig.get("ucg_rate", 0.0)
+ if not embedder.is_trainable:
+ embedder.train = disabled_train
+ for param in embedder.parameters():
+ param.requires_grad = False
+ embedder.eval()
+ print(
+ f"Initialized embedder #{n}: {embedder.__class__.__name__} "
+ f"with {count_params(embedder, False)} params. Trainable: {embedder.is_trainable}"
+ )
+
+ if "input_key" in embconfig:
+ embedder.input_key = embconfig["input_key"]
+ elif "input_keys" in embconfig:
+ embedder.input_keys = embconfig["input_keys"]
+ else:
+ raise KeyError(f"need either 'input_key' or 'input_keys' for embedder {embedder.__class__.__name__}")
+
+ embedder.legacy_ucg_val = embconfig.get("legacy_ucg_value", None)
+ if embedder.legacy_ucg_val is not None:
+ embedder.ucg_prng = np.random.RandomState()
+
+ embedders.append(embedder)
+ self.embedders = nn.ModuleList(embedders)
+
+ if len(cor_embs) > 0:
+ assert len(cor_p) == 2 ** len(cor_embs)
+ self.cor_embs = cor_embs
+ self.cor_p = cor_p
+
+ def possibly_get_ucg_val(self, embedder: AbstractEmbModel, batch: Dict) -> Dict:
+ assert embedder.legacy_ucg_val is not None
+ p = embedder.ucg_rate
+ val = embedder.legacy_ucg_val
+ for i in range(len(batch[embedder.input_key])):
+ if embedder.ucg_prng.choice(2, p=[1 - p, p]):
+ batch[embedder.input_key][i] = val
+ return batch
+
+ def surely_get_ucg_val(self, embedder: AbstractEmbModel, batch: Dict, cond_or_not) -> Dict:
+ assert embedder.legacy_ucg_val is not None
+ val = embedder.legacy_ucg_val
+ for i in range(len(batch[embedder.input_key])):
+ if cond_or_not[i]:
+ batch[embedder.input_key][i] = val
+ return batch
+
+ def get_single_embedding(
+ self,
+ embedder,
+ batch,
+ output,
+ cond_or_not: Optional[np.ndarray] = None,
+ force_zero_embeddings: Optional[List] = None,
+ ):
+ embedding_context = nullcontext if embedder.is_trainable else torch.no_grad
+ with embedding_context():
+ if hasattr(embedder, "input_key") and (embedder.input_key is not None):
+ if embedder.legacy_ucg_val is not None:
+ if cond_or_not is None:
+ batch = self.possibly_get_ucg_val(embedder, batch)
+ else:
+ batch = self.surely_get_ucg_val(embedder, batch, cond_or_not)
+ emb_out = embedder(batch[embedder.input_key])
+ elif hasattr(embedder, "input_keys"):
+ emb_out = embedder(*[batch[k] for k in embedder.input_keys])
+ assert isinstance(
+ emb_out, (torch.Tensor, list, tuple)
+ ), f"encoder outputs must be tensors or a sequence, but got {type(emb_out)}"
+ if not isinstance(emb_out, (list, tuple)):
+ emb_out = [emb_out]
+ for emb in emb_out:
+ out_key = self.OUTPUT_DIM2KEYS[emb.dim()]
+ if embedder.ucg_rate > 0.0 and embedder.legacy_ucg_val is None:
+ if cond_or_not is None:
+ emb = (
+ expand_dims_like(
+ torch.bernoulli((1.0 - embedder.ucg_rate) * torch.ones(emb.shape[0], device=emb.device)),
+ emb,
+ )
+ * emb
+ )
+ else:
+ emb = (
+ expand_dims_like(
+ torch.tensor(1 - cond_or_not, dtype=emb.dtype, device=emb.device),
+ emb,
+ )
+ * emb
+ )
+ if hasattr(embedder, "input_key") and embedder.input_key in force_zero_embeddings:
+ emb = torch.zeros_like(emb)
+ if out_key in output:
+ output[out_key] = torch.cat((output[out_key], emb), self.KEY2CATDIM[out_key])
+ else:
+ output[out_key] = emb
+ return output
+
+ def forward(self, batch: Dict, force_zero_embeddings: Optional[List] = None) -> Dict:
+ output = dict()
+ if force_zero_embeddings is None:
+ force_zero_embeddings = []
+
+ if len(self.cor_embs) > 0:
+ batch_size = len(batch[list(batch.keys())[0]])
+ rand_idx = np.random.choice(len(self.cor_p), size=(batch_size,), p=self.cor_p)
+ for emb_idx in self.cor_embs:
+ cond_or_not = rand_idx % 2
+ rand_idx //= 2
+ output = self.get_single_embedding(
+ self.embedders[emb_idx],
+ batch,
+ output=output,
+ cond_or_not=cond_or_not,
+ force_zero_embeddings=force_zero_embeddings,
+ )
+
+ for i, embedder in enumerate(self.embedders):
+ if i in self.cor_embs:
+ continue
+ output = self.get_single_embedding(
+ embedder, batch, output=output, force_zero_embeddings=force_zero_embeddings
+ )
+ return output
+
+ def get_unconditional_conditioning(self, batch_c, batch_uc=None, force_uc_zero_embeddings=None):
+ if force_uc_zero_embeddings is None:
+ force_uc_zero_embeddings = []
+ ucg_rates = list()
+ for embedder in self.embedders:
+ ucg_rates.append(embedder.ucg_rate)
+ embedder.ucg_rate = 0.0
+ cor_embs = self.cor_embs
+ cor_p = self.cor_p
+ self.cor_embs = []
+ self.cor_p = []
+
+ c = self(batch_c)
+ uc = self(batch_c if batch_uc is None else batch_uc, force_uc_zero_embeddings)
+
+ for embedder, rate in zip(self.embedders, ucg_rates):
+ embedder.ucg_rate = rate
+ self.cor_embs = cor_embs
+ self.cor_p = cor_p
+
+ return c, uc
+
+
+class FrozenT5Embedder(AbstractEmbModel):
+ """Uses the T5 transformer encoder for text"""
+
+ def __init__(
+ self,
+ model_dir="google/t5-v1_1-xxl",
+ device="cuda",
+ max_length=77,
+ freeze=True,
+ cache_dir=None,
+ ):
+ super().__init__()
+ if model_dir is not "google/t5-v1_1-xxl":
+ self.tokenizer = T5Tokenizer.from_pretrained(model_dir)
+ self.transformer = T5EncoderModel.from_pretrained(model_dir)
+ else:
+ self.tokenizer = T5Tokenizer.from_pretrained(model_dir, cache_dir=cache_dir)
+ self.transformer = T5EncoderModel.from_pretrained(model_dir, cache_dir=cache_dir)
+ self.device = device
+ self.max_length = max_length
+ if freeze:
+ self.freeze()
+
+ def freeze(self):
+ self.transformer = self.transformer.eval()
+
+ for param in self.parameters():
+ param.requires_grad = False
+
+ # @autocast
+ def forward(self, text):
+ batch_encoding = self.tokenizer(
+ text,
+ truncation=True,
+ max_length=self.max_length,
+ return_length=True,
+ return_overflowing_tokens=False,
+ padding="max_length",
+ return_tensors="pt",
+ )
+ tokens = batch_encoding["input_ids"].to(self.device)
+ with torch.autocast("cuda", enabled=False):
+ outputs = self.transformer(input_ids=tokens)
+ z = outputs.last_hidden_state
+ return z
+
+ def encode(self, text):
+ return self(text)
diff --git a/sat/sgm/modules/video_attention.py b/sat/sgm/modules/video_attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..9f968d72e7d1ef5f11a68c289e76a8a1c9817312
--- /dev/null
+++ b/sat/sgm/modules/video_attention.py
@@ -0,0 +1,293 @@
+import torch
+
+from ..modules.attention import *
+from ..modules.diffusionmodules.util import AlphaBlender, linear, timestep_embedding
+
+
+class TimeMixSequential(nn.Sequential):
+ def forward(self, x, context=None, timesteps=None):
+ for layer in self:
+ x = layer(x, context, timesteps)
+
+ return x
+
+
+class VideoTransformerBlock(nn.Module):
+ ATTENTION_MODES = {
+ "softmax": CrossAttention,
+ "softmax-xformers": MemoryEfficientCrossAttention,
+ }
+
+ def __init__(
+ self,
+ dim,
+ n_heads,
+ d_head,
+ dropout=0.0,
+ context_dim=None,
+ gated_ff=True,
+ checkpoint=True,
+ timesteps=None,
+ ff_in=False,
+ inner_dim=None,
+ attn_mode="softmax",
+ disable_self_attn=False,
+ disable_temporal_crossattention=False,
+ switch_temporal_ca_to_sa=False,
+ ):
+ super().__init__()
+
+ attn_cls = self.ATTENTION_MODES[attn_mode]
+
+ self.ff_in = ff_in or inner_dim is not None
+ if inner_dim is None:
+ inner_dim = dim
+
+ assert int(n_heads * d_head) == inner_dim
+
+ self.is_res = inner_dim == dim
+
+ if self.ff_in:
+ self.norm_in = nn.LayerNorm(dim)
+ self.ff_in = FeedForward(dim, dim_out=inner_dim, dropout=dropout, glu=gated_ff)
+
+ self.timesteps = timesteps
+ self.disable_self_attn = disable_self_attn
+ if self.disable_self_attn:
+ self.attn1 = attn_cls(
+ query_dim=inner_dim,
+ heads=n_heads,
+ dim_head=d_head,
+ context_dim=context_dim,
+ dropout=dropout,
+ ) # is a cross-attention
+ else:
+ self.attn1 = attn_cls(
+ query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout
+ ) # is a self-attention
+
+ self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout, glu=gated_ff)
+
+ if disable_temporal_crossattention:
+ if switch_temporal_ca_to_sa:
+ raise ValueError
+ else:
+ self.attn2 = None
+ else:
+ self.norm2 = nn.LayerNorm(inner_dim)
+ if switch_temporal_ca_to_sa:
+ self.attn2 = attn_cls(
+ query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout
+ ) # is a self-attention
+ else:
+ self.attn2 = attn_cls(
+ query_dim=inner_dim,
+ context_dim=context_dim,
+ heads=n_heads,
+ dim_head=d_head,
+ dropout=dropout,
+ ) # is self-attn if context is none
+
+ self.norm1 = nn.LayerNorm(inner_dim)
+ self.norm3 = nn.LayerNorm(inner_dim)
+ self.switch_temporal_ca_to_sa = switch_temporal_ca_to_sa
+
+ self.checkpoint = checkpoint
+ if self.checkpoint:
+ print(f"{self.__class__.__name__} is using checkpointing")
+
+ def forward(self, x: torch.Tensor, context: torch.Tensor = None, timesteps: int = None) -> torch.Tensor:
+ if self.checkpoint:
+ return checkpoint(self._forward, x, context, timesteps)
+ else:
+ return self._forward(x, context, timesteps=timesteps)
+
+ def _forward(self, x, context=None, timesteps=None):
+ assert self.timesteps or timesteps
+ assert not (self.timesteps and timesteps) or self.timesteps == timesteps
+ timesteps = self.timesteps or timesteps
+ B, S, C = x.shape
+ x = rearrange(x, "(b t) s c -> (b s) t c", t=timesteps)
+
+ if self.ff_in:
+ x_skip = x
+ x = self.ff_in(self.norm_in(x))
+ if self.is_res:
+ x += x_skip
+
+ if self.disable_self_attn:
+ x = self.attn1(self.norm1(x), context=context) + x
+ else:
+ x = self.attn1(self.norm1(x)) + x
+
+ if self.attn2 is not None:
+ if self.switch_temporal_ca_to_sa:
+ x = self.attn2(self.norm2(x)) + x
+ else:
+ x = self.attn2(self.norm2(x), context=context) + x
+ x_skip = x
+ x = self.ff(self.norm3(x))
+ if self.is_res:
+ x += x_skip
+
+ x = rearrange(x, "(b s) t c -> (b t) s c", s=S, b=B // timesteps, c=C, t=timesteps)
+ return x
+
+ def get_last_layer(self):
+ return self.ff.net[-1].weight
+
+
+str_to_dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}
+
+
+class SpatialVideoTransformer(SpatialTransformer):
+ def __init__(
+ self,
+ in_channels,
+ n_heads,
+ d_head,
+ depth=1,
+ dropout=0.0,
+ use_linear=False,
+ context_dim=None,
+ use_spatial_context=False,
+ timesteps=None,
+ merge_strategy: str = "fixed",
+ merge_factor: float = 0.5,
+ time_context_dim=None,
+ ff_in=False,
+ checkpoint=False,
+ time_depth=1,
+ attn_mode="softmax",
+ disable_self_attn=False,
+ disable_temporal_crossattention=False,
+ max_time_embed_period: int = 10000,
+ dtype="fp32",
+ ):
+ super().__init__(
+ in_channels,
+ n_heads,
+ d_head,
+ depth=depth,
+ dropout=dropout,
+ attn_type=attn_mode,
+ use_checkpoint=checkpoint,
+ context_dim=context_dim,
+ use_linear=use_linear,
+ disable_self_attn=disable_self_attn,
+ )
+ self.time_depth = time_depth
+ self.depth = depth
+ self.max_time_embed_period = max_time_embed_period
+
+ time_mix_d_head = d_head
+ n_time_mix_heads = n_heads
+
+ time_mix_inner_dim = int(time_mix_d_head * n_time_mix_heads)
+
+ inner_dim = n_heads * d_head
+ if use_spatial_context:
+ time_context_dim = context_dim
+
+ self.time_stack = nn.ModuleList(
+ [
+ VideoTransformerBlock(
+ inner_dim,
+ n_time_mix_heads,
+ time_mix_d_head,
+ dropout=dropout,
+ context_dim=time_context_dim,
+ timesteps=timesteps,
+ checkpoint=checkpoint,
+ ff_in=ff_in,
+ inner_dim=time_mix_inner_dim,
+ attn_mode=attn_mode,
+ disable_self_attn=disable_self_attn,
+ disable_temporal_crossattention=disable_temporal_crossattention,
+ )
+ for _ in range(self.depth)
+ ]
+ )
+
+ assert len(self.time_stack) == len(self.transformer_blocks)
+
+ self.use_spatial_context = use_spatial_context
+ self.in_channels = in_channels
+
+ time_embed_dim = self.in_channels * 4
+ self.time_pos_embed = nn.Sequential(
+ linear(self.in_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, self.in_channels),
+ )
+
+ self.time_mixer = AlphaBlender(alpha=merge_factor, merge_strategy=merge_strategy)
+ self.dtype = str_to_dtype[dtype]
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ context: Optional[torch.Tensor] = None,
+ time_context: Optional[torch.Tensor] = None,
+ timesteps: Optional[int] = None,
+ image_only_indicator: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ _, _, h, w = x.shape
+ x_in = x
+ spatial_context = None
+ if exists(context):
+ spatial_context = context
+
+ if self.use_spatial_context:
+ assert context.ndim == 3, f"n dims of spatial context should be 3 but are {context.ndim}"
+
+ time_context = context
+ time_context_first_timestep = time_context[::timesteps]
+ time_context = repeat(time_context_first_timestep, "b ... -> (b n) ...", n=h * w)
+ elif time_context is not None and not self.use_spatial_context:
+ time_context = repeat(time_context, "b ... -> (b n) ...", n=h * w)
+ if time_context.ndim == 2:
+ time_context = rearrange(time_context, "b c -> b 1 c")
+
+ x = self.norm(x)
+ if not self.use_linear:
+ x = self.proj_in(x)
+ x = rearrange(x, "b c h w -> b (h w) c")
+ if self.use_linear:
+ x = self.proj_in(x)
+
+ num_frames = torch.arange(timesteps, device=x.device)
+ num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps)
+ num_frames = rearrange(num_frames, "b t -> (b t)")
+ t_emb = timestep_embedding(
+ num_frames,
+ self.in_channels,
+ repeat_only=False,
+ max_period=self.max_time_embed_period,
+ dtype=self.dtype,
+ )
+ emb = self.time_pos_embed(t_emb)
+ emb = emb[:, None, :]
+
+ for it_, (block, mix_block) in enumerate(zip(self.transformer_blocks, self.time_stack)):
+ x = block(
+ x,
+ context=spatial_context,
+ )
+
+ x_mix = x
+ x_mix = x_mix + emb
+
+ x_mix = mix_block(x_mix, context=time_context, timesteps=timesteps)
+ x = self.time_mixer(
+ x_spatial=x,
+ x_temporal=x_mix,
+ image_only_indicator=image_only_indicator,
+ )
+ if self.use_linear:
+ x = self.proj_out(x)
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
+ if not self.use_linear:
+ x = self.proj_out(x)
+ out = x + x_in
+ return out
diff --git a/sat/sgm/util.py b/sat/sgm/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..b93a04930b62d5cf5d9361b3153883cdebfdc28a
--- /dev/null
+++ b/sat/sgm/util.py
@@ -0,0 +1,383 @@
+import functools
+import importlib
+import os
+from functools import partial
+from inspect import isfunction
+
+import fsspec
+import numpy as np
+import torch
+from PIL import Image, ImageDraw, ImageFont
+from safetensors.torch import load_file as load_safetensors
+import torch.distributed
+
+_CONTEXT_PARALLEL_GROUP = None
+_CONTEXT_PARALLEL_SIZE = None
+
+
+def is_context_parallel_initialized():
+ if _CONTEXT_PARALLEL_GROUP is None:
+ return False
+ else:
+ return True
+
+
+def set_context_parallel_group(size, group):
+ global _CONTEXT_PARALLEL_GROUP
+ global _CONTEXT_PARALLEL_SIZE
+ _CONTEXT_PARALLEL_GROUP = group
+ _CONTEXT_PARALLEL_SIZE = size
+
+
+def initialize_context_parallel(context_parallel_size):
+ global _CONTEXT_PARALLEL_GROUP
+ global _CONTEXT_PARALLEL_SIZE
+
+ assert _CONTEXT_PARALLEL_GROUP is None, "context parallel group is already initialized"
+ _CONTEXT_PARALLEL_SIZE = context_parallel_size
+
+ rank = torch.distributed.get_rank()
+ world_size = torch.distributed.get_world_size()
+
+ for i in range(0, world_size, context_parallel_size):
+ ranks = range(i, i + context_parallel_size)
+ group = torch.distributed.new_group(ranks)
+ if rank in ranks:
+ _CONTEXT_PARALLEL_GROUP = group
+ break
+
+
+def get_context_parallel_group():
+ assert _CONTEXT_PARALLEL_GROUP is not None, "context parallel group is not initialized"
+
+ return _CONTEXT_PARALLEL_GROUP
+
+
+def get_context_parallel_world_size():
+ assert _CONTEXT_PARALLEL_SIZE is not None, "context parallel size is not initialized"
+
+ return _CONTEXT_PARALLEL_SIZE
+
+
+def get_context_parallel_rank():
+ assert _CONTEXT_PARALLEL_SIZE is not None, "context parallel size is not initialized"
+
+ rank = torch.distributed.get_rank()
+ cp_rank = rank % _CONTEXT_PARALLEL_SIZE
+ return cp_rank
+
+
+def get_context_parallel_group_rank():
+ assert _CONTEXT_PARALLEL_SIZE is not None, "context parallel size is not initialized"
+
+ rank = torch.distributed.get_rank()
+ cp_group_rank = rank // _CONTEXT_PARALLEL_SIZE
+
+ return cp_group_rank
+
+
+class SafeConv3d(torch.nn.Conv3d):
+ def forward(self, input):
+ memory_count = torch.prod(torch.tensor(input.shape)).item() * 2 / 1024**3
+ if memory_count > 2:
+ # print(f"WARNING: Conv3d with {memory_count:.2f}GB")
+ kernel_size = self.kernel_size[0]
+ part_num = int(memory_count / 2) + 1
+ input_chunks = torch.chunk(input, part_num, dim=2) # NCTHW
+ if kernel_size > 1:
+ input_chunks = [input_chunks[0]] + [
+ torch.cat((input_chunks[i - 1][:, :, -kernel_size + 1 :], input_chunks[i]), dim=2)
+ for i in range(1, len(input_chunks))
+ ]
+
+ output_chunks = []
+ for input_chunk in input_chunks:
+ output_chunks.append(super(SafeConv3d, self).forward(input_chunk))
+ output = torch.cat(output_chunks, dim=2)
+ return output
+ else:
+ return super(SafeConv3d, self).forward(input)
+
+
+def disabled_train(self, mode=True):
+ """Overwrite model.train with this function to make sure train/eval mode
+ does not change anymore."""
+ return self
+
+
+def get_string_from_tuple(s):
+ try:
+ # Check if the string starts and ends with parentheses
+ if s[0] == "(" and s[-1] == ")":
+ # Convert the string to a tuple
+ t = eval(s)
+ # Check if the type of t is tuple
+ if type(t) == tuple:
+ return t[0]
+ else:
+ pass
+ except:
+ pass
+ return s
+
+
+def is_power_of_two(n):
+ """
+ chat.openai.com/chat
+ Return True if n is a power of 2, otherwise return False.
+
+ The function is_power_of_two takes an integer n as input and returns True if n is a power of 2, otherwise it returns False.
+ The function works by first checking if n is less than or equal to 0. If n is less than or equal to 0, it can't be a power of 2, so the function returns False.
+ If n is greater than 0, the function checks whether n is a power of 2 by using a bitwise AND operation between n and n-1. If n is a power of 2, then it will have only one bit set to 1 in its binary representation. When we subtract 1 from a power of 2, all the bits to the right of that bit become 1, and the bit itself becomes 0. So, when we perform a bitwise AND between n and n-1, we get 0 if n is a power of 2, and a non-zero value otherwise.
+ Thus, if the result of the bitwise AND operation is 0, then n is a power of 2 and the function returns True. Otherwise, the function returns False.
+
+ """
+ if n <= 0:
+ return False
+ return (n & (n - 1)) == 0
+
+
+def autocast(f, enabled=True):
+ def do_autocast(*args, **kwargs):
+ with torch.cuda.amp.autocast(
+ enabled=enabled,
+ dtype=torch.get_autocast_gpu_dtype(),
+ cache_enabled=torch.is_autocast_cache_enabled(),
+ ):
+ return f(*args, **kwargs)
+
+ return do_autocast
+
+
+def load_partial_from_config(config):
+ return partial(get_obj_from_str(config["target"]), **config.get("params", dict()))
+
+
+def log_txt_as_img(wh, xc, size=10):
+ # wh a tuple of (width, height)
+ # xc a list of captions to plot
+ b = len(xc)
+ txts = list()
+ for bi in range(b):
+ txt = Image.new("RGB", wh, color="white")
+ draw = ImageDraw.Draw(txt)
+ font = ImageFont.truetype("data/DejaVuSans.ttf", size=size)
+ nc = int(40 * (wh[0] / 256))
+ if isinstance(xc[bi], list):
+ text_seq = xc[bi][0]
+ else:
+ text_seq = xc[bi]
+ lines = "\n".join(text_seq[start : start + nc] for start in range(0, len(text_seq), nc))
+
+ try:
+ draw.text((0, 0), lines, fill="black", font=font)
+ except UnicodeEncodeError:
+ print("Cant encode string for logging. Skipping.")
+
+ txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
+ txts.append(txt)
+ txts = np.stack(txts)
+ txts = torch.tensor(txts)
+ return txts
+
+
+def partialclass(cls, *args, **kwargs):
+ class NewCls(cls):
+ __init__ = functools.partialmethod(cls.__init__, *args, **kwargs)
+
+ return NewCls
+
+
+def make_path_absolute(path):
+ fs, p = fsspec.core.url_to_fs(path)
+ if fs.protocol == "file":
+ return os.path.abspath(p)
+ return path
+
+
+def ismap(x):
+ if not isinstance(x, torch.Tensor):
+ return False
+ return (len(x.shape) == 4) and (x.shape[1] > 3)
+
+
+def isimage(x):
+ if not isinstance(x, torch.Tensor):
+ return False
+ return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
+
+
+def isheatmap(x):
+ if not isinstance(x, torch.Tensor):
+ return False
+
+ return x.ndim == 2
+
+
+def isneighbors(x):
+ if not isinstance(x, torch.Tensor):
+ return False
+ return x.ndim == 5 and (x.shape[2] == 3 or x.shape[2] == 1)
+
+
+def exists(x):
+ return x is not None
+
+
+def expand_dims_like(x, y):
+ while x.dim() != y.dim():
+ x = x.unsqueeze(-1)
+ return x
+
+
+def default(val, d):
+ if exists(val):
+ return val
+ return d() if isfunction(d) else d
+
+
+def mean_flat(tensor):
+ """
+ https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
+ Take the mean over all non-batch dimensions.
+ """
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
+
+
+def count_params(model, verbose=False):
+ total_params = sum(p.numel() for p in model.parameters())
+ if verbose:
+ print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
+ return total_params
+
+
+def instantiate_from_config(config, **extra_kwargs):
+ if not "target" in config:
+ if config == "__is_first_stage__":
+ return None
+ elif config == "__is_unconditional__":
+ return None
+ raise KeyError("Expected key `target` to instantiate.")
+ return get_obj_from_str(config["target"])(**config.get("params", dict()), **extra_kwargs)
+
+
+def get_obj_from_str(string, reload=False, invalidate_cache=True):
+ module, cls = string.rsplit(".", 1)
+ if invalidate_cache:
+ importlib.invalidate_caches()
+ if reload:
+ module_imp = importlib.import_module(module)
+ importlib.reload(module_imp)
+ return getattr(importlib.import_module(module, package=None), cls)
+
+
+def append_zero(x):
+ return torch.cat([x, x.new_zeros([1])])
+
+
+def append_dims(x, target_dims):
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
+ dims_to_append = target_dims - x.ndim
+ if dims_to_append < 0:
+ raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
+ return x[(...,) + (None,) * dims_to_append]
+
+
+def load_model_from_config(config, ckpt, verbose=True, freeze=True):
+ print(f"Loading model from {ckpt}")
+ if ckpt.endswith("ckpt"):
+ pl_sd = torch.load(ckpt, map_location="cpu")
+ if "global_step" in pl_sd:
+ print(f"Global Step: {pl_sd['global_step']}")
+ sd = pl_sd["state_dict"]
+ elif ckpt.endswith("safetensors"):
+ sd = load_safetensors(ckpt)
+ else:
+ raise NotImplementedError
+
+ model = instantiate_from_config(config.model)
+
+ m, u = model.load_state_dict(sd, strict=False)
+
+ if len(m) > 0 and verbose:
+ print("missing keys:")
+ print(m)
+ if len(u) > 0 and verbose:
+ print("unexpected keys:")
+ print(u)
+
+ if freeze:
+ for param in model.parameters():
+ param.requires_grad = False
+
+ model.eval()
+ return model
+
+
+def get_configs_path() -> str:
+ """
+ Get the `configs` directory.
+ For a working copy, this is the one in the root of the repository,
+ but for an installed copy, it's in the `sgm` package (see pyproject.toml).
+ """
+ this_dir = os.path.dirname(__file__)
+ candidates = (
+ os.path.join(this_dir, "configs"),
+ os.path.join(this_dir, "..", "configs"),
+ )
+ for candidate in candidates:
+ candidate = os.path.abspath(candidate)
+ if os.path.isdir(candidate):
+ return candidate
+ raise FileNotFoundError(f"Could not find SGM configs in {candidates}")
+
+
+def get_nested_attribute(obj, attribute_path, depth=None, return_key=False):
+ """
+ Will return the result of a recursive get attribute call.
+ E.g.:
+ a.b.c
+ = getattr(getattr(a, "b"), "c")
+ = get_nested_attribute(a, "b.c")
+ If any part of the attribute call is an integer x with current obj a, will
+ try to call a[x] instead of a.x first.
+ """
+ attributes = attribute_path.split(".")
+ if depth is not None and depth > 0:
+ attributes = attributes[:depth]
+ assert len(attributes) > 0, "At least one attribute should be selected"
+ current_attribute = obj
+ current_key = None
+ for level, attribute in enumerate(attributes):
+ current_key = ".".join(attributes[: level + 1])
+ try:
+ id_ = int(attribute)
+ current_attribute = current_attribute[id_]
+ except ValueError:
+ current_attribute = getattr(current_attribute, attribute)
+
+ return (current_attribute, current_key) if return_key else current_attribute
+
+
+from math import sqrt
+
+
+class SeededNoise:
+ def __init__(self, seeds, weights):
+ self.seeds = seeds
+ self.weights = weights
+ weight_square_sum = 0
+ for weight in weights:
+ weight_square_sum += weight**2
+ self.weight_square_sum_sqrt = sqrt(weight_square_sum)
+ self.cnt = 0
+
+ def __call__(self, x):
+ self.cnt += 1
+ randn_combined = torch.zeros_like(x)
+ for seed, weight in zip(self.seeds, self.weights):
+ randn = np.random.RandomState(seed + self.cnt).randn(*x.shape)
+ randn = torch.from_numpy(randn, dtype=x.dtype, device=x.device)
+ randn_combined += randn * weight
+ randn_combined /= self.weight_square_sum_sqrt
+ return randn_combined
diff --git a/sat/sgm/webds.py b/sat/sgm/webds.py
new file mode 100644
index 0000000000000000000000000000000000000000..b99f9f337e2f19532cda3237473ee069dbbc4f9b
--- /dev/null
+++ b/sat/sgm/webds.py
@@ -0,0 +1,389 @@
+import sys
+import io
+import os
+import re
+import json
+import tarfile
+from functools import partial
+
+import webdataset as wds
+from webdataset import ResampledShards, DataPipeline, tarfile_to_samples
+from webdataset.filters import pipelinefilter
+from webdataset.tariterators import url_opener, group_by_keys
+from webdataset.handlers import reraise_exception
+from webdataset.gopen import gopen_schemes, gopen
+
+
+def pytorch_worker_info(group=None): # sourcery skip: use-contextlib-suppress
+ """Return node and worker info for PyTorch and some distributed environments."""
+ rank = 0
+ world_size = 1
+ worker = 0
+ num_workers = 1
+ try:
+ import torch.distributed
+
+ if torch.distributed.is_available() and torch.distributed.is_initialized():
+ group = group or torch.distributed.group.WORLD
+ rank = torch.distributed.get_rank(group=group)
+ world_size = torch.distributed.get_world_size(group=group)
+ except ModuleNotFoundError:
+ pass
+ try:
+ import torch.utils.data
+
+ worker_info = torch.utils.data.get_worker_info()
+ if worker_info is not None:
+ worker = worker_info.id
+ num_workers = worker_info.num_workers
+ except ModuleNotFoundError:
+ pass
+
+ return rank, world_size, worker, num_workers
+
+
+def pytorch_worker_seed(group=None):
+ """Compute a distinct, deterministic RNG seed for each worker and node."""
+ rank, world_size, worker, num_workers = pytorch_worker_info(group=group)
+ return rank * 1000 + worker
+
+
+def worker_seed_sat(group=None, seed=0):
+ return pytorch_worker_seed(group=group) + seed * 23
+
+
+class ConfiguredResampledShards(ResampledShards):
+ def __init__(self, urls, seed, nshards=sys.maxsize, deterministic=True):
+ from sat.helpers import print_rank0
+
+ try:
+ from megatron.core.parallel_state import get_data_parallel_group
+
+ group = get_data_parallel_group()
+ print_rank0("Using megatron data parallel group.")
+ except:
+ from sat.mpu import get_data_parallel_group
+
+ try:
+ group = get_data_parallel_group()
+ print_rank0("Using sat data parallel group.")
+ except AssertionError:
+ group = None
+ print_rank0("No data parallel group is specified!")
+ worker_seed_sat_this = partial(worker_seed_sat, group=group, seed=seed)
+ super().__init__(urls, nshards, worker_seed_sat_this, deterministic)
+
+
+class SimpleDistributedWebDataset(DataPipeline):
+ def __init__(self, path, process_fn, seed, *, shuffle_buffer=1000):
+ # set shuffle_buffer = 1 to disable it, model-parallel will be different due to shuffle
+ try:
+ from sat.mpu import get_model_parallel_world_size
+
+ if get_model_parallel_world_size() > 1:
+ shuffle_buffer = 1
+ except Exception:
+ pass
+ super().__init__(
+ ConfiguredResampledShards(path, seed), # Lots of shards are recommended, or not evenly
+ tarfile_to_samples(),
+ wds.shuffle(shuffle_buffer),
+ process_fn,
+ )
+
+
+def tar_file_iterator_with_meta(
+ fileobj, meta_names, skip_meta=r"__[^/]*__($|/)", suffix=None, handler=reraise_exception, meta_stream=None
+):
+ """Iterate over tar file, yielding filename, content pairs for the given tar stream.
+
+ :param fileobj: byte stream suitable for tarfile
+ :param meta_names: key of different items in meta file
+ :param skip_meta: regexp for keys that are skipped entirely (Default value = r"__[^/]*__($|/)")
+
+ """
+ stream = tarfile.open(fileobj=fileobj, mode="r|*")
+ data_dir, filename = fileobj.name.rsplit("/", 1)
+ meta_data = {} # {id: {meta_name: meta_value, meta_name2: meta_value2, ...}}
+
+ if meta_stream is None:
+ meta_file_name = filename.split(".")[0] + ".meta.jsonl"
+ meta_path = os.path.join(data_dir, meta_file_name)
+ if os.path.exists(meta_path):
+ meta_stream = open(meta_path, "r")
+ else:
+ meta_file_name = meta_stream.name
+
+ if meta_stream is not None:
+ for lineno, line in enumerate(meta_stream):
+ meta_list = []
+ try:
+ meta_list.append(json.loads(line))
+ except Exception as exn:
+ from sat.helpers import print_rank0
+
+ print_rank0(f"Error in loading jsonl {meta_file_name}, lineno {lineno}: {line}", level="DEBUG")
+ continue
+ for item in meta_list:
+ if not item["key"] in meta_data:
+ meta_data[item["key"]] = {}
+ for meta_name in meta_names:
+ if meta_name in item:
+ meta_data[item["key"]][meta_name] = item[meta_name]
+ meta_stream.close()
+
+ try:
+ for tarinfo in stream:
+ fname = tarinfo.name
+ try:
+ if not tarinfo.isreg():
+ continue
+ if fname is None:
+ continue
+ if "/" not in fname and fname.startswith("__") and fname.endswith("__"):
+ # skipping metadata for now
+ continue
+ if skip_meta is not None and re.match(skip_meta, fname):
+ continue
+ if fname.endswith(".txt") and suffix is not None:
+ data = (stream.extractfile(tarinfo).read().decode() + suffix).encode()
+ else:
+ data = stream.extractfile(tarinfo).read()
+ result = dict(fname=fname, data=data)
+ yield result
+
+ if fname.endswith(".id"):
+ fid = fname.split(".")[0]
+ if "-$#%@&" in fid:
+ sfid = fid.split("-$#%@&")[0]
+ else:
+ sfid = fid
+ meta_data_fid = meta_data.get(sfid, {})
+ for meta_name in meta_names:
+ meta_fname = fid + "." + meta_name
+ meta = meta_data_fid.get(meta_name, None)
+ yield dict(fname=meta_fname, data=meta)
+ stream.members = []
+ except Exception as exn:
+ if hasattr(exn, "args") and len(exn.args) > 0:
+ exn.args = (exn.args[0] + " @ " + str(fileobj),) + exn.args[1:]
+ if handler(exn):
+ continue
+ else:
+ break
+ except Exception as exn:
+ print(exn)
+ del stream
+
+
+def tar_file_expander_with_meta(data, meta_names, handler=reraise_exception):
+ """Expand a stream of open tar files into a stream of tar file contents.
+
+ This returns an iterator over (filename, file_contents).
+ """
+ for source in data:
+ url = source["url"]
+ try:
+ assert isinstance(source, dict)
+ assert "stream" in source
+ for sample in tar_file_iterator_with_meta(source["stream"], meta_names, meta_stream=source["meta_stream"]):
+ assert isinstance(sample, dict) and "data" in sample and "fname" in sample
+ sample["__url__"] = url
+ yield sample
+ except Exception as exn:
+ exn.args = exn.args + (source.get("stream"), source.get("url"))
+ if handler(exn):
+ continue
+ else:
+ break
+
+
+def url_opener(
+ data,
+ handler,
+ **kw,
+):
+ """Open URLs and yield a stream of url+stream pairs.
+
+ Args:
+ data: iterator over dict(url=...)
+ handler: exception handler.
+ kw: keyword arguments for gopen.gopen.
+
+ Yields:
+ a stream of url+stream pairs.
+ """
+ for sample in data:
+ assert isinstance(sample, dict), sample
+ assert "url" in sample
+ url = sample["url"]
+ try:
+ stream = gopen(url, **kw)
+ if hasattr(stream, "meta_stream"):
+ meta_stream = stream.meta_stream
+ del stream.meta_stream
+ else:
+ meta_stream = None
+ sample.update(stream=stream, meta_stream=meta_stream)
+ yield sample
+ except Exception as exn:
+ exn.args = exn.args + (url,)
+ if handler(exn):
+ continue
+ else:
+ break
+
+
+def tarfile_samples_with_meta(src, meta_names, handler=reraise_exception):
+ streams = url_opener(src, handler=handler)
+ files = tar_file_expander_with_meta(streams, meta_names, handler)
+ samples = group_by_keys(files, handler=handler)
+ return samples
+
+
+class MetaDistributedWebDataset(DataPipeline):
+ """WebDataset with meta information files
+ Extra Format:
+ in webdataset (tar), for each sample there is a '.id';
+ for each tar file, there is a '.meta.jsonl' file with the same name;
+ The '.meta.jsonl' file contains lines of json objects, each with a 'key' field to match '.id'.
+ """
+
+ def __init__(
+ self, path, process_fn, seed, *, meta_names=[], nshards=sys.maxsize, shuffle_buffer=1000, include_dirs=None
+ ):
+ # os.environ['WDS_SHOW_SEED'] = '1'
+ import torch
+
+ if torch.distributed.get_rank() == 0:
+ if include_dirs is not None: # /webdatasets/A,/webdatasets/C
+ other_paths = []
+ include_dirs = include_dirs.split(",")
+ for include_dir in include_dirs:
+ if "*" in include_dir:
+ include_dir, n = include_dir.split("*")
+ n = int(n)
+ else:
+ n = 1
+ for cur_dir, dirs, files in os.walk(include_dir):
+ for f in files:
+ if f.endswith("tar") and os.path.getsize(os.path.join(cur_dir, f)) > 0:
+ # other_paths.append(os.path.join(cur_dir,f))
+ other_paths.extend([os.path.join(cur_dir, f)] * n)
+ # print(f'Adding dataset paths {",".join(other_paths)}')
+ from braceexpand import braceexpand
+
+ if len(path) > 0: # not ""
+ path = list(braceexpand(path)) + other_paths
+ else:
+ path = other_paths
+ path = [path]
+ else:
+ path = [
+ None,
+ ]
+ torch.distributed.broadcast_object_list(path, src=0)
+ path = path[0]
+
+ tarfile_samples = partial(tarfile_samples_with_meta, meta_names=meta_names)
+ tarfile_to_samples = pipelinefilter(tarfile_samples)
+
+ # if model parallel, shuffle_buffer should be 1 to disable shuffling
+ try:
+ from sat.mpu import get_model_parallel_world_size
+
+ if get_model_parallel_world_size() > 1:
+ shuffle_buffer = 1
+ except Exception:
+ pass
+
+ super().__init__(
+ ConfiguredResampledShards(path, seed, nshards=nshards),
+ tarfile_to_samples(),
+ wds.shuffle(shuffle_buffer),
+ process_fn,
+ )
+
+
+# rclone support
+from webdataset.gopen import Pipe
+
+
+def gopen_rclone(url, mode="rb", bufsize=1024 * 1024 * 32):
+ """Open a URL with `curl`.
+
+ :param url: rclone url, e.g. data:bucket1/foo.tar. data should be configured.
+ :param mode: file mode
+ :param bufsize: buffer size
+ """
+ url = url.replace("rclone://", "")
+ if mode[0] == "r":
+ cmd = f"rclone cat '{url}'"
+ return Pipe(
+ cmd,
+ mode=mode,
+ shell=True,
+ bufsize=bufsize,
+ ignore_status=[141, 23],
+ ) # skipcq: BAN-B604
+ elif mode[0] == "w":
+ cmd = f"rclone cp - '{url}'"
+ return Pipe(
+ cmd,
+ mode=mode,
+ shell=True,
+ bufsize=bufsize,
+ ignore_status=[141, 26],
+ ) # skipcq: BAN-B604
+ else:
+ raise ValueError(f"{mode}: unknown mode")
+
+
+def gopen_boto3(url, mode="rb", bufsize=8192 * 2):
+ """Open a URL with boto3 API.
+
+ :param url: boto3 url, e.g. boto3://bucket1/foo.tar. data should be configured.
+ :param mode: file mode
+ :param bufsize: buffer size
+ """
+ import boto3
+
+ # boto3.set_stream_logger('botocore', level='DEBUG')
+ if url.startswith("boto3://"):
+ url = url.replace("boto3://", "")
+ need_meta = False
+ else:
+ url = url.replace("metaboto3://", "")
+ need_meta = True
+ endpoint_url = os.environ.get("S3_ENDPOINT_URL", None)
+ access_key = os.environ.get("S3_ACCESS_KEY_ID", None)
+ secret_key = os.environ.get("S3_SECRET_ACCESS_KEY", None)
+
+ if mode[0] == "r":
+ s3_client = boto3.client(
+ "s3", endpoint_url=endpoint_url, aws_access_key_id=access_key, aws_secret_access_key=secret_key
+ )
+ bucket, key = url.split("/", 1)
+
+ if need_meta:
+ # download a meta json
+ meta_file_key = key.split(".")[0] + ".meta.jsonl"
+ meta_stream = io.BytesIO()
+ s3_client.download_fileobj(bucket, meta_file_key, meta_stream)
+ meta_stream.seek(0)
+ meta_stream.name = meta_file_key
+ else:
+ meta_stream = None
+
+ # data tar stream
+ response = s3_client.get_object(Bucket=bucket, Key=key) # Range optional
+ response["Body"].name = key # actually not used
+ response["Body"].meta_stream = meta_stream
+ return response["Body"]
+ else:
+ raise ValueError(f"{mode}: unknown mode")
+
+
+gopen_schemes["rclone"] = gopen_rclone
+gopen_schemes["boto3"] = gopen_boto3
+gopen_schemes["metaboto3"] = gopen_boto3
diff --git a/sat/train_video.py b/sat/train_video.py
new file mode 100644
index 0000000000000000000000000000000000000000..62f2136b737c21b96f70719653771b26ef4f5c8f
--- /dev/null
+++ b/sat/train_video.py
@@ -0,0 +1,233 @@
+import os
+import argparse
+from functools import partial
+from PIL import Image
+import numpy as np
+import torch.distributed
+import torchvision
+from omegaconf import OmegaConf
+import imageio
+
+import torch
+
+from sat import mpu
+from sat.training.deepspeed_training import training_main
+
+from sgm.util import get_obj_from_str, isheatmap, exists
+
+from diffusion_video import SATVideoDiffusionEngine
+from arguments import get_args, process_config_to_args
+
+from einops import rearrange, repeat
+
+try:
+ import wandb
+except ImportError:
+ print("warning: wandb not installed")
+
+
+def print_debug(args, s):
+ if args.debug:
+ s = f"RANK:[{torch.distributed.get_rank()}]:" + s
+ print(s)
+
+
+def save_texts(texts, save_dir, iterations):
+ output_path = os.path.join(save_dir, f"{str(iterations).zfill(8)}")
+ with open(output_path, "w", encoding="utf-8") as f:
+ for text in texts:
+ f.write(text + "\n")
+
+
+def save_video_as_grid_and_mp4(video_batch: torch.Tensor, save_path: str, T: int, fps: int = 5, args=None, key=None):
+ os.makedirs(save_path, exist_ok=True)
+
+ for i, vid in enumerate(video_batch):
+ gif_frames = []
+ for frame in vid:
+ frame = rearrange(frame, "c h w -> h w c")
+ frame = (255.0 * frame).cpu().numpy().astype(np.uint8)
+ gif_frames.append(frame)
+ now_save_path = os.path.join(save_path, f"{i:06d}.mp4")
+ with imageio.get_writer(now_save_path, fps=fps) as writer:
+ for frame in gif_frames:
+ writer.append_data(frame)
+ if args is not None and args.wandb:
+ wandb.log(
+ {key + f"_video_{i}": wandb.Video(now_save_path, fps=fps, format="mp4")}, step=args.iteration + 1
+ )
+
+
+def log_video(batch, model, args, only_log_video_latents=False):
+ texts = batch["txt"]
+ text_save_dir = os.path.join(args.save, "video_texts")
+ os.makedirs(text_save_dir, exist_ok=True)
+ save_texts(texts, text_save_dir, args.iteration)
+
+ gpu_autocast_kwargs = {
+ "enabled": torch.is_autocast_enabled(),
+ "dtype": torch.get_autocast_gpu_dtype(),
+ "cache_enabled": torch.is_autocast_cache_enabled(),
+ }
+ with torch.no_grad(), torch.cuda.amp.autocast(**gpu_autocast_kwargs):
+ videos = model.log_video(batch, only_log_video_latents=only_log_video_latents)
+
+ if torch.distributed.get_rank() == 0:
+ root = os.path.join(args.save, "video")
+
+ if only_log_video_latents:
+ root = os.path.join(root, "latents")
+ filename = "{}_gs-{:06}".format("latents", args.iteration)
+ path = os.path.join(root, filename)
+ os.makedirs(os.path.split(path)[0], exist_ok=True)
+ os.makedirs(path, exist_ok=True)
+ torch.save(videos["latents"], os.path.join(path, "latent.pt"))
+ else:
+ for k in videos:
+ N = videos[k].shape[0]
+ if not isheatmap(videos[k]):
+ videos[k] = videos[k][:N]
+ if isinstance(videos[k], torch.Tensor):
+ videos[k] = videos[k].detach().float().cpu()
+ if not isheatmap(videos[k]):
+ videos[k] = torch.clamp(videos[k], -1.0, 1.0)
+
+ num_frames = batch["num_frames"][0]
+ fps = batch["fps"][0].cpu().item()
+ if only_log_video_latents:
+ root = os.path.join(root, "latents")
+ filename = "{}_gs-{:06}".format("latents", args.iteration)
+ path = os.path.join(root, filename)
+ os.makedirs(os.path.split(path)[0], exist_ok=True)
+ os.makedirs(path, exist_ok=True)
+ torch.save(videos["latents"], os.path.join(path, "latents.pt"))
+ else:
+ for k in videos:
+ samples = (videos[k] + 1.0) / 2.0
+ filename = "{}_gs-{:06}".format(k, args.iteration)
+
+ path = os.path.join(root, filename)
+ os.makedirs(os.path.split(path)[0], exist_ok=True)
+ save_video_as_grid_and_mp4(samples, path, num_frames // fps, fps, args, k)
+
+
+def broad_cast_batch(batch):
+ mp_size = mpu.get_model_parallel_world_size()
+ global_rank = torch.distributed.get_rank() // mp_size
+ src = global_rank * mp_size
+
+ if batch["mp4"] is not None:
+ broadcast_shape = [batch["mp4"].shape, batch["fps"].shape, batch["num_frames"].shape]
+ else:
+ broadcast_shape = None
+
+ txt = [batch["txt"], broadcast_shape]
+ torch.distributed.broadcast_object_list(txt, src=src, group=mpu.get_model_parallel_group())
+ batch["txt"] = txt[0]
+
+ mp4_shape = txt[1][0]
+ fps_shape = txt[1][1]
+ num_frames_shape = txt[1][2]
+
+ if mpu.get_model_parallel_rank() != 0:
+ batch["mp4"] = torch.zeros(mp4_shape, device="cuda")
+ batch["fps"] = torch.zeros(fps_shape, device="cuda", dtype=torch.long)
+ batch["num_frames"] = torch.zeros(num_frames_shape, device="cuda", dtype=torch.long)
+
+ torch.distributed.broadcast(batch["mp4"], src=src, group=mpu.get_model_parallel_group())
+ torch.distributed.broadcast(batch["fps"], src=src, group=mpu.get_model_parallel_group())
+ torch.distributed.broadcast(batch["num_frames"], src=src, group=mpu.get_model_parallel_group())
+ return batch
+
+
+def forward_step_eval(data_iterator, model, args, timers, only_log_video_latents=False, data_class=None):
+ if mpu.get_model_parallel_rank() == 0:
+ timers("data loader").start()
+ batch_video = next(data_iterator)
+ timers("data loader").stop()
+
+ if len(batch_video["mp4"].shape) == 6:
+ b, v = batch_video["mp4"].shape[:2]
+ batch_video["mp4"] = batch_video["mp4"].view(-1, *batch_video["mp4"].shape[2:])
+ txt = []
+ for i in range(b):
+ for j in range(v):
+ txt.append(batch_video["txt"][j][i])
+ batch_video["txt"] = txt
+
+ for key in batch_video:
+ if isinstance(batch_video[key], torch.Tensor):
+ batch_video[key] = batch_video[key].cuda()
+ else:
+ batch_video = {"mp4": None, "fps": None, "num_frames": None, "txt": None}
+ broad_cast_batch(batch_video)
+ if mpu.get_data_parallel_rank() == 0:
+ log_video(batch_video, model, args, only_log_video_latents=only_log_video_latents)
+
+ batch_video["global_step"] = args.iteration
+ loss, loss_dict = model.shared_step(batch_video)
+ for k in loss_dict:
+ if loss_dict[k].dtype == torch.bfloat16:
+ loss_dict[k] = loss_dict[k].to(torch.float32)
+ return loss, loss_dict
+
+
+def forward_step(data_iterator, model, args, timers, data_class=None):
+ if mpu.get_model_parallel_rank() == 0:
+ timers("data loader").start()
+ batch = next(data_iterator)
+ timers("data loader").stop()
+ for key in batch:
+ if isinstance(batch[key], torch.Tensor):
+ batch[key] = batch[key].cuda()
+
+ if torch.distributed.get_rank() == 0:
+ if not os.path.exists(os.path.join(args.save, "training_config.yaml")):
+ configs = [OmegaConf.load(cfg) for cfg in args.base]
+ config = OmegaConf.merge(*configs)
+ os.makedirs(args.save, exist_ok=True)
+ OmegaConf.save(config=config, f=os.path.join(args.save, "training_config.yaml"))
+ else:
+ batch = {"mp4": None, "fps": None, "num_frames": None, "txt": None}
+
+ batch["global_step"] = args.iteration
+
+ broad_cast_batch(batch)
+
+ loss, loss_dict = model.shared_step(batch)
+
+ return loss, loss_dict
+
+
+if __name__ == "__main__":
+ if "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ:
+ os.environ["LOCAL_RANK"] = os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]
+ os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"]
+ os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"]
+
+ py_parser = argparse.ArgumentParser(add_help=False)
+ known, args_list = py_parser.parse_known_args()
+ args = get_args(args_list)
+ args = argparse.Namespace(**vars(args), **vars(known))
+
+ data_class = get_obj_from_str(args.data_config["target"])
+ create_dataset_function = partial(data_class.create_dataset_function, **args.data_config["params"])
+
+ import yaml
+
+ configs = []
+ for config in args.base:
+ with open(config, "r") as f:
+ base_config = yaml.safe_load(f)
+ configs.append(base_config)
+ args.log_config = configs
+
+ training_main(
+ args,
+ model_cls=SATVideoDiffusionEngine,
+ forward_step_function=partial(forward_step, data_class=data_class),
+ forward_step_eval=partial(
+ forward_step_eval, data_class=data_class, only_log_video_latents=args.only_log_video_latents
+ ),
+ create_dataset_function=create_dataset_function,
+ )
diff --git a/sat/vae_modules/attention.py b/sat/vae_modules/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..52bbba5dd3617df5568330c8f425dc98aa3944a1
--- /dev/null
+++ b/sat/vae_modules/attention.py
@@ -0,0 +1,572 @@
+import math
+from inspect import isfunction
+from typing import Any, Optional
+
+import torch
+import torch.nn.functional as F
+from einops import rearrange, repeat
+from packaging import version
+from torch import nn
+
+if version.parse(torch.__version__) >= version.parse("2.0.0"):
+ SDP_IS_AVAILABLE = True
+ from torch.backends.cuda import SDPBackend, sdp_kernel
+
+ BACKEND_MAP = {
+ SDPBackend.MATH: {
+ "enable_math": True,
+ "enable_flash": False,
+ "enable_mem_efficient": False,
+ },
+ SDPBackend.FLASH_ATTENTION: {
+ "enable_math": False,
+ "enable_flash": True,
+ "enable_mem_efficient": False,
+ },
+ SDPBackend.EFFICIENT_ATTENTION: {
+ "enable_math": False,
+ "enable_flash": False,
+ "enable_mem_efficient": True,
+ },
+ None: {"enable_math": True, "enable_flash": True, "enable_mem_efficient": True},
+ }
+else:
+ from contextlib import nullcontext
+
+ SDP_IS_AVAILABLE = False
+ sdp_kernel = nullcontext
+ BACKEND_MAP = {}
+ print(
+ f"No SDP backend available, likely because you are running in pytorch versions < 2.0. In fact, "
+ f"you are using PyTorch {torch.__version__}. You might want to consider upgrading."
+ )
+
+try:
+ import xformers
+ import xformers.ops
+
+ XFORMERS_IS_AVAILABLE = True
+except:
+ XFORMERS_IS_AVAILABLE = False
+ print("no module 'xformers'. Processing without...")
+
+from modules.utils import checkpoint
+
+
+def exists(val):
+ return val is not None
+
+
+def uniq(arr):
+ return {el: True for el in arr}.keys()
+
+
+def default(val, d):
+ if exists(val):
+ return val
+ return d() if isfunction(d) else d
+
+
+def max_neg_value(t):
+ return -torch.finfo(t.dtype).max
+
+
+def init_(tensor):
+ dim = tensor.shape[-1]
+ std = 1 / math.sqrt(dim)
+ tensor.uniform_(-std, std)
+ return tensor
+
+
+# feedforward
+class GEGLU(nn.Module):
+ def __init__(self, dim_in, dim_out):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out * 2)
+
+ def forward(self, x):
+ x, gate = self.proj(x).chunk(2, dim=-1)
+ return x * F.gelu(gate)
+
+
+class FeedForward(nn.Module):
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
+ super().__init__()
+ inner_dim = int(dim * mult)
+ dim_out = default(dim_out, dim)
+ project_in = nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)
+
+ self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
+
+ def forward(self, x):
+ return self.net(x)
+
+
+def zero_module(module):
+ """
+ Zero out the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+
+def Normalize(in_channels):
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
+
+
+class LinearAttention(nn.Module):
+ def __init__(self, dim, heads=4, dim_head=32):
+ super().__init__()
+ self.heads = heads
+ hidden_dim = dim_head * heads
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
+
+ def forward(self, x):
+ b, c, h, w = x.shape
+ qkv = self.to_qkv(x)
+ q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3)
+ k = k.softmax(dim=-1)
+ context = torch.einsum("bhdn,bhen->bhde", k, v)
+ out = torch.einsum("bhde,bhdn->bhen", context, q)
+ out = rearrange(out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w)
+ return self.to_out(out)
+
+
+class SpatialSelfAttention(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+ self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+ self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+ self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b, c, h, w = q.shape
+ q = rearrange(q, "b c h w -> b (h w) c")
+ k = rearrange(k, "b c h w -> b c (h w)")
+ w_ = torch.einsum("bij,bjk->bik", q, k)
+
+ w_ = w_ * (int(c) ** (-0.5))
+ w_ = torch.nn.functional.softmax(w_, dim=2)
+
+ # attend to values
+ v = rearrange(v, "b c h w -> b c (h w)")
+ w_ = rearrange(w_, "b i j -> b j i")
+ h_ = torch.einsum("bij,bjk->bik", v, w_)
+ h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
+ h_ = self.proj_out(h_)
+
+ return x + h_
+
+
+class CrossAttention(nn.Module):
+ def __init__(
+ self,
+ query_dim,
+ context_dim=None,
+ heads=8,
+ dim_head=64,
+ dropout=0.0,
+ backend=None,
+ ):
+ super().__init__()
+ inner_dim = dim_head * heads
+ context_dim = default(context_dim, query_dim)
+
+ self.scale = dim_head**-0.5
+ self.heads = heads
+
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
+
+ self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
+ self.backend = backend
+
+ def forward(
+ self,
+ x,
+ context=None,
+ mask=None,
+ additional_tokens=None,
+ n_times_crossframe_attn_in_self=0,
+ ):
+ h = self.heads
+
+ if additional_tokens is not None:
+ # get the number of masked tokens at the beginning of the output sequence
+ n_tokens_to_mask = additional_tokens.shape[1]
+ # add additional token
+ x = torch.cat([additional_tokens, x], dim=1)
+
+ q = self.to_q(x)
+ context = default(context, x)
+ k = self.to_k(context)
+ v = self.to_v(context)
+
+ if n_times_crossframe_attn_in_self:
+ # reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439
+ assert x.shape[0] % n_times_crossframe_attn_in_self == 0
+ n_cp = x.shape[0] // n_times_crossframe_attn_in_self
+ k = repeat(k[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp)
+ v = repeat(v[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp)
+
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
+
+ ## old
+ """
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
+ del q, k
+
+ if exists(mask):
+ mask = rearrange(mask, 'b ... -> b (...)')
+ max_neg_value = -torch.finfo(sim.dtype).max
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
+ sim.masked_fill_(~mask, max_neg_value)
+
+ # attention, what we cannot get enough of
+ sim = sim.softmax(dim=-1)
+
+ out = einsum('b i j, b j d -> b i d', sim, v)
+ """
+ ## new
+ with sdp_kernel(**BACKEND_MAP[self.backend]):
+ # print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape)
+ out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask) # scale is dim_head ** -0.5 per default
+
+ del q, k, v
+ out = rearrange(out, "b h n d -> b n (h d)", h=h)
+
+ if additional_tokens is not None:
+ # remove additional token
+ out = out[:, n_tokens_to_mask:]
+ return self.to_out(out)
+
+
+class MemoryEfficientCrossAttention(nn.Module):
+ # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs):
+ super().__init__()
+ print(
+ f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
+ f"{heads} heads with a dimension of {dim_head}."
+ )
+ inner_dim = dim_head * heads
+ context_dim = default(context_dim, query_dim)
+
+ self.heads = heads
+ self.dim_head = dim_head
+
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
+
+ self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
+ self.attention_op: Optional[Any] = None
+
+ def forward(
+ self,
+ x,
+ context=None,
+ mask=None,
+ additional_tokens=None,
+ n_times_crossframe_attn_in_self=0,
+ ):
+ if additional_tokens is not None:
+ # get the number of masked tokens at the beginning of the output sequence
+ n_tokens_to_mask = additional_tokens.shape[1]
+ # add additional token
+ x = torch.cat([additional_tokens, x], dim=1)
+ q = self.to_q(x)
+ context = default(context, x)
+ k = self.to_k(context)
+ v = self.to_v(context)
+
+ if n_times_crossframe_attn_in_self:
+ # reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439
+ assert x.shape[0] % n_times_crossframe_attn_in_self == 0
+ # n_cp = x.shape[0]//n_times_crossframe_attn_in_self
+ k = repeat(
+ k[::n_times_crossframe_attn_in_self],
+ "b ... -> (b n) ...",
+ n=n_times_crossframe_attn_in_self,
+ )
+ v = repeat(
+ v[::n_times_crossframe_attn_in_self],
+ "b ... -> (b n) ...",
+ n=n_times_crossframe_attn_in_self,
+ )
+
+ b, _, _ = q.shape
+ q, k, v = map(
+ lambda t: t.unsqueeze(3)
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
+ .permute(0, 2, 1, 3)
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
+ .contiguous(),
+ (q, k, v),
+ )
+
+ # actually compute the attention, what we cannot get enough of
+ out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
+
+ # TODO: Use this directly in the attention operation, as a bias
+ if exists(mask):
+ raise NotImplementedError
+ out = (
+ out.unsqueeze(0)
+ .reshape(b, self.heads, out.shape[1], self.dim_head)
+ .permute(0, 2, 1, 3)
+ .reshape(b, out.shape[1], self.heads * self.dim_head)
+ )
+ if additional_tokens is not None:
+ # remove additional token
+ out = out[:, n_tokens_to_mask:]
+ return self.to_out(out)
+
+
+class BasicTransformerBlock(nn.Module):
+ ATTENTION_MODES = {
+ "softmax": CrossAttention, # vanilla attention
+ "softmax-xformers": MemoryEfficientCrossAttention, # ampere
+ }
+
+ def __init__(
+ self,
+ dim,
+ n_heads,
+ d_head,
+ dropout=0.0,
+ context_dim=None,
+ gated_ff=True,
+ checkpoint=True,
+ disable_self_attn=False,
+ attn_mode="softmax",
+ sdp_backend=None,
+ ):
+ super().__init__()
+ assert attn_mode in self.ATTENTION_MODES
+ if attn_mode != "softmax" and not XFORMERS_IS_AVAILABLE:
+ print(
+ f"Attention mode '{attn_mode}' is not available. Falling back to native attention. "
+ f"This is not a problem in Pytorch >= 2.0. FYI, you are running with PyTorch version {torch.__version__}"
+ )
+ attn_mode = "softmax"
+ elif attn_mode == "softmax" and not SDP_IS_AVAILABLE:
+ print("We do not support vanilla attention anymore, as it is too expensive. Sorry.")
+ if not XFORMERS_IS_AVAILABLE:
+ assert False, "Please install xformers via e.g. 'pip install xformers==0.0.16'"
+ else:
+ print("Falling back to xformers efficient attention.")
+ attn_mode = "softmax-xformers"
+ attn_cls = self.ATTENTION_MODES[attn_mode]
+ if version.parse(torch.__version__) >= version.parse("2.0.0"):
+ assert sdp_backend is None or isinstance(sdp_backend, SDPBackend)
+ else:
+ assert sdp_backend is None
+ self.disable_self_attn = disable_self_attn
+ self.attn1 = attn_cls(
+ query_dim=dim,
+ heads=n_heads,
+ dim_head=d_head,
+ dropout=dropout,
+ context_dim=context_dim if self.disable_self_attn else None,
+ backend=sdp_backend,
+ ) # is a self-attention if not self.disable_self_attn
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
+ self.attn2 = attn_cls(
+ query_dim=dim,
+ context_dim=context_dim,
+ heads=n_heads,
+ dim_head=d_head,
+ dropout=dropout,
+ backend=sdp_backend,
+ ) # is self-attn if context is none
+ self.norm1 = nn.LayerNorm(dim)
+ self.norm2 = nn.LayerNorm(dim)
+ self.norm3 = nn.LayerNorm(dim)
+ self.checkpoint = checkpoint
+ if self.checkpoint:
+ print(f"{self.__class__.__name__} is using checkpointing")
+
+ def forward(self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0):
+ kwargs = {"x": x}
+
+ if context is not None:
+ kwargs.update({"context": context})
+
+ if additional_tokens is not None:
+ kwargs.update({"additional_tokens": additional_tokens})
+
+ if n_times_crossframe_attn_in_self:
+ kwargs.update({"n_times_crossframe_attn_in_self": n_times_crossframe_attn_in_self})
+
+ # return mixed_checkpoint(self._forward, kwargs, self.parameters(), self.checkpoint)
+ return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
+
+ def _forward(self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0):
+ x = (
+ self.attn1(
+ self.norm1(x),
+ context=context if self.disable_self_attn else None,
+ additional_tokens=additional_tokens,
+ n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self if not self.disable_self_attn else 0,
+ )
+ + x
+ )
+ x = self.attn2(self.norm2(x), context=context, additional_tokens=additional_tokens) + x
+ x = self.ff(self.norm3(x)) + x
+ return x
+
+
+class BasicTransformerSingleLayerBlock(nn.Module):
+ ATTENTION_MODES = {
+ "softmax": CrossAttention, # vanilla attention
+ "softmax-xformers": MemoryEfficientCrossAttention, # on the A100s not quite as fast as the above version
+ # (todo might depend on head_dim, check, falls back to semi-optimized kernels for dim!=[16,32,64,128])
+ }
+
+ def __init__(
+ self,
+ dim,
+ n_heads,
+ d_head,
+ dropout=0.0,
+ context_dim=None,
+ gated_ff=True,
+ checkpoint=True,
+ attn_mode="softmax",
+ ):
+ super().__init__()
+ assert attn_mode in self.ATTENTION_MODES
+ attn_cls = self.ATTENTION_MODES[attn_mode]
+ self.attn1 = attn_cls(
+ query_dim=dim,
+ heads=n_heads,
+ dim_head=d_head,
+ dropout=dropout,
+ context_dim=context_dim,
+ )
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
+ self.norm1 = nn.LayerNorm(dim)
+ self.norm2 = nn.LayerNorm(dim)
+ self.checkpoint = checkpoint
+
+ def forward(self, x, context=None):
+ return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
+
+ def _forward(self, x, context=None):
+ x = self.attn1(self.norm1(x), context=context) + x
+ x = self.ff(self.norm2(x)) + x
+ return x
+
+
+class SpatialTransformer(nn.Module):
+ """
+ Transformer block for image-like data.
+ First, project the input (aka embedding)
+ and reshape to b, t, d.
+ Then apply standard transformer action.
+ Finally, reshape to image
+ NEW: use_linear for more efficiency instead of the 1x1 convs
+ """
+
+ def __init__(
+ self,
+ in_channels,
+ n_heads,
+ d_head,
+ depth=1,
+ dropout=0.0,
+ context_dim=None,
+ disable_self_attn=False,
+ use_linear=False,
+ attn_type="softmax",
+ use_checkpoint=True,
+ # sdp_backend=SDPBackend.FLASH_ATTENTION
+ sdp_backend=None,
+ ):
+ super().__init__()
+ print(f"constructing {self.__class__.__name__} of depth {depth} w/ {in_channels} channels and {n_heads} heads")
+ from omegaconf import ListConfig
+
+ if exists(context_dim) and not isinstance(context_dim, (list, ListConfig)):
+ context_dim = [context_dim]
+ if exists(context_dim) and isinstance(context_dim, list):
+ if depth != len(context_dim):
+ print(
+ f"WARNING: {self.__class__.__name__}: Found context dims {context_dim} of depth {len(context_dim)}, "
+ f"which does not match the specified 'depth' of {depth}. Setting context_dim to {depth * [context_dim[0]]} now."
+ )
+ # depth does not match context dims.
+ assert all(
+ map(lambda x: x == context_dim[0], context_dim)
+ ), "need homogenous context_dim to match depth automatically"
+ context_dim = depth * [context_dim[0]]
+ elif context_dim is None:
+ context_dim = [None] * depth
+ self.in_channels = in_channels
+ inner_dim = n_heads * d_head
+ self.norm = Normalize(in_channels)
+ if not use_linear:
+ self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
+ else:
+ self.proj_in = nn.Linear(in_channels, inner_dim)
+
+ self.transformer_blocks = nn.ModuleList(
+ [
+ BasicTransformerBlock(
+ inner_dim,
+ n_heads,
+ d_head,
+ dropout=dropout,
+ context_dim=context_dim[d],
+ disable_self_attn=disable_self_attn,
+ attn_mode=attn_type,
+ checkpoint=use_checkpoint,
+ sdp_backend=sdp_backend,
+ )
+ for d in range(depth)
+ ]
+ )
+ if not use_linear:
+ self.proj_out = zero_module(nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0))
+ else:
+ # self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
+ self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
+ self.use_linear = use_linear
+
+ def forward(self, x, context=None):
+ # note: if no context is given, cross-attention defaults to self-attention
+ if not isinstance(context, list):
+ context = [context]
+ b, c, h, w = x.shape
+ x_in = x
+ x = self.norm(x)
+ if not self.use_linear:
+ x = self.proj_in(x)
+ x = rearrange(x, "b c h w -> b (h w) c").contiguous()
+ if self.use_linear:
+ x = self.proj_in(x)
+ for i, block in enumerate(self.transformer_blocks):
+ if i > 0 and len(context) == 1:
+ i = 0 # use same context for each block
+ x = block(x, context=context[i])
+ if self.use_linear:
+ x = self.proj_out(x)
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
+ if not self.use_linear:
+ x = self.proj_out(x)
+ return x + x_in
diff --git a/sat/vae_modules/autoencoder.py b/sat/vae_modules/autoencoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c0cc80e9c0569e1d0c574ad1078b0df21209a7f
--- /dev/null
+++ b/sat/vae_modules/autoencoder.py
@@ -0,0 +1,651 @@
+import logging
+import math
+import re
+import random
+from abc import abstractmethod
+from contextlib import contextmanager
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import pytorch_lightning as pl
+import torch
+import torch.distributed
+import torch.nn as nn
+from einops import rearrange
+from packaging import version
+
+from vae_modules.ema import LitEma
+from sgm.util import (
+ instantiate_from_config,
+ get_obj_from_str,
+ default,
+ is_context_parallel_initialized,
+ initialize_context_parallel,
+ get_context_parallel_group,
+ get_context_parallel_group_rank,
+)
+from vae_modules.cp_enc_dec import _conv_split, _conv_gather
+
+logpy = logging.getLogger(__name__)
+
+
+class AbstractAutoencoder(pl.LightningModule):
+ """
+ This is the base class for all autoencoders, including image autoencoders, image autoencoders with discriminators,
+ unCLIP models, etc. Hence, it is fairly general, and specific features
+ (e.g. discriminator training, encoding, decoding) must be implemented in subclasses.
+ """
+
+ def __init__(
+ self,
+ ema_decay: Union[None, float] = None,
+ monitor: Union[None, str] = None,
+ input_key: str = "jpg",
+ ):
+ super().__init__()
+
+ self.input_key = input_key
+ self.use_ema = ema_decay is not None
+ if monitor is not None:
+ self.monitor = monitor
+
+ if self.use_ema:
+ self.model_ema = LitEma(self, decay=ema_decay)
+ logpy.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
+
+ if version.parse(torch.__version__) >= version.parse("2.0.0"):
+ self.automatic_optimization = False
+
+ # def apply_ckpt(self, ckpt: Union[None, str, dict]):
+ # if ckpt is None:
+ # return
+ # if isinstance(ckpt, str):
+ # ckpt = {
+ # "target": "sgm.modules.checkpoint.CheckpointEngine",
+ # "params": {"ckpt_path": ckpt},
+ # }
+ # engine = instantiate_from_config(ckpt)
+ # engine(self)
+
+ def apply_ckpt(self, ckpt: Union[None, str, dict]):
+ if ckpt is None:
+ return
+ self.init_from_ckpt(ckpt)
+
+ def init_from_ckpt(self, path, ignore_keys=list()):
+ sd = torch.load(path, map_location="cpu")["state_dict"]
+ keys = list(sd.keys())
+ for k in keys:
+ for ik in ignore_keys:
+ if k.startswith(ik):
+ print("Deleting key {} from state_dict.".format(k))
+ del sd[k]
+ missing_keys, unexpected_keys = self.load_state_dict(sd, strict=False)
+ print("Missing keys: ", missing_keys)
+ print("Unexpected keys: ", unexpected_keys)
+ print(f"Restored from {path}")
+
+ @abstractmethod
+ def get_input(self, batch) -> Any:
+ raise NotImplementedError()
+
+ def on_train_batch_end(self, *args, **kwargs):
+ # for EMA computation
+ if self.use_ema:
+ self.model_ema(self)
+
+ @contextmanager
+ def ema_scope(self, context=None):
+ if self.use_ema:
+ self.model_ema.store(self.parameters())
+ self.model_ema.copy_to(self)
+ if context is not None:
+ logpy.info(f"{context}: Switched to EMA weights")
+ try:
+ yield None
+ finally:
+ if self.use_ema:
+ self.model_ema.restore(self.parameters())
+ if context is not None:
+ logpy.info(f"{context}: Restored training weights")
+
+ @abstractmethod
+ def encode(self, *args, **kwargs) -> torch.Tensor:
+ raise NotImplementedError("encode()-method of abstract base class called")
+
+ @abstractmethod
+ def decode(self, *args, **kwargs) -> torch.Tensor:
+ raise NotImplementedError("decode()-method of abstract base class called")
+
+ def instantiate_optimizer_from_config(self, params, lr, cfg):
+ logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config")
+ return get_obj_from_str(cfg["target"])(params, lr=lr, **cfg.get("params", dict()))
+
+ def configure_optimizers(self) -> Any:
+ raise NotImplementedError()
+
+
+class AutoencodingEngine(AbstractAutoencoder):
+ """
+ Base class for all image autoencoders that we train, like VQGAN or AutoencoderKL
+ (we also restore them explicitly as special cases for legacy reasons).
+ Regularizations such as KL or VQ are moved to the regularizer class.
+ """
+
+ def __init__(
+ self,
+ *args,
+ encoder_config: Dict,
+ decoder_config: Dict,
+ loss_config: Dict,
+ regularizer_config: Dict,
+ optimizer_config: Union[Dict, None] = None,
+ lr_g_factor: float = 1.0,
+ trainable_ae_params: Optional[List[List[str]]] = None,
+ ae_optimizer_args: Optional[List[dict]] = None,
+ trainable_disc_params: Optional[List[List[str]]] = None,
+ disc_optimizer_args: Optional[List[dict]] = None,
+ disc_start_iter: int = 0,
+ diff_boost_factor: float = 3.0,
+ ckpt_engine: Union[None, str, dict] = None,
+ ckpt_path: Optional[str] = None,
+ additional_decode_keys: Optional[List[str]] = None,
+ **kwargs,
+ ):
+ super().__init__(*args, **kwargs)
+ self.automatic_optimization = False # pytorch lightning
+
+ self.encoder = instantiate_from_config(encoder_config)
+ self.decoder = instantiate_from_config(decoder_config)
+ self.loss = instantiate_from_config(loss_config)
+ self.regularization = instantiate_from_config(regularizer_config)
+ self.optimizer_config = default(optimizer_config, {"target": "torch.optim.Adam"})
+ self.diff_boost_factor = diff_boost_factor
+ self.disc_start_iter = disc_start_iter
+ self.lr_g_factor = lr_g_factor
+ self.trainable_ae_params = trainable_ae_params
+ if self.trainable_ae_params is not None:
+ self.ae_optimizer_args = default(
+ ae_optimizer_args,
+ [{} for _ in range(len(self.trainable_ae_params))],
+ )
+ assert len(self.ae_optimizer_args) == len(self.trainable_ae_params)
+ else:
+ self.ae_optimizer_args = [{}] # makes type consitent
+
+ self.trainable_disc_params = trainable_disc_params
+ if self.trainable_disc_params is not None:
+ self.disc_optimizer_args = default(
+ disc_optimizer_args,
+ [{} for _ in range(len(self.trainable_disc_params))],
+ )
+ assert len(self.disc_optimizer_args) == len(self.trainable_disc_params)
+ else:
+ self.disc_optimizer_args = [{}] # makes type consitent
+
+ if ckpt_path is not None:
+ assert ckpt_engine is None, "Can't set ckpt_engine and ckpt_path"
+ logpy.warn("Checkpoint path is deprecated, use `checkpoint_egnine` instead")
+ self.apply_ckpt(default(ckpt_path, ckpt_engine))
+ self.additional_decode_keys = set(default(additional_decode_keys, []))
+
+ def get_input(self, batch: Dict) -> torch.Tensor:
+ # assuming unified data format, dataloader returns a dict.
+ # image tensors should be scaled to -1 ... 1 and in channels-first
+ # format (e.g., bchw instead if bhwc)
+ return batch[self.input_key]
+
+ def get_autoencoder_params(self) -> list:
+ params = []
+ if hasattr(self.loss, "get_trainable_autoencoder_parameters"):
+ params += list(self.loss.get_trainable_autoencoder_parameters())
+ if hasattr(self.regularization, "get_trainable_parameters"):
+ params += list(self.regularization.get_trainable_parameters())
+ params = params + list(self.encoder.parameters())
+ params = params + list(self.decoder.parameters())
+ return params
+
+ def get_discriminator_params(self) -> list:
+ if hasattr(self.loss, "get_trainable_parameters"):
+ params = list(self.loss.get_trainable_parameters()) # e.g., discriminator
+ else:
+ params = []
+ return params
+
+ def get_last_layer(self):
+ return self.decoder.get_last_layer()
+
+ def encode(
+ self,
+ x: torch.Tensor,
+ return_reg_log: bool = False,
+ unregularized: bool = False,
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
+ z = self.encoder(x)
+ if unregularized:
+ return z, dict()
+ z, reg_log = self.regularization(z)
+ if return_reg_log:
+ return z, reg_log
+ return z
+
+ def decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor:
+ x = self.decoder(z, **kwargs)
+ return x
+
+ def forward(self, x: torch.Tensor, **additional_decode_kwargs) -> Tuple[torch.Tensor, torch.Tensor, dict]:
+ z, reg_log = self.encode(x, return_reg_log=True)
+ dec = self.decode(z, **additional_decode_kwargs)
+ return z, dec, reg_log
+
+ def inner_training_step(self, batch: dict, batch_idx: int, optimizer_idx: int = 0) -> torch.Tensor:
+ x = self.get_input(batch)
+ additional_decode_kwargs = {key: batch[key] for key in self.additional_decode_keys.intersection(batch)}
+ z, xrec, regularization_log = self(x, **additional_decode_kwargs)
+ if hasattr(self.loss, "forward_keys"):
+ extra_info = {
+ "z": z,
+ "optimizer_idx": optimizer_idx,
+ "global_step": self.global_step,
+ "last_layer": self.get_last_layer(),
+ "split": "train",
+ "regularization_log": regularization_log,
+ "autoencoder": self,
+ }
+ extra_info = {k: extra_info[k] for k in self.loss.forward_keys}
+ else:
+ extra_info = dict()
+
+ if optimizer_idx == 0:
+ # autoencode
+ out_loss = self.loss(x, xrec, **extra_info)
+ if isinstance(out_loss, tuple):
+ aeloss, log_dict_ae = out_loss
+ else:
+ # simple loss function
+ aeloss = out_loss
+ log_dict_ae = {"train/loss/rec": aeloss.detach()}
+
+ self.log_dict(
+ log_dict_ae,
+ prog_bar=False,
+ logger=True,
+ on_step=True,
+ on_epoch=True,
+ sync_dist=False,
+ )
+ self.log(
+ "loss",
+ aeloss.mean().detach(),
+ prog_bar=True,
+ logger=False,
+ on_epoch=False,
+ on_step=True,
+ )
+ return aeloss
+ elif optimizer_idx == 1:
+ # discriminator
+ discloss, log_dict_disc = self.loss(x, xrec, **extra_info)
+ # -> discriminator always needs to return a tuple
+ self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
+ return discloss
+ else:
+ raise NotImplementedError(f"Unknown optimizer {optimizer_idx}")
+
+ def training_step(self, batch: dict, batch_idx: int):
+ opts = self.optimizers()
+ if not isinstance(opts, list):
+ # Non-adversarial case
+ opts = [opts]
+ optimizer_idx = batch_idx % len(opts)
+ if self.global_step < self.disc_start_iter:
+ optimizer_idx = 0
+ opt = opts[optimizer_idx]
+ opt.zero_grad()
+ with opt.toggle_model():
+ loss = self.inner_training_step(batch, batch_idx, optimizer_idx=optimizer_idx)
+ self.manual_backward(loss)
+ opt.step()
+
+ def validation_step(self, batch: dict, batch_idx: int) -> Dict:
+ log_dict = self._validation_step(batch, batch_idx)
+ with self.ema_scope():
+ log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
+ log_dict.update(log_dict_ema)
+ return log_dict
+
+ def _validation_step(self, batch: dict, batch_idx: int, postfix: str = "") -> Dict:
+ x = self.get_input(batch)
+
+ z, xrec, regularization_log = self(x)
+ if hasattr(self.loss, "forward_keys"):
+ extra_info = {
+ "z": z,
+ "optimizer_idx": 0,
+ "global_step": self.global_step,
+ "last_layer": self.get_last_layer(),
+ "split": "val" + postfix,
+ "regularization_log": regularization_log,
+ "autoencoder": self,
+ }
+ extra_info = {k: extra_info[k] for k in self.loss.forward_keys}
+ else:
+ extra_info = dict()
+ out_loss = self.loss(x, xrec, **extra_info)
+ if isinstance(out_loss, tuple):
+ aeloss, log_dict_ae = out_loss
+ else:
+ # simple loss function
+ aeloss = out_loss
+ log_dict_ae = {f"val{postfix}/loss/rec": aeloss.detach()}
+ full_log_dict = log_dict_ae
+
+ if "optimizer_idx" in extra_info:
+ extra_info["optimizer_idx"] = 1
+ discloss, log_dict_disc = self.loss(x, xrec, **extra_info)
+ full_log_dict.update(log_dict_disc)
+ self.log(
+ f"val{postfix}/loss/rec",
+ log_dict_ae[f"val{postfix}/loss/rec"],
+ sync_dist=True,
+ )
+ self.log_dict(full_log_dict, sync_dist=True)
+ return full_log_dict
+
+ def get_param_groups(
+ self, parameter_names: List[List[str]], optimizer_args: List[dict]
+ ) -> Tuple[List[Dict[str, Any]], int]:
+ groups = []
+ num_params = 0
+ for names, args in zip(parameter_names, optimizer_args):
+ params = []
+ for pattern_ in names:
+ pattern_params = []
+ pattern = re.compile(pattern_)
+ for p_name, param in self.named_parameters():
+ if re.match(pattern, p_name):
+ pattern_params.append(param)
+ num_params += param.numel()
+ if len(pattern_params) == 0:
+ logpy.warn(f"Did not find parameters for pattern {pattern_}")
+ params.extend(pattern_params)
+ groups.append({"params": params, **args})
+ return groups, num_params
+
+ def configure_optimizers(self) -> List[torch.optim.Optimizer]:
+ if self.trainable_ae_params is None:
+ ae_params = self.get_autoencoder_params()
+ else:
+ ae_params, num_ae_params = self.get_param_groups(self.trainable_ae_params, self.ae_optimizer_args)
+ logpy.info(f"Number of trainable autoencoder parameters: {num_ae_params:,}")
+ if self.trainable_disc_params is None:
+ disc_params = self.get_discriminator_params()
+ else:
+ disc_params, num_disc_params = self.get_param_groups(self.trainable_disc_params, self.disc_optimizer_args)
+ logpy.info(f"Number of trainable discriminator parameters: {num_disc_params:,}")
+ opt_ae = self.instantiate_optimizer_from_config(
+ ae_params,
+ default(self.lr_g_factor, 1.0) * self.learning_rate,
+ self.optimizer_config,
+ )
+ opts = [opt_ae]
+ if len(disc_params) > 0:
+ opt_disc = self.instantiate_optimizer_from_config(disc_params, self.learning_rate, self.optimizer_config)
+ opts.append(opt_disc)
+
+ return opts
+
+ @torch.no_grad()
+ def log_images(self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs) -> dict:
+ log = dict()
+ additional_decode_kwargs = {}
+ x = self.get_input(batch)
+ additional_decode_kwargs.update({key: batch[key] for key in self.additional_decode_keys.intersection(batch)})
+
+ _, xrec, _ = self(x, **additional_decode_kwargs)
+ log["inputs"] = x
+ log["reconstructions"] = xrec
+ diff = 0.5 * torch.abs(torch.clamp(xrec, -1.0, 1.0) - x)
+ diff.clamp_(0, 1.0)
+ log["diff"] = 2.0 * diff - 1.0
+ # diff_boost shows location of small errors, by boosting their
+ # brightness.
+ log["diff_boost"] = 2.0 * torch.clamp(self.diff_boost_factor * diff, 0.0, 1.0) - 1
+ if hasattr(self.loss, "log_images"):
+ log.update(self.loss.log_images(x, xrec))
+ with self.ema_scope():
+ _, xrec_ema, _ = self(x, **additional_decode_kwargs)
+ log["reconstructions_ema"] = xrec_ema
+ diff_ema = 0.5 * torch.abs(torch.clamp(xrec_ema, -1.0, 1.0) - x)
+ diff_ema.clamp_(0, 1.0)
+ log["diff_ema"] = 2.0 * diff_ema - 1.0
+ log["diff_boost_ema"] = 2.0 * torch.clamp(self.diff_boost_factor * diff_ema, 0.0, 1.0) - 1
+ if additional_log_kwargs:
+ additional_decode_kwargs.update(additional_log_kwargs)
+ _, xrec_add, _ = self(x, **additional_decode_kwargs)
+ log_str = "reconstructions-" + "-".join(
+ [f"{key}={additional_log_kwargs[key]}" for key in additional_log_kwargs]
+ )
+ log[log_str] = xrec_add
+ return log
+
+
+class AutoencodingEngineLegacy(AutoencodingEngine):
+ def __init__(self, embed_dim: int, **kwargs):
+ self.max_batch_size = kwargs.pop("max_batch_size", None)
+ ddconfig = kwargs.pop("ddconfig")
+ ckpt_path = kwargs.pop("ckpt_path", None)
+ ckpt_engine = kwargs.pop("ckpt_engine", None)
+ super().__init__(
+ encoder_config={
+ "target": "sgm.modules.diffusionmodules.model.Encoder",
+ "params": ddconfig,
+ },
+ decoder_config={
+ "target": "sgm.modules.diffusionmodules.model.Decoder",
+ "params": ddconfig,
+ },
+ **kwargs,
+ )
+ self.quant_conv = torch.nn.Conv2d(
+ (1 + ddconfig["double_z"]) * ddconfig["z_channels"],
+ (1 + ddconfig["double_z"]) * embed_dim,
+ 1,
+ )
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
+ self.embed_dim = embed_dim
+
+ self.apply_ckpt(default(ckpt_path, ckpt_engine))
+
+ def get_autoencoder_params(self) -> list:
+ params = super().get_autoencoder_params()
+ return params
+
+ def encode(self, x: torch.Tensor, return_reg_log: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
+ if self.max_batch_size is None:
+ z = self.encoder(x)
+ z = self.quant_conv(z)
+ else:
+ N = x.shape[0]
+ bs = self.max_batch_size
+ n_batches = int(math.ceil(N / bs))
+ z = list()
+ for i_batch in range(n_batches):
+ z_batch = self.encoder(x[i_batch * bs : (i_batch + 1) * bs])
+ z_batch = self.quant_conv(z_batch)
+ z.append(z_batch)
+ z = torch.cat(z, 0)
+
+ z, reg_log = self.regularization(z)
+ if return_reg_log:
+ return z, reg_log
+ return z
+
+ def decode(self, z: torch.Tensor, **decoder_kwargs) -> torch.Tensor:
+ if self.max_batch_size is None:
+ dec = self.post_quant_conv(z)
+ dec = self.decoder(dec, **decoder_kwargs)
+ else:
+ N = z.shape[0]
+ bs = self.max_batch_size
+ n_batches = int(math.ceil(N / bs))
+ dec = list()
+ for i_batch in range(n_batches):
+ dec_batch = self.post_quant_conv(z[i_batch * bs : (i_batch + 1) * bs])
+ dec_batch = self.decoder(dec_batch, **decoder_kwargs)
+ dec.append(dec_batch)
+ dec = torch.cat(dec, 0)
+
+ return dec
+
+
+class AutoencoderKL(AutoencodingEngineLegacy):
+ def __init__(self, **kwargs):
+ if "lossconfig" in kwargs:
+ kwargs["loss_config"] = kwargs.pop("lossconfig")
+ super().__init__(
+ regularizer_config={"target": ("sgm.modules.autoencoding.regularizers" ".DiagonalGaussianRegularizer")},
+ **kwargs,
+ )
+
+
+class IdentityFirstStage(AbstractAutoencoder):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def get_input(self, x: Any) -> Any:
+ return x
+
+ def encode(self, x: Any, *args, **kwargs) -> Any:
+ return x
+
+ def decode(self, x: Any, *args, **kwargs) -> Any:
+ return x
+
+
+class VideoAutoencodingEngine(AutoencodingEngine):
+ def __init__(
+ self,
+ ckpt_path: Union[None, str] = None,
+ ignore_keys: Union[Tuple, list] = (),
+ image_video_weights=[1, 1],
+ only_train_decoder=False,
+ context_parallel_size=0,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.context_parallel_size = context_parallel_size
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
+
+ def log_videos(self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs) -> dict:
+ return self.log_images(batch, additional_log_kwargs, **kwargs)
+
+ def get_input(self, batch: dict) -> torch.Tensor:
+ if self.context_parallel_size > 0:
+ if not is_context_parallel_initialized():
+ initialize_context_parallel(self.context_parallel_size)
+
+ batch = batch[self.input_key]
+
+ global_src_rank = get_context_parallel_group_rank() * self.context_parallel_size
+ torch.distributed.broadcast(batch, src=global_src_rank, group=get_context_parallel_group())
+
+ batch = _conv_split(batch, dim=2, kernel_size=1)
+ return batch
+
+ return batch[self.input_key]
+
+ def apply_ckpt(self, ckpt: Union[None, str, dict]):
+ if ckpt is None:
+ return
+ self.init_from_ckpt(ckpt)
+
+ def init_from_ckpt(self, path, ignore_keys=list()):
+ sd = torch.load(path, map_location="cpu")["state_dict"]
+ keys = list(sd.keys())
+ for k in keys:
+ for ik in ignore_keys:
+ if k.startswith(ik):
+ print("Deleting key {} from state_dict.".format(k))
+ del sd[k]
+ missing_keys, unexpected_keys = self.load_state_dict(sd, strict=False)
+ print("Missing keys: ", missing_keys)
+ print("Unexpected keys: ", unexpected_keys)
+ print(f"Restored from {path}")
+
+
+class VideoAutoencoderInferenceWrapper(VideoAutoencodingEngine):
+ def __init__(
+ self,
+ cp_size=0,
+ *args,
+ **kwargs,
+ ):
+ self.cp_size = cp_size
+ return super().__init__(*args, **kwargs)
+
+ def encode(
+ self,
+ x: torch.Tensor,
+ return_reg_log: bool = False,
+ unregularized: bool = False,
+ input_cp: bool = False,
+ output_cp: bool = False,
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
+ if self.cp_size > 0 and not input_cp:
+ if not is_context_parallel_initialized:
+ initialize_context_parallel(self.cp_size)
+
+ global_src_rank = get_context_parallel_group_rank() * self.cp_size
+ torch.distributed.broadcast(x, src=global_src_rank, group=get_context_parallel_group())
+
+ x = _conv_split(x, dim=2, kernel_size=1)
+
+ if return_reg_log:
+ z, reg_log = super().encode(x, return_reg_log, unregularized)
+ else:
+ z = super().encode(x, return_reg_log, unregularized)
+
+ if self.cp_size > 0 and not output_cp:
+ z = _conv_gather(z, dim=2, kernel_size=1)
+
+ if return_reg_log:
+ return z, reg_log
+ return z
+
+ def decode(
+ self,
+ z: torch.Tensor,
+ input_cp: bool = False,
+ output_cp: bool = False,
+ split_kernel_size: int = 1,
+ **kwargs,
+ ):
+ if self.cp_size > 0 and not input_cp:
+ if not is_context_parallel_initialized:
+ initialize_context_parallel(self.cp_size)
+
+ global_src_rank = get_context_parallel_group_rank() * self.cp_size
+ torch.distributed.broadcast(z, src=global_src_rank, group=get_context_parallel_group())
+
+ z = _conv_split(z, dim=2, kernel_size=split_kernel_size)
+
+ x = super().decode(z, **kwargs)
+
+ if self.cp_size > 0 and not output_cp:
+ x = _conv_gather(x, dim=2, kernel_size=split_kernel_size)
+
+ return x
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ input_cp: bool = False,
+ latent_cp: bool = False,
+ output_cp: bool = False,
+ **additional_decode_kwargs,
+ ) -> Tuple[torch.Tensor, torch.Tensor, dict]:
+ z, reg_log = self.encode(x, return_reg_log=True, input_cp=input_cp, output_cp=latent_cp)
+ dec = self.decode(z, input_cp=latent_cp, output_cp=output_cp, **additional_decode_kwargs)
+ return z, dec, reg_log
diff --git a/sat/vae_modules/cp_enc_dec.py b/sat/vae_modules/cp_enc_dec.py
new file mode 100644
index 0000000000000000000000000000000000000000..28d97385249c345db283dc1d5f20aaab83d7787c
--- /dev/null
+++ b/sat/vae_modules/cp_enc_dec.py
@@ -0,0 +1,987 @@
+import math
+import torch
+import torch.distributed
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+
+from beartype import beartype
+from beartype.typing import Union, Tuple, Optional, List
+from einops import rearrange
+
+from sgm.util import (
+ get_context_parallel_group,
+ get_context_parallel_rank,
+ get_context_parallel_world_size,
+ get_context_parallel_group_rank,
+)
+
+# try:
+from vae_modules.utils import SafeConv3d as Conv3d
+# except:
+# # Degrade to normal Conv3d if SafeConv3d is not available
+# from torch.nn import Conv3d
+
+
+def cast_tuple(t, length=1):
+ return t if isinstance(t, tuple) else ((t,) * length)
+
+
+def divisible_by(num, den):
+ return (num % den) == 0
+
+
+def is_odd(n):
+ return not divisible_by(n, 2)
+
+
+def exists(v):
+ return v is not None
+
+
+def pair(t):
+ return t if isinstance(t, tuple) else (t, t)
+
+
+def get_timestep_embedding(timesteps, embedding_dim):
+ """
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
+ From Fairseq.
+ Build sinusoidal embeddings.
+ This matches the implementation in tensor2tensor, but differs slightly
+ from the description in Section 3.5 of "Attention Is All You Need".
+ """
+ assert len(timesteps.shape) == 1
+
+ half_dim = embedding_dim // 2
+ emb = math.log(10000) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
+ emb = emb.to(device=timesteps.device)
+ emb = timesteps.float()[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
+ return emb
+
+
+def nonlinearity(x):
+ # swish
+ return x * torch.sigmoid(x)
+
+
+def leaky_relu(p=0.1):
+ return nn.LeakyReLU(p)
+
+
+def _split(input_, dim):
+ cp_world_size = get_context_parallel_world_size()
+
+ if cp_world_size == 1:
+ return input_
+
+ cp_rank = get_context_parallel_rank()
+
+ # print('in _split, cp_rank:', cp_rank, 'input_size:', input_.shape)
+
+ inpu_first_frame_ = input_.transpose(0, dim)[:1].transpose(0, dim).contiguous()
+ input_ = input_.transpose(0, dim)[1:].transpose(0, dim).contiguous()
+ dim_size = input_.size()[dim] // cp_world_size
+
+ input_list = torch.split(input_, dim_size, dim=dim)
+ output = input_list[cp_rank]
+
+ if cp_rank == 0:
+ output = torch.cat([inpu_first_frame_, output], dim=dim)
+ output = output.contiguous()
+
+ # print('out _split, cp_rank:', cp_rank, 'output_size:', output.shape)
+
+ return output
+
+
+def _gather(input_, dim):
+ cp_world_size = get_context_parallel_world_size()
+
+ # Bypass the function if context parallel is 1
+ if cp_world_size == 1:
+ return input_
+
+ group = get_context_parallel_group()
+ cp_rank = get_context_parallel_rank()
+
+ # print('in _gather, cp_rank:', cp_rank, 'input_size:', input_.shape)
+
+ input_first_frame_ = input_.transpose(0, dim)[:1].transpose(0, dim).contiguous()
+ if cp_rank == 0:
+ input_ = input_.transpose(0, dim)[1:].transpose(0, dim).contiguous()
+
+ tensor_list = [torch.empty_like(torch.cat([input_first_frame_, input_], dim=dim))] + [
+ torch.empty_like(input_) for _ in range(cp_world_size - 1)
+ ]
+
+ if cp_rank == 0:
+ input_ = torch.cat([input_first_frame_, input_], dim=dim)
+
+ tensor_list[cp_rank] = input_
+ torch.distributed.all_gather(tensor_list, input_, group=group)
+
+ output = torch.cat(tensor_list, dim=dim).contiguous()
+
+ # print('out _gather, cp_rank:', cp_rank, 'output_size:', output.shape)
+
+ return output
+
+
+def _conv_split(input_, dim, kernel_size):
+ cp_world_size = get_context_parallel_world_size()
+
+ # Bypass the function if context parallel is 1
+ if cp_world_size == 1:
+ return input_
+
+ # print('in _conv_split, cp_rank:', cp_rank, 'input_size:', input_.shape)
+
+ cp_rank = get_context_parallel_rank()
+
+ dim_size = (input_.size()[dim] - kernel_size) // cp_world_size
+
+ if cp_rank == 0:
+ output = input_.transpose(dim, 0)[: dim_size + kernel_size].transpose(dim, 0)
+ else:
+ # output = input_.transpose(dim, 0)[cp_rank * dim_size + 1:(cp_rank + 1) * dim_size + kernel_size].transpose(dim, 0)
+ output = input_.transpose(dim, 0)[
+ cp_rank * dim_size + kernel_size : (cp_rank + 1) * dim_size + kernel_size
+ ].transpose(dim, 0)
+ output = output.contiguous()
+
+ # print('out _conv_split, cp_rank:', cp_rank, 'input_size:', output.shape)
+
+ return output
+
+
+def _conv_gather(input_, dim, kernel_size):
+ cp_world_size = get_context_parallel_world_size()
+
+ # Bypass the function if context parallel is 1
+ if cp_world_size == 1:
+ return input_
+
+ group = get_context_parallel_group()
+ cp_rank = get_context_parallel_rank()
+
+ # print('in _conv_gather, cp_rank:', cp_rank, 'input_size:', input_.shape)
+
+ input_first_kernel_ = input_.transpose(0, dim)[:kernel_size].transpose(0, dim).contiguous()
+ if cp_rank == 0:
+ input_ = input_.transpose(0, dim)[kernel_size:].transpose(0, dim).contiguous()
+ else:
+ input_ = input_.transpose(0, dim)[max(kernel_size - 1, 0) :].transpose(0, dim).contiguous()
+
+ tensor_list = [torch.empty_like(torch.cat([input_first_kernel_, input_], dim=dim))] + [
+ torch.empty_like(input_) for _ in range(cp_world_size - 1)
+ ]
+ if cp_rank == 0:
+ input_ = torch.cat([input_first_kernel_, input_], dim=dim)
+
+ tensor_list[cp_rank] = input_
+ torch.distributed.all_gather(tensor_list, input_, group=group)
+
+ # Note: torch.cat already creates a contiguous tensor.
+ output = torch.cat(tensor_list, dim=dim).contiguous()
+
+ # print('out _conv_gather, cp_rank:', cp_rank, 'input_size:', output.shape)
+
+ return output
+
+
+def _pass_from_previous_rank(input_, dim, kernel_size):
+ # Bypass the function if kernel size is 1
+ if kernel_size == 1:
+ return input_
+
+ group = get_context_parallel_group()
+ cp_rank = get_context_parallel_rank()
+ cp_group_rank = get_context_parallel_group_rank()
+ cp_world_size = get_context_parallel_world_size()
+
+ # print('in _pass_from_previous_rank, cp_rank:', cp_rank, 'input_size:', input_.shape)
+
+ global_rank = torch.distributed.get_rank()
+ global_world_size = torch.distributed.get_world_size()
+
+ input_ = input_.transpose(0, dim)
+
+ # pass from last rank
+ send_rank = global_rank + 1
+ recv_rank = global_rank - 1
+ if send_rank % cp_world_size == 0:
+ send_rank -= cp_world_size
+ if recv_rank % cp_world_size == cp_world_size - 1:
+ recv_rank += cp_world_size
+
+ if cp_rank < cp_world_size - 1:
+ req_send = torch.distributed.isend(input_[-kernel_size + 1 :].contiguous(), send_rank, group=group)
+ if cp_rank > 0:
+ recv_buffer = torch.empty_like(input_[-kernel_size + 1 :]).contiguous()
+ req_recv = torch.distributed.irecv(recv_buffer, recv_rank, group=group)
+
+ if cp_rank == 0:
+ input_ = torch.cat([input_[:1]] * (kernel_size - 1) + [input_], dim=0)
+ else:
+ req_recv.wait()
+ input_ = torch.cat([recv_buffer, input_], dim=0)
+
+ input_ = input_.transpose(0, dim).contiguous()
+
+ # print('out _pass_from_previous_rank, cp_rank:', cp_rank, 'input_size:', input_.shape)
+
+ return input_
+
+
+def _fake_cp_pass_from_previous_rank(input_, dim, kernel_size, cache_padding=None):
+ # Bypass the function if kernel size is 1
+ if kernel_size == 1:
+ return input_
+
+ group = get_context_parallel_group()
+ cp_rank = get_context_parallel_rank()
+ cp_group_rank = get_context_parallel_group_rank()
+ cp_world_size = get_context_parallel_world_size()
+
+ # print('in _pass_from_previous_rank, cp_rank:', cp_rank, 'input_size:', input_.shape)
+
+ global_rank = torch.distributed.get_rank()
+ global_world_size = torch.distributed.get_world_size()
+
+ input_ = input_.transpose(0, dim)
+
+ # pass from last rank
+ send_rank = global_rank + 1
+ recv_rank = global_rank - 1
+ if send_rank % cp_world_size == 0:
+ send_rank -= cp_world_size
+ if recv_rank % cp_world_size == cp_world_size - 1:
+ recv_rank += cp_world_size
+
+ # req_send = torch.distributed.isend(input_[-kernel_size + 1:].contiguous(), send_rank, group=group)
+ # recv_buffer = torch.empty_like(input_[-kernel_size + 1:]).contiguous()
+ # req_recv = torch.distributed.recv(recv_buffer, recv_rank, group=group)
+ # req_recv.wait()
+ recv_buffer = torch.empty_like(input_[-kernel_size + 1 :]).contiguous()
+ if cp_rank < cp_world_size - 1:
+ req_send = torch.distributed.isend(input_[-kernel_size + 1 :].contiguous(), send_rank, group=group)
+ if cp_rank > 0:
+ req_recv = torch.distributed.irecv(recv_buffer, recv_rank, group=group)
+ # req_send = torch.distributed.isend(input_[-kernel_size + 1:].contiguous(), send_rank, group=group)
+ # req_recv = torch.distributed.irecv(recv_buffer, recv_rank, group=group)
+
+ if cp_rank == 0:
+ if cache_padding is not None:
+ input_ = torch.cat([cache_padding.transpose(0, dim).to(input_.device), input_], dim=0)
+ else:
+ input_ = torch.cat([input_[:1]] * (kernel_size - 1) + [input_], dim=0)
+ else:
+ req_recv.wait()
+ input_ = torch.cat([recv_buffer, input_], dim=0)
+
+ input_ = input_.transpose(0, dim).contiguous()
+ return input_
+
+
+def _drop_from_previous_rank(input_, dim, kernel_size):
+ input_ = input_.transpose(0, dim)[kernel_size - 1 :].transpose(0, dim)
+ return input_
+
+
+class _ConvolutionScatterToContextParallelRegion(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, input_, dim, kernel_size):
+ ctx.dim = dim
+ ctx.kernel_size = kernel_size
+ return _conv_split(input_, dim, kernel_size)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ return _conv_gather(grad_output, ctx.dim, ctx.kernel_size), None, None
+
+
+class _ConvolutionGatherFromContextParallelRegion(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, input_, dim, kernel_size):
+ ctx.dim = dim
+ ctx.kernel_size = kernel_size
+ return _conv_gather(input_, dim, kernel_size)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ return _conv_split(grad_output, ctx.dim, ctx.kernel_size), None, None
+
+
+class _ConvolutionPassFromPreviousRank(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, input_, dim, kernel_size):
+ ctx.dim = dim
+ ctx.kernel_size = kernel_size
+ return _pass_from_previous_rank(input_, dim, kernel_size)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ return _drop_from_previous_rank(grad_output, ctx.dim, ctx.kernel_size), None, None
+
+
+class _FakeCPConvolutionPassFromPreviousRank(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, input_, dim, kernel_size, cache_padding):
+ ctx.dim = dim
+ ctx.kernel_size = kernel_size
+ return _fake_cp_pass_from_previous_rank(input_, dim, kernel_size, cache_padding)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ return _drop_from_previous_rank(grad_output, ctx.dim, ctx.kernel_size), None, None, None
+
+
+def conv_scatter_to_context_parallel_region(input_, dim, kernel_size):
+ return _ConvolutionScatterToContextParallelRegion.apply(input_, dim, kernel_size)
+
+
+def conv_gather_from_context_parallel_region(input_, dim, kernel_size):
+ return _ConvolutionGatherFromContextParallelRegion.apply(input_, dim, kernel_size)
+
+
+def conv_pass_from_last_rank(input_, dim, kernel_size):
+ return _ConvolutionPassFromPreviousRank.apply(input_, dim, kernel_size)
+
+
+def fake_cp_pass_from_previous_rank(input_, dim, kernel_size, cache_padding):
+ return _FakeCPConvolutionPassFromPreviousRank.apply(input_, dim, kernel_size, cache_padding)
+
+
+class ContextParallelCausalConv3d(nn.Module):
+ def __init__(self, chan_in, chan_out, kernel_size: Union[int, Tuple[int, int, int]], stride=1, **kwargs):
+ super().__init__()
+ kernel_size = cast_tuple(kernel_size, 3)
+
+ time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
+
+ assert is_odd(height_kernel_size) and is_odd(width_kernel_size)
+
+ time_pad = time_kernel_size - 1
+ height_pad = height_kernel_size // 2
+ width_pad = width_kernel_size // 2
+
+ self.height_pad = height_pad
+ self.width_pad = width_pad
+ self.time_pad = time_pad
+ self.time_kernel_size = time_kernel_size
+ self.temporal_dim = 2
+
+ stride = (stride, stride, stride)
+ dilation = (1, 1, 1)
+ self.conv = Conv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
+ self.cache_padding = None
+
+ def forward(self, input_, clear_cache=True):
+ # if input_.shape[2] == 1: # handle image
+ # # first frame padding
+ # input_parallel = torch.cat([input_] * self.time_kernel_size, dim=2)
+ # else:
+ # input_parallel = conv_pass_from_last_rank(input_, self.temporal_dim, self.time_kernel_size)
+
+ # padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
+ # input_parallel = F.pad(input_parallel, padding_2d, mode = 'constant', value = 0)
+
+ # output_parallel = self.conv(input_parallel)
+ # output = output_parallel
+ # return output
+
+ input_parallel = fake_cp_pass_from_previous_rank(
+ input_, self.temporal_dim, self.time_kernel_size, self.cache_padding
+ )
+
+ del self.cache_padding
+ self.cache_padding = None
+ if not clear_cache:
+ cp_rank, cp_world_size = get_context_parallel_rank(), get_context_parallel_world_size()
+ global_rank = torch.distributed.get_rank()
+ if cp_world_size == 1:
+ self.cache_padding = (
+ input_parallel[:, :, -self.time_kernel_size + 1 :].contiguous().detach().clone().cpu()
+ )
+ else:
+ if cp_rank == cp_world_size - 1:
+ torch.distributed.isend(
+ input_parallel[:, :, -self.time_kernel_size + 1 :].contiguous(),
+ global_rank + 1 - cp_world_size,
+ group=get_context_parallel_group(),
+ )
+ if cp_rank == 0:
+ recv_buffer = torch.empty_like(input_parallel[:, :, -self.time_kernel_size + 1 :]).contiguous()
+ torch.distributed.recv(
+ recv_buffer, global_rank - 1 + cp_world_size, group=get_context_parallel_group()
+ )
+ self.cache_padding = recv_buffer.contiguous().detach().clone().cpu()
+
+ padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
+ input_parallel = F.pad(input_parallel, padding_2d, mode="constant", value=0)
+
+ output_parallel = self.conv(input_parallel)
+ output = output_parallel
+ return output
+
+
+class ContextParallelGroupNorm(torch.nn.GroupNorm):
+ def forward(self, input_):
+ gather_flag = input_.shape[2] > 1
+ if gather_flag:
+ input_ = conv_gather_from_context_parallel_region(input_, dim=2, kernel_size=1)
+ output = super().forward(input_)
+ if gather_flag:
+ output = conv_scatter_to_context_parallel_region(output, dim=2, kernel_size=1)
+ return output
+
+
+def Normalize(in_channels, gather=False, **kwargs): # same for 3D and 2D
+ if gather:
+ return ContextParallelGroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
+ else:
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
+
+
+class SpatialNorm3D(nn.Module):
+ def __init__(
+ self,
+ f_channels,
+ zq_channels,
+ freeze_norm_layer=False,
+ add_conv=False,
+ pad_mode="constant",
+ gather=False,
+ **norm_layer_params,
+ ):
+ super().__init__()
+ if gather:
+ self.norm_layer = ContextParallelGroupNorm(num_channels=f_channels, **norm_layer_params)
+ else:
+ self.norm_layer = torch.nn.GroupNorm(num_channels=f_channels, **norm_layer_params)
+ # self.norm_layer = norm_layer(num_channels=f_channels, **norm_layer_params)
+ if freeze_norm_layer:
+ for p in self.norm_layer.parameters:
+ p.requires_grad = False
+
+ self.add_conv = add_conv
+ if add_conv:
+ self.conv = ContextParallelCausalConv3d(
+ chan_in=zq_channels,
+ chan_out=zq_channels,
+ kernel_size=3,
+ )
+
+ self.conv_y = ContextParallelCausalConv3d(
+ chan_in=zq_channels,
+ chan_out=f_channels,
+ kernel_size=1,
+ )
+ self.conv_b = ContextParallelCausalConv3d(
+ chan_in=zq_channels,
+ chan_out=f_channels,
+ kernel_size=1,
+ )
+
+ def forward(self, f, zq, clear_fake_cp_cache=True):
+ if f.shape[2] > 1 and f.shape[2] % 2 == 1:
+ f_first, f_rest = f[:, :, :1], f[:, :, 1:]
+ f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:]
+ zq_first, zq_rest = zq[:, :, :1], zq[:, :, 1:]
+ zq_first = torch.nn.functional.interpolate(zq_first, size=f_first_size, mode="nearest")
+ zq_rest = torch.nn.functional.interpolate(zq_rest, size=f_rest_size, mode="nearest")
+ zq = torch.cat([zq_first, zq_rest], dim=2)
+ else:
+ zq = torch.nn.functional.interpolate(zq, size=f.shape[-3:], mode="nearest")
+
+ if self.add_conv:
+ zq = self.conv(zq, clear_cache=clear_fake_cp_cache)
+
+ # f = conv_gather_from_context_parallel_region(f, dim=2, kernel_size=1)
+ norm_f = self.norm_layer(f)
+ # norm_f = conv_scatter_to_context_parallel_region(norm_f, dim=2, kernel_size=1)
+
+ new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
+ return new_f
+
+
+def Normalize3D(
+ in_channels,
+ zq_ch,
+ add_conv,
+ gather=False,
+):
+ return SpatialNorm3D(
+ in_channels,
+ zq_ch,
+ gather=gather,
+ freeze_norm_layer=False,
+ add_conv=add_conv,
+ num_groups=32,
+ eps=1e-6,
+ affine=True,
+ )
+
+
+class Upsample3D(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ with_conv,
+ compress_time=False,
+ ):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
+ self.compress_time = compress_time
+
+ def forward(self, x):
+ if self.compress_time and x.shape[2] > 1:
+ if x.shape[2] % 2 == 1:
+ # split first frame
+ x_first, x_rest = x[:, :, 0], x[:, :, 1:]
+
+ x_first = torch.nn.functional.interpolate(x_first, scale_factor=2.0, mode="nearest")
+ x_rest = torch.nn.functional.interpolate(x_rest, scale_factor=2.0, mode="nearest")
+ x = torch.cat([x_first[:, :, None, :, :], x_rest], dim=2)
+ else:
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
+
+ else:
+ # only interpolate 2D
+ t = x.shape[2]
+ x = rearrange(x, "b c t h w -> (b t) c h w")
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
+ x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
+
+ if self.with_conv:
+ t = x.shape[2]
+ x = rearrange(x, "b c t h w -> (b t) c h w")
+ x = self.conv(x)
+ x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
+ return x
+
+
+class DownSample3D(nn.Module):
+ def __init__(self, in_channels, with_conv, compress_time=False, out_channels=None):
+ super().__init__()
+ self.with_conv = with_conv
+ if out_channels is None:
+ out_channels = in_channels
+ if self.with_conv:
+ # no asymmetric padding in torch conv, must do it ourselves
+ self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0)
+ self.compress_time = compress_time
+
+ def forward(self, x):
+ if self.compress_time and x.shape[2] > 1:
+ h, w = x.shape[-2:]
+ x = rearrange(x, "b c t h w -> (b h w) c t")
+
+ if x.shape[-1] % 2 == 1:
+ # split first frame
+ x_first, x_rest = x[..., 0], x[..., 1:]
+
+ if x_rest.shape[-1] > 0:
+ x_rest = torch.nn.functional.avg_pool1d(x_rest, kernel_size=2, stride=2)
+ x = torch.cat([x_first[..., None], x_rest], dim=-1)
+ x = rearrange(x, "(b h w) c t -> b c t h w", h=h, w=w)
+ else:
+ x = torch.nn.functional.avg_pool1d(x, kernel_size=2, stride=2)
+ x = rearrange(x, "(b h w) c t -> b c t h w", h=h, w=w)
+
+ if self.with_conv:
+ pad = (0, 1, 0, 1)
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
+ t = x.shape[2]
+ x = rearrange(x, "b c t h w -> (b t) c h w")
+ x = self.conv(x)
+ x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
+ else:
+ t = x.shape[2]
+ x = rearrange(x, "b c t h w -> (b t) c h w")
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
+ x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
+ return x
+
+
+class ContextParallelResnetBlock3D(nn.Module):
+ def __init__(
+ self,
+ *,
+ in_channels,
+ out_channels=None,
+ conv_shortcut=False,
+ dropout,
+ temb_channels=512,
+ zq_ch=None,
+ add_conv=False,
+ gather_norm=False,
+ normalization=Normalize,
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+ self.use_conv_shortcut = conv_shortcut
+
+ self.norm1 = normalization(
+ in_channels,
+ zq_ch=zq_ch,
+ add_conv=add_conv,
+ gather=gather_norm,
+ )
+
+ self.conv1 = ContextParallelCausalConv3d(
+ chan_in=in_channels,
+ chan_out=out_channels,
+ kernel_size=3,
+ )
+ if temb_channels > 0:
+ self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
+ self.norm2 = normalization(
+ out_channels,
+ zq_ch=zq_ch,
+ add_conv=add_conv,
+ gather=gather_norm,
+ )
+ self.dropout = torch.nn.Dropout(dropout)
+ self.conv2 = ContextParallelCausalConv3d(
+ chan_in=out_channels,
+ chan_out=out_channels,
+ kernel_size=3,
+ )
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ self.conv_shortcut = ContextParallelCausalConv3d(
+ chan_in=in_channels,
+ chan_out=out_channels,
+ kernel_size=3,
+ )
+ else:
+ self.nin_shortcut = Conv3d(
+ in_channels,
+ out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ )
+
+ def forward(self, x, temb, zq=None, clear_fake_cp_cache=True):
+ h = x
+
+ # if isinstance(self.norm1, torch.nn.GroupNorm):
+ # h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1)
+ if zq is not None:
+ h = self.norm1(h, zq, clear_fake_cp_cache=clear_fake_cp_cache)
+ else:
+ h = self.norm1(h)
+ # if isinstance(self.norm1, torch.nn.GroupNorm):
+ # h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1)
+
+ h = nonlinearity(h)
+ h = self.conv1(h, clear_cache=clear_fake_cp_cache)
+
+ if temb is not None:
+ h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None, None]
+
+ # if isinstance(self.norm2, torch.nn.GroupNorm):
+ # h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1)
+ if zq is not None:
+ h = self.norm2(h, zq, clear_fake_cp_cache=clear_fake_cp_cache)
+ else:
+ h = self.norm2(h)
+ # if isinstance(self.norm2, torch.nn.GroupNorm):
+ # h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1)
+
+ h = nonlinearity(h)
+ h = self.dropout(h)
+ h = self.conv2(h, clear_cache=clear_fake_cp_cache)
+
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ x = self.conv_shortcut(x, clear_cache=clear_fake_cp_cache)
+ else:
+ x = self.nin_shortcut(x)
+
+ return x + h
+
+
+class ContextParallelEncoder3D(nn.Module):
+ def __init__(
+ self,
+ *,
+ ch,
+ out_ch,
+ ch_mult=(1, 2, 4, 8),
+ num_res_blocks,
+ attn_resolutions,
+ dropout=0.0,
+ resamp_with_conv=True,
+ in_channels,
+ resolution,
+ z_channels,
+ double_z=True,
+ pad_mode="first",
+ temporal_compress_times=4,
+ gather_norm=False,
+ **ignore_kwargs,
+ ):
+ super().__init__()
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+
+ # log2 of temporal_compress_times
+ self.temporal_compress_level = int(np.log2(temporal_compress_times))
+
+ self.conv_in = ContextParallelCausalConv3d(
+ chan_in=in_channels,
+ chan_out=self.ch,
+ kernel_size=3,
+ )
+
+ curr_res = resolution
+ in_ch_mult = (1,) + tuple(ch_mult)
+ self.down = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_in = ch * in_ch_mult[i_level]
+ block_out = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks):
+ block.append(
+ ContextParallelResnetBlock3D(
+ in_channels=block_in,
+ out_channels=block_out,
+ dropout=dropout,
+ temb_channels=self.temb_ch,
+ gather_norm=gather_norm,
+ )
+ )
+ block_in = block_out
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ if i_level != self.num_resolutions - 1:
+ if i_level < self.temporal_compress_level:
+ down.downsample = DownSample3D(block_in, resamp_with_conv, compress_time=True)
+ else:
+ down.downsample = DownSample3D(block_in, resamp_with_conv, compress_time=False)
+ curr_res = curr_res // 2
+ self.down.append(down)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ContextParallelResnetBlock3D(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ gather_norm=gather_norm,
+ )
+
+ self.mid.block_2 = ContextParallelResnetBlock3D(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ gather_norm=gather_norm,
+ )
+
+ # end
+ self.norm_out = Normalize(block_in, gather=gather_norm)
+
+ self.conv_out = ContextParallelCausalConv3d(
+ chan_in=block_in,
+ chan_out=2 * z_channels if double_z else z_channels,
+ kernel_size=3,
+ )
+
+ def forward(self, x):
+ # timestep embedding
+ temb = None
+
+ # downsampling
+ h = self.conv_in(x)
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](h, temb)
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h)
+ if i_level != self.num_resolutions - 1:
+ h = self.down[i_level].downsample(h)
+
+ # middle
+ h = self.mid.block_1(h, temb)
+ h = self.mid.block_2(h, temb)
+
+ # end
+ # h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1)
+ h = self.norm_out(h)
+ # h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1)
+
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+
+ return h
+
+
+class ContextParallelDecoder3D(nn.Module):
+ def __init__(
+ self,
+ *,
+ ch,
+ out_ch,
+ ch_mult=(1, 2, 4, 8),
+ num_res_blocks,
+ attn_resolutions,
+ dropout=0.0,
+ resamp_with_conv=True,
+ in_channels,
+ resolution,
+ z_channels,
+ give_pre_end=False,
+ zq_ch=None,
+ add_conv=False,
+ pad_mode="first",
+ temporal_compress_times=4,
+ gather_norm=False,
+ **ignorekwargs,
+ ):
+ super().__init__()
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+ self.give_pre_end = give_pre_end
+
+ # log2 of temporal_compress_times
+ self.temporal_compress_level = int(np.log2(temporal_compress_times))
+
+ if zq_ch is None:
+ zq_ch = z_channels
+
+ # compute in_ch_mult, block_in and curr_res at lowest res
+ in_ch_mult = (1,) + tuple(ch_mult)
+ block_in = ch * ch_mult[self.num_resolutions - 1]
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
+ self.z_shape = (1, z_channels, curr_res, curr_res)
+ print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
+
+ self.conv_in = ContextParallelCausalConv3d(
+ chan_in=z_channels,
+ chan_out=block_in,
+ kernel_size=3,
+ )
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ContextParallelResnetBlock3D(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ zq_ch=zq_ch,
+ add_conv=add_conv,
+ normalization=Normalize3D,
+ gather_norm=gather_norm,
+ )
+
+ self.mid.block_2 = ContextParallelResnetBlock3D(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ zq_ch=zq_ch,
+ add_conv=add_conv,
+ normalization=Normalize3D,
+ gather_norm=gather_norm,
+ )
+
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks + 1):
+ block.append(
+ ContextParallelResnetBlock3D(
+ in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ zq_ch=zq_ch,
+ add_conv=add_conv,
+ normalization=Normalize3D,
+ gather_norm=gather_norm,
+ )
+ )
+ block_in = block_out
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ if i_level < self.num_resolutions - self.temporal_compress_level:
+ up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=False)
+ else:
+ up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=True)
+ self.up.insert(0, up)
+
+ self.norm_out = Normalize3D(block_in, zq_ch, add_conv=add_conv, gather=gather_norm)
+
+ self.conv_out = ContextParallelCausalConv3d(
+ chan_in=block_in,
+ chan_out=out_ch,
+ kernel_size=3,
+ )
+
+ def forward(self, z, clear_fake_cp_cache=True):
+ self.last_z_shape = z.shape
+
+ # timestep embedding
+ temb = None
+
+ t = z.shape[2]
+ # z to block_in
+
+ zq = z
+ h = self.conv_in(z, clear_cache=clear_fake_cp_cache)
+
+ # middle
+ h = self.mid.block_1(h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache)
+ h = self.mid.block_2(h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks + 1):
+ h = self.up[i_level].block[i_block](h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h, zq)
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+
+ # end
+ if self.give_pre_end:
+ return h
+
+ h = self.norm_out(h, zq, clear_fake_cp_cache=clear_fake_cp_cache)
+ h = nonlinearity(h)
+ h = self.conv_out(h, clear_cache=clear_fake_cp_cache)
+
+ return h
+
+ def get_last_layer(self):
+ return self.conv_out.conv.weight
diff --git a/sat/vae_modules/ema.py b/sat/vae_modules/ema.py
new file mode 100644
index 0000000000000000000000000000000000000000..9f1f7606c2c9b68ebd2302215a9e08f9f31ed8ab
--- /dev/null
+++ b/sat/vae_modules/ema.py
@@ -0,0 +1,82 @@
+import torch
+from torch import nn
+
+
+class LitEma(nn.Module):
+ def __init__(self, model, decay=0.9999, use_num_upates=True):
+ super().__init__()
+ if decay < 0.0 or decay > 1.0:
+ raise ValueError("Decay must be between 0 and 1")
+
+ self.m_name2s_name = {}
+ self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32))
+ self.register_buffer(
+ "num_updates",
+ torch.tensor(0, dtype=torch.int) if use_num_upates else torch.tensor(-1, dtype=torch.int),
+ )
+
+ for name, p in model.named_parameters():
+ if p.requires_grad:
+ # remove as '.'-character is not allowed in buffers
+ s_name = name.replace(".", "")
+ self.m_name2s_name.update({name: s_name})
+ self.register_buffer(s_name, p.clone().detach().data)
+
+ self.collected_params = []
+
+ def reset_num_updates(self):
+ del self.num_updates
+ self.register_buffer("num_updates", torch.tensor(0, dtype=torch.int))
+
+ def forward(self, model):
+ decay = self.decay
+
+ if self.num_updates >= 0:
+ self.num_updates += 1
+ decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
+
+ one_minus_decay = 1.0 - decay
+
+ with torch.no_grad():
+ m_param = dict(model.named_parameters())
+ shadow_params = dict(self.named_buffers())
+
+ for key in m_param:
+ if m_param[key].requires_grad:
+ sname = self.m_name2s_name[key]
+ shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
+ shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
+ else:
+ assert not key in self.m_name2s_name
+
+ def copy_to(self, model):
+ m_param = dict(model.named_parameters())
+ shadow_params = dict(self.named_buffers())
+ for key in m_param:
+ if m_param[key].requires_grad:
+ m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
+ else:
+ assert not key in self.m_name2s_name
+
+ def store(self, parameters):
+ """
+ Save the current parameters for restoring later.
+ Args:
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ temporarily stored.
+ """
+ self.collected_params = [param.clone() for param in parameters]
+
+ def restore(self, parameters):
+ """
+ Restore the parameters stored with the `store` method.
+ Useful to validate the model with EMA parameters without affecting the
+ original optimization process. Store the parameters before the
+ `copy_to` method. After validation (or model saving), use this to
+ restore the former parameters.
+ Args:
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ updated with the stored parameters.
+ """
+ for c_param, param in zip(self.collected_params, parameters):
+ param.data.copy_(c_param.data)
diff --git a/sat/vae_modules/regularizers.py b/sat/vae_modules/regularizers.py
new file mode 100644
index 0000000000000000000000000000000000000000..205bd4a9415f0eb350b04508545a25362c6d0449
--- /dev/null
+++ b/sat/vae_modules/regularizers.py
@@ -0,0 +1,108 @@
+from abc import abstractmethod
+from typing import Any, Tuple
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+
+class DiagonalGaussianDistribution(object):
+ def __init__(self, parameters, deterministic=False):
+ self.parameters = parameters
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
+ self.deterministic = deterministic
+ self.std = torch.exp(0.5 * self.logvar)
+ self.var = torch.exp(self.logvar)
+ if self.deterministic:
+ self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
+
+ def sample(self):
+ # x = self.mean + self.std * torch.randn(self.mean.shape).to(
+ # device=self.parameters.device
+ # )
+ x = self.mean + self.std * torch.randn_like(self.mean)
+ return x
+
+ def kl(self, other=None):
+ if self.deterministic:
+ return torch.Tensor([0.0])
+ else:
+ if other is None:
+ return 0.5 * torch.sum(
+ torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
+ dim=[1, 2, 3],
+ )
+ else:
+ return 0.5 * torch.sum(
+ torch.pow(self.mean - other.mean, 2) / other.var
+ + self.var / other.var
+ - 1.0
+ - self.logvar
+ + other.logvar,
+ dim=[1, 2, 3],
+ )
+
+ def nll(self, sample, dims=[1, 2, 3]):
+ if self.deterministic:
+ return torch.Tensor([0.0])
+ logtwopi = np.log(2.0 * np.pi)
+ return 0.5 * torch.sum(
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
+ dim=dims,
+ )
+
+ def mode(self):
+ return self.mean
+
+
+class AbstractRegularizer(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
+ raise NotImplementedError()
+
+ @abstractmethod
+ def get_trainable_parameters(self) -> Any:
+ raise NotImplementedError()
+
+
+class IdentityRegularizer(AbstractRegularizer):
+ def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
+ return z, dict()
+
+ def get_trainable_parameters(self) -> Any:
+ yield from ()
+
+
+def measure_perplexity(predicted_indices: torch.Tensor, num_centroids: int) -> Tuple[torch.Tensor, torch.Tensor]:
+ # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
+ # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
+ encodings = F.one_hot(predicted_indices, num_centroids).float().reshape(-1, num_centroids)
+ avg_probs = encodings.mean(0)
+ perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
+ cluster_use = torch.sum(avg_probs > 0)
+ return perplexity, cluster_use
+
+
+class DiagonalGaussianRegularizer(AbstractRegularizer):
+ def __init__(self, sample: bool = True):
+ super().__init__()
+ self.sample = sample
+
+ def get_trainable_parameters(self) -> Any:
+ yield from ()
+
+ def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
+ log = dict()
+ posterior = DiagonalGaussianDistribution(z)
+ if self.sample:
+ z = posterior.sample()
+ else:
+ z = posterior.mode()
+ kl_loss = posterior.kl()
+ kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
+ log["kl_loss"] = kl_loss
+ return z, log
diff --git a/sat/vae_modules/utils.py b/sat/vae_modules/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..8c8dba626f250a69c8c28e67f0c7c1c822bc6bc2
--- /dev/null
+++ b/sat/vae_modules/utils.py
@@ -0,0 +1,404 @@
+import functools
+import importlib
+import os
+from functools import partial
+from inspect import isfunction
+
+import fsspec
+import numpy as np
+import torch
+from PIL import Image, ImageDraw, ImageFont
+from safetensors.torch import load_file as load_safetensors
+import torch.distributed
+
+_CONTEXT_PARALLEL_GROUP = None
+_CONTEXT_PARALLEL_SIZE = None
+
+
+def is_context_parallel_initialized():
+ if _CONTEXT_PARALLEL_GROUP is None:
+ return False
+ else:
+ return True
+
+
+def initialize_context_parallel(context_parallel_size):
+ global _CONTEXT_PARALLEL_GROUP
+ global _CONTEXT_PARALLEL_SIZE
+
+ assert _CONTEXT_PARALLEL_GROUP is None, "context parallel group is already initialized"
+ _CONTEXT_PARALLEL_SIZE = context_parallel_size
+
+ rank = torch.distributed.get_rank()
+ world_size = torch.distributed.get_world_size()
+
+ for i in range(0, world_size, context_parallel_size):
+ ranks = range(i, i + context_parallel_size)
+ group = torch.distributed.new_group(ranks)
+ if rank in ranks:
+ _CONTEXT_PARALLEL_GROUP = group
+ break
+
+
+def get_context_parallel_group():
+ assert _CONTEXT_PARALLEL_GROUP is not None, "context parallel group is not initialized"
+
+ return _CONTEXT_PARALLEL_GROUP
+
+
+def get_context_parallel_world_size():
+ assert _CONTEXT_PARALLEL_SIZE is not None, "context parallel size is not initialized"
+
+ return _CONTEXT_PARALLEL_SIZE
+
+
+def get_context_parallel_rank():
+ assert _CONTEXT_PARALLEL_SIZE is not None, "context parallel size is not initialized"
+
+ rank = torch.distributed.get_rank()
+ cp_rank = rank % _CONTEXT_PARALLEL_SIZE
+ return cp_rank
+
+
+def get_context_parallel_group_rank():
+ assert _CONTEXT_PARALLEL_SIZE is not None, "context parallel size is not initialized"
+
+ rank = torch.distributed.get_rank()
+ cp_group_rank = rank // _CONTEXT_PARALLEL_SIZE
+
+ return cp_group_rank
+
+
+class SafeConv3d(torch.nn.Conv3d):
+ def forward(self, input):
+ memory_count = torch.prod(torch.tensor(input.shape)).item() * 2 / 1024**3
+ if memory_count > 2:
+ kernel_size = self.kernel_size[0]
+ part_num = int(memory_count / 2) + 1
+ input_chunks = torch.chunk(input, part_num, dim=2) # NCTHW
+ if kernel_size > 1:
+ input_chunks = [input_chunks[0]] + [
+ torch.cat((input_chunks[i - 1][:, :, -kernel_size + 1 :], input_chunks[i]), dim=2)
+ for i in range(1, len(input_chunks))
+ ]
+
+ output_chunks = []
+ for input_chunk in input_chunks:
+ output_chunks.append(super(SafeConv3d, self).forward(input_chunk))
+ output = torch.cat(output_chunks, dim=2)
+ return output
+ else:
+ return super(SafeConv3d, self).forward(input)
+
+
+def disabled_train(self, mode=True):
+ """Overwrite model.train with this function to make sure train/eval mode
+ does not change anymore."""
+ return self
+
+
+def get_string_from_tuple(s):
+ try:
+ # Check if the string starts and ends with parentheses
+ if s[0] == "(" and s[-1] == ")":
+ # Convert the string to a tuple
+ t = eval(s)
+ # Check if the type of t is tuple
+ if type(t) == tuple:
+ return t[0]
+ else:
+ pass
+ except:
+ pass
+ return s
+
+
+def is_power_of_two(n):
+ """
+ chat.openai.com/chat
+ Return True if n is a power of 2, otherwise return False.
+
+ The function is_power_of_two takes an integer n as input and returns True if n is a power of 2, otherwise it returns False.
+ The function works by first checking if n is less than or equal to 0. If n is less than or equal to 0, it can't be a power of 2, so the function returns False.
+ If n is greater than 0, the function checks whether n is a power of 2 by using a bitwise AND operation between n and n-1. If n is a power of 2, then it will have only one bit set to 1 in its binary representation. When we subtract 1 from a power of 2, all the bits to the right of that bit become 1, and the bit itself becomes 0. So, when we perform a bitwise AND between n and n-1, we get 0 if n is a power of 2, and a non-zero value otherwise.
+ Thus, if the result of the bitwise AND operation is 0, then n is a power of 2 and the function returns True. Otherwise, the function returns False.
+
+ """
+ if n <= 0:
+ return False
+ return (n & (n - 1)) == 0
+
+
+def autocast(f, enabled=True):
+ def do_autocast(*args, **kwargs):
+ with torch.cuda.amp.autocast(
+ enabled=enabled,
+ dtype=torch.get_autocast_gpu_dtype(),
+ cache_enabled=torch.is_autocast_cache_enabled(),
+ ):
+ return f(*args, **kwargs)
+
+ return do_autocast
+
+
+def load_partial_from_config(config):
+ return partial(get_obj_from_str(config["target"]), **config.get("params", dict()))
+
+
+def log_txt_as_img(wh, xc, size=10):
+ # wh a tuple of (width, height)
+ # xc a list of captions to plot
+ b = len(xc)
+ txts = list()
+ for bi in range(b):
+ txt = Image.new("RGB", wh, color="white")
+ draw = ImageDraw.Draw(txt)
+ font = ImageFont.truetype("data/DejaVuSans.ttf", size=size)
+ nc = int(40 * (wh[0] / 256))
+ if isinstance(xc[bi], list):
+ text_seq = xc[bi][0]
+ else:
+ text_seq = xc[bi]
+ lines = "\n".join(text_seq[start : start + nc] for start in range(0, len(text_seq), nc))
+
+ try:
+ draw.text((0, 0), lines, fill="black", font=font)
+ except UnicodeEncodeError:
+ print("Cant encode string for logging. Skipping.")
+
+ txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
+ txts.append(txt)
+ txts = np.stack(txts)
+ txts = torch.tensor(txts)
+ return txts
+
+
+def partialclass(cls, *args, **kwargs):
+ class NewCls(cls):
+ __init__ = functools.partialmethod(cls.__init__, *args, **kwargs)
+
+ return NewCls
+
+
+def make_path_absolute(path):
+ fs, p = fsspec.core.url_to_fs(path)
+ if fs.protocol == "file":
+ return os.path.abspath(p)
+ return path
+
+
+def ismap(x):
+ if not isinstance(x, torch.Tensor):
+ return False
+ return (len(x.shape) == 4) and (x.shape[1] > 3)
+
+
+def isimage(x):
+ if not isinstance(x, torch.Tensor):
+ return False
+ return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
+
+
+def isheatmap(x):
+ if not isinstance(x, torch.Tensor):
+ return False
+
+ return x.ndim == 2
+
+
+def isneighbors(x):
+ if not isinstance(x, torch.Tensor):
+ return False
+ return x.ndim == 5 and (x.shape[2] == 3 or x.shape[2] == 1)
+
+
+def exists(x):
+ return x is not None
+
+
+def expand_dims_like(x, y):
+ while x.dim() != y.dim():
+ x = x.unsqueeze(-1)
+ return x
+
+
+def default(val, d):
+ if exists(val):
+ return val
+ return d() if isfunction(d) else d
+
+
+def mean_flat(tensor):
+ """
+ https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
+ Take the mean over all non-batch dimensions.
+ """
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
+
+
+def count_params(model, verbose=False):
+ total_params = sum(p.numel() for p in model.parameters())
+ if verbose:
+ print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
+ return total_params
+
+
+def instantiate_from_config(config):
+ if not "target" in config:
+ if config == "__is_first_stage__":
+ return None
+ elif config == "__is_unconditional__":
+ return None
+ raise KeyError("Expected key `target` to instantiate.")
+ return get_obj_from_str(config["target"])(**config.get("params", dict()))
+
+
+def get_obj_from_str(string, reload=False, invalidate_cache=True):
+ module, cls = string.rsplit(".", 1)
+ if invalidate_cache:
+ importlib.invalidate_caches()
+ if reload:
+ module_imp = importlib.import_module(module)
+ importlib.reload(module_imp)
+ return getattr(importlib.import_module(module, package=None), cls)
+
+
+def append_zero(x):
+ return torch.cat([x, x.new_zeros([1])])
+
+
+def append_dims(x, target_dims):
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
+ dims_to_append = target_dims - x.ndim
+ if dims_to_append < 0:
+ raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
+ return x[(...,) + (None,) * dims_to_append]
+
+
+def load_model_from_config(config, ckpt, verbose=True, freeze=True):
+ print(f"Loading model from {ckpt}")
+ if ckpt.endswith("ckpt"):
+ pl_sd = torch.load(ckpt, map_location="cpu")
+ if "global_step" in pl_sd:
+ print(f"Global Step: {pl_sd['global_step']}")
+ sd = pl_sd["state_dict"]
+ elif ckpt.endswith("safetensors"):
+ sd = load_safetensors(ckpt)
+ else:
+ raise NotImplementedError
+
+ model = instantiate_from_config(config.model)
+
+ m, u = model.load_state_dict(sd, strict=False)
+
+ if len(m) > 0 and verbose:
+ print("missing keys:")
+ print(m)
+ if len(u) > 0 and verbose:
+ print("unexpected keys:")
+ print(u)
+
+ if freeze:
+ for param in model.parameters():
+ param.requires_grad = False
+
+ model.eval()
+ return model
+
+
+def get_configs_path() -> str:
+ """
+ Get the `configs` directory.
+ For a working copy, this is the one in the root of the repository,
+ but for an installed copy, it's in the `sgm` package (see pyproject.toml).
+ """
+ this_dir = os.path.dirname(__file__)
+ candidates = (
+ os.path.join(this_dir, "configs"),
+ os.path.join(this_dir, "..", "configs"),
+ )
+ for candidate in candidates:
+ candidate = os.path.abspath(candidate)
+ if os.path.isdir(candidate):
+ return candidate
+ raise FileNotFoundError(f"Could not find SGM configs in {candidates}")
+
+
+def get_nested_attribute(obj, attribute_path, depth=None, return_key=False):
+ """
+ Will return the result of a recursive get attribute call.
+ E.g.:
+ a.b.c
+ = getattr(getattr(a, "b"), "c")
+ = get_nested_attribute(a, "b.c")
+ If any part of the attribute call is an integer x with current obj a, will
+ try to call a[x] instead of a.x first.
+ """
+ attributes = attribute_path.split(".")
+ if depth is not None and depth > 0:
+ attributes = attributes[:depth]
+ assert len(attributes) > 0, "At least one attribute should be selected"
+ current_attribute = obj
+ current_key = None
+ for level, attribute in enumerate(attributes):
+ current_key = ".".join(attributes[: level + 1])
+ try:
+ id_ = int(attribute)
+ current_attribute = current_attribute[id_]
+ except ValueError:
+ current_attribute = getattr(current_attribute, attribute)
+
+ return (current_attribute, current_key) if return_key else current_attribute
+
+
+def checkpoint(func, inputs, params, flag):
+ """
+ Evaluate a function without caching intermediate activations, allowing for
+ reduced memory at the expense of extra compute in the backward pass.
+ :param func: the function to evaluate.
+ :param inputs: the argument sequence to pass to `func`.
+ :param params: a sequence of parameters `func` depends on but does not
+ explicitly take as arguments.
+ :param flag: if False, disable gradient checkpointing.
+ """
+ if flag:
+ args = tuple(inputs) + tuple(params)
+ return CheckpointFunction.apply(func, len(inputs), *args)
+ else:
+ return func(*inputs)
+
+
+class CheckpointFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, run_function, length, *args):
+ ctx.run_function = run_function
+ ctx.input_tensors = list(args[:length])
+ ctx.input_params = list(args[length:])
+ ctx.gpu_autocast_kwargs = {
+ "enabled": torch.is_autocast_enabled(),
+ "dtype": torch.get_autocast_gpu_dtype(),
+ "cache_enabled": torch.is_autocast_cache_enabled(),
+ }
+ with torch.no_grad():
+ output_tensors = ctx.run_function(*ctx.input_tensors)
+ return output_tensors
+
+ @staticmethod
+ def backward(ctx, *output_grads):
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
+ with torch.enable_grad(), torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
+ # Fixes a bug where the first op in run_function modifies the
+ # Tensor storage in place, which is not allowed for detach()'d
+ # Tensors.
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
+ output_tensors = ctx.run_function(*shallow_copies)
+ input_grads = torch.autograd.grad(
+ output_tensors,
+ ctx.input_tensors + ctx.input_params,
+ output_grads,
+ allow_unused=True,
+ )
+ del ctx.input_tensors
+ del ctx.input_params
+ del output_tensors
+ return (None, None) + input_grads
diff --git a/tools/caption/README.md b/tools/caption/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..bf02225ccfce051bc08cac2c56f316323a031970
--- /dev/null
+++ b/tools/caption/README.md
@@ -0,0 +1,18 @@
+# Video Caption
+
+Typically, most video data does not come with corresponding descriptive text, so it is necessary to convert the video
+data into textual descriptions to provide the essential training data for text-to-video models.
+
+## Video Caption via CogVLM2-Video
+
+
+
+CogVLM2-Video is a versatile video understanding model equipped with timestamp-based question answering capabilities.
+Users can input prompts such as `Please describe this video in detail.` to the model to obtain a detailed video caption:
+
+
+
+
+Users can use the provided [code](https://github.com/THUDM/CogVLM2/tree/main/video_demo) to load the model or configure a RESTful API to generate video captions.
\ No newline at end of file
diff --git a/tools/caption/README_zh.md b/tools/caption/README_zh.md
new file mode 100644
index 0000000000000000000000000000000000000000..dc905e4668da4780d6a45dac28bbc1c33636afb1
--- /dev/null
+++ b/tools/caption/README_zh.md
@@ -0,0 +1,16 @@
+# 视频Caption
+
+通常,大多数视频数据不带有相应的描述性文本,因此需要将视频数据转换为文本描述,以提供必要的训练数据用于文本到视频模型。
+
+## 通过 CogVLM2-Video 模型生成视频Caption
+
+🤗 [Hugging Face](https://huggingface.co/THUDM/cogvlm2-video-llama3-chat) | 🤖 [ModelScope](https://modelscope.cn/models/ZhipuAI/cogvlm2-video-llama3-chat) | 📑 [Blog](https://cogvlm2-video.github.io/) | [💬 Online Demo](http://cogvlm2-online.cogviewai.cn:7868/)
+
+CogVLM2-Video 是一个多功能的视频理解模型,具备基于时间戳的问题回答能力。用户可以输入诸如 `请详细描述这个视频` 的提示语给模型,以获得详细的视频Caption:
+
+
+
+
+
+
+用户可以使用提供的[代码](https://github.com/THUDM/CogVLM2/tree/main/video_demo)加载模型或配置 RESTful API 来生成视频Caption。
\ No newline at end of file
diff --git a/tools/caption/assests/cogvlm2-video-example.png b/tools/caption/assests/cogvlm2-video-example.png
new file mode 100644
index 0000000000000000000000000000000000000000..f0c2e6bae875945ca6db64d64edea2eefa4914c9
--- /dev/null
+++ b/tools/caption/assests/cogvlm2-video-example.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:93ffadabb7b0b32fbdce9c3bfdff68e2b1fe9af2277708828e58757ea81a568b
+size 1419122
diff --git a/tools/convert_weight_sat2hf.py b/tools/convert_weight_sat2hf.py
new file mode 100644
index 0000000000000000000000000000000000000000..0cec3f53741774410009fccf9d0f3858eaea548a
--- /dev/null
+++ b/tools/convert_weight_sat2hf.py
@@ -0,0 +1,268 @@
+"""
+This script demonstrates how to convert and generate video from a text prompt using CogVideoX with 🤗Huggingface Diffusers Pipeline.
+
+Note:
+ This script requires the `diffusers>=0.30.0` library to be installed.
+
+Run the script:
+ $ python convert_and_generate.py --transformer_ckpt_path --vae_ckpt_path --output_path --text_encoder_path
+
+Functions:
+ - reassign_query_key_value_inplace: Reassigns the query, key, and value weights in-place.
+ - reassign_query_key_layernorm_inplace: Reassigns layer normalization for query and key in-place.
+ - reassign_adaln_norm_inplace: Reassigns adaptive layer normalization in-place.
+ - remove_keys_inplace: Removes specified keys from the state_dict in-place.
+ - replace_up_keys_inplace: Replaces keys in the "up" block in-place.
+ - get_state_dict: Extracts the state_dict from a saved checkpoint.
+ - update_state_dict_inplace: Updates the state_dict with new key assignments in-place.
+ - convert_transformer: Converts a transformer checkpoint to the CogVideoX format.
+ - convert_vae: Converts a VAE checkpoint to the CogVideoX format.
+ - get_args: Parses command-line arguments for the script.
+ - generate_video: Generates a video from a text prompt using the CogVideoX pipeline.
+"""
+
+import argparse
+from typing import Any, Dict
+
+import torch
+from diffusers import AutoencoderKLCogVideoX, CogVideoXDDIMScheduler, CogVideoXPipeline, CogVideoXTransformer3DModel
+from transformers import T5EncoderModel, T5Tokenizer
+
+
+# Function to reassign the query, key, and value weights in-place
+def reassign_query_key_value_inplace(key: str, state_dict: Dict[str, Any]):
+ to_q_key = key.replace("query_key_value", "to_q")
+ to_k_key = key.replace("query_key_value", "to_k")
+ to_v_key = key.replace("query_key_value", "to_v")
+ to_q, to_k, to_v = torch.chunk(state_dict[key], chunks=3, dim=0)
+ state_dict[to_q_key] = to_q
+ state_dict[to_k_key] = to_k
+ state_dict[to_v_key] = to_v
+ state_dict.pop(key)
+
+
+# Function to reassign layer normalization for query and key in-place
+def reassign_query_key_layernorm_inplace(key: str, state_dict: Dict[str, Any]):
+ layer_id, weight_or_bias = key.split(".")[-2:]
+
+ if "query" in key:
+ new_key = f"transformer_blocks.{layer_id}.attn1.norm_q.{weight_or_bias}"
+ elif "key" in key:
+ new_key = f"transformer_blocks.{layer_id}.attn1.norm_k.{weight_or_bias}"
+
+ state_dict[new_key] = state_dict.pop(key)
+
+
+# Function to reassign adaptive layer normalization in-place
+def reassign_adaln_norm_inplace(key: str, state_dict: Dict[str, Any]):
+ layer_id, _, weight_or_bias = key.split(".")[-3:]
+
+ weights_or_biases = state_dict[key].chunk(12, dim=0)
+ norm1_weights_or_biases = torch.cat(weights_or_biases[0:3] + weights_or_biases[6:9])
+ norm2_weights_or_biases = torch.cat(weights_or_biases[3:6] + weights_or_biases[9:12])
+
+ norm1_key = f"transformer_blocks.{layer_id}.norm1.linear.{weight_or_bias}"
+ state_dict[norm1_key] = norm1_weights_or_biases
+
+ norm2_key = f"transformer_blocks.{layer_id}.norm2.linear.{weight_or_bias}"
+ state_dict[norm2_key] = norm2_weights_or_biases
+
+ state_dict.pop(key)
+
+
+# Function to remove keys from state_dict in-place
+def remove_keys_inplace(key: str, state_dict: Dict[str, Any]):
+ state_dict.pop(key)
+
+
+# Function to replace keys in the "up" block in-place
+def replace_up_keys_inplace(key: str, state_dict: Dict[str, Any]):
+ key_split = key.split(".")
+ layer_index = int(key_split[2])
+ replace_layer_index = 4 - 1 - layer_index
+
+ key_split[1] = "up_blocks"
+ key_split[2] = str(replace_layer_index)
+ new_key = ".".join(key_split)
+
+ state_dict[new_key] = state_dict.pop(key)
+
+
+# Dictionary for renaming transformer keys
+TRANSFORMER_KEYS_RENAME_DICT = {
+ "transformer.final_layernorm": "norm_final",
+ "transformer": "transformer_blocks",
+ "attention": "attn1",
+ "mlp": "ff.net",
+ "dense_h_to_4h": "0.proj",
+ "dense_4h_to_h": "2",
+ ".layers": "",
+ "dense": "to_out.0",
+ "input_layernorm": "norm1.norm",
+ "post_attn1_layernorm": "norm2.norm",
+ "time_embed.0": "time_embedding.linear_1",
+ "time_embed.2": "time_embedding.linear_2",
+ "mixins.patch_embed": "patch_embed",
+ "mixins.final_layer.norm_final": "norm_out.norm",
+ "mixins.final_layer.linear": "proj_out",
+ "mixins.final_layer.adaLN_modulation.1": "norm_out.linear",
+}
+
+# Dictionary for handling special keys in transformer
+TRANSFORMER_SPECIAL_KEYS_REMAP = {
+ "query_key_value": reassign_query_key_value_inplace,
+ "query_layernorm_list": reassign_query_key_layernorm_inplace,
+ "key_layernorm_list": reassign_query_key_layernorm_inplace,
+ "adaln_layer.adaLN_modulations": reassign_adaln_norm_inplace,
+ "embed_tokens": remove_keys_inplace,
+}
+
+# Dictionary for renaming VAE keys
+VAE_KEYS_RENAME_DICT = {
+ "block.": "resnets.",
+ "down.": "down_blocks.",
+ "downsample": "downsamplers.0",
+ "upsample": "upsamplers.0",
+ "nin_shortcut": "conv_shortcut",
+ "encoder.mid.block_1": "encoder.mid_block.resnets.0",
+ "encoder.mid.block_2": "encoder.mid_block.resnets.1",
+ "decoder.mid.block_1": "decoder.mid_block.resnets.0",
+ "decoder.mid.block_2": "decoder.mid_block.resnets.1",
+}
+
+# Dictionary for handling special keys in VAE
+VAE_SPECIAL_KEYS_REMAP = {
+ "loss": remove_keys_inplace,
+ "up.": replace_up_keys_inplace,
+}
+
+# Maximum length of the tokenizer (Must be 226)
+TOKENIZER_MAX_LENGTH = 226
+
+
+# Function to extract the state_dict from a saved checkpoint
+def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
+ state_dict = saved_dict
+ if "model" in saved_dict.keys():
+ state_dict = state_dict["model"]
+ if "module" in saved_dict.keys():
+ state_dict = state_dict["module"]
+ if "state_dict" in saved_dict.keys():
+ state_dict = state_dict["state_dict"]
+ return state_dict
+
+
+# Function to update the state_dict with new key assignments in-place
+def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
+ state_dict[new_key] = state_dict.pop(old_key)
+
+
+# Function to convert a transformer checkpoint to the CogVideoX format
+def convert_transformer(ckpt_path: str):
+ PREFIX_KEY = "model.diffusion_model."
+
+ original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True))
+ transformer = CogVideoXTransformer3DModel()
+
+ for key in list(original_state_dict.keys()):
+ new_key = key[len(PREFIX_KEY) :]
+ for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
+ new_key = new_key.replace(replace_key, rename_key)
+ update_state_dict_inplace(original_state_dict, key, new_key)
+
+ for key in list(original_state_dict.keys()):
+ for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
+ if special_key not in key:
+ continue
+ handler_fn_inplace(key, original_state_dict)
+
+ transformer.load_state_dict(original_state_dict, strict=True)
+ return transformer
+
+
+# Function to convert a VAE checkpoint to the CogVideoX format
+def convert_vae(ckpt_path: str):
+ original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True))
+ vae = AutoencoderKLCogVideoX()
+
+ for key in list(original_state_dict.keys()):
+ new_key = key[:]
+ for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
+ new_key = new_key.replace(replace_key, rename_key)
+ update_state_dict_inplace(original_state_dict, key, new_key)
+
+ for key in list(original_state_dict.keys()):
+ for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items():
+ if special_key not in key:
+ continue
+ handler_fn_inplace(key, original_state_dict)
+
+ vae.load_state_dict(original_state_dict, strict=True)
+ return vae
+
+
+# Function to parse command-line arguments for the script
+def get_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint"
+ )
+ parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original vae checkpoint")
+ parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
+ parser.add_argument(
+ "--text_encoder_path",
+ type=str,
+ required=True,
+ default="google/t5-v1_1-xxl",
+ help="Path where converted model should be saved",
+ )
+ parser.add_argument(
+ "--text_encoder_cache_dir",
+ type=str,
+ default=None,
+ help="Path to text encoder cache directory. Not needed if text_encoder_path is in your local.",
+ )
+ parser.add_argument("--fp16", action="store_true", default=True, help="Whether to save the model weights in fp16")
+ parser.add_argument(
+ "--push_to_hub", action="store_true", default=False, help="Whether to push to HF Hub after saving"
+ )
+ return parser.parse_args()
+
+
+if __name__ == "__main__":
+ args = get_args()
+
+ transformer = None
+ vae = None
+
+ if args.transformer_ckpt_path is not None:
+ transformer = convert_transformer(args.transformer_ckpt_path)
+ if args.vae_ckpt_path is not None:
+ vae = convert_vae(args.vae_ckpt_path)
+
+ tokenizer = T5Tokenizer.from_pretrained(args.text_encoder_path, model_max_length=TOKENIZER_MAX_LENGTH)
+ text_encoder = T5EncoderModel.from_pretrained(args.text_encoder_path, cache_dir=args.text_encoder_cache_dir)
+
+ scheduler = CogVideoXDDIMScheduler.from_config(
+ {
+ "snr_shift_scale": 3.0,
+ "beta_end": 0.012,
+ "beta_schedule": "scaled_linear",
+ "beta_start": 0.00085,
+ "clip_sample": False,
+ "num_train_timesteps": 1000,
+ "prediction_type": "v_prediction",
+ "rescale_betas_zero_snr": True,
+ "set_alpha_to_one": True,
+ "timestep_spacing": "linspace",
+ }
+ )
+
+ pipe = CogVideoXPipeline(
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
+ )
+
+ if args.fp16:
+ pipe = pipe.to(dtype=torch.float16)
+
+ pipe.save_pretrained(args.output_path, safe_serialization=True, push_to_hub=args.push_to_hub)