Commit
·
22ee710
verified
·
0
Parent(s):
Duplicate from tencent/HunyuanImage-3.0
Browse filesCo-authored-by: TencentOpen <TencentOpen@users.noreply.huggingface.co>
This view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +40 -0
- LICENSE +78 -0
- README.md +502 -0
- __init__.py +0 -0
- assets/WECHAT.md +6 -0
- assets/banner.png +3 -0
- assets/banner_all.jpg +3 -0
- assets/framework.png +3 -0
- assets/gsb.png +3 -0
- assets/logo.png +3 -0
- assets/pg_imgs/image1.png +3 -0
- assets/pg_imgs/image2.png +3 -0
- assets/pg_imgs/image3.png +3 -0
- assets/pg_imgs/image4.png +3 -0
- assets/pg_imgs/image5.png +3 -0
- assets/pg_imgs/image6.png +3 -0
- assets/pg_imgs/image7.png +3 -0
- assets/pg_imgs/image8.png +3 -0
- assets/robot.png +3 -0
- assets/ssae_side_by_side_comparison.png +3 -0
- assets/ssae_side_by_side_heatmap.png +3 -0
- assets/user.png +3 -0
- assets/wechat.png +3 -0
- autoencoder_kl_3d.py +793 -0
- config.json +273 -0
- configuration_hunyuan.py +285 -0
- generation_config.json +20 -0
- hunyuan.py +0 -0
- hunyuan_image_3_pipeline.py +879 -0
- image_processor.py +125 -0
- model-0001-of-0032.safetensors +3 -0
- model-0002-of-0032.safetensors +3 -0
- model-0003-of-0032.safetensors +3 -0
- model-0004-of-0032.safetensors +3 -0
- model-0005-of-0032.safetensors +3 -0
- model-0006-of-0032.safetensors +3 -0
- model-0007-of-0032.safetensors +3 -0
- model-0008-of-0032.safetensors +3 -0
- model-0009-of-0032.safetensors +3 -0
- model-0010-of-0032.safetensors +3 -0
- model-0011-of-0032.safetensors +3 -0
- model-0012-of-0032.safetensors +3 -0
- model-0013-of-0032.safetensors +3 -0
- model-0014-of-0032.safetensors +3 -0
- model-0015-of-0032.safetensors +3 -0
- model-0016-of-0032.safetensors +3 -0
- model-0017-of-0032.safetensors +3 -0
- model-0018-of-0032.safetensors +3 -0
- model-0019-of-0032.safetensors +3 -0
- model-0020-of-0032.safetensors +3 -0
.gitattributes
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
assets/banner_all.jpg filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
assets/**/*.png filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
*.tar.gz filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT
|
| 2 |
+
Tencent Hunyuan Image 3.0 Release Date: September 28, 2025
|
| 3 |
+
THIS LICENSE AGREEMENT DOES NOT APPLY IN THE EUROPEAN UNION, UNITED KINGDOM AND SOUTH KOREA AND IS EXPRESSLY LIMITED TO THE TERRITORY, AS DEFINED BELOW.
|
| 4 |
+
By clicking to agree or by using, reproducing, modifying, distributing, performing or displaying any portion or element of the Tencent Hunyuan Works, including via any Hosted Service, You will be deemed to have recognized and accepted the content of this Agreement, which is effective immediately.
|
| 5 |
+
1. DEFINITIONS.
|
| 6 |
+
a. “Acceptable Use Policy” shall mean the policy made available by Tencent as set forth in the Exhibit A.
|
| 7 |
+
b. “Agreement” shall mean the terms and conditions for use, reproduction, distribution, modification, performance and displaying of Tencent Hunyuan Works or any portion or element thereof set forth herein.
|
| 8 |
+
c. “Documentation” shall mean the specifications, manuals and documentation for Tencent Hunyuan made publicly available by Tencent.
|
| 9 |
+
d. “Hosted Service” shall mean a hosted service offered via an application programming interface (API), web access, or any other electronic or remote means.
|
| 10 |
+
e. “Licensee,” “You” or “Your” shall mean a natural person or legal entity exercising the rights granted by this Agreement and/or using the Tencent Hunyuan Works for any purpose and in any field of use.
|
| 11 |
+
f. “Materials” shall mean, collectively, Tencent’s proprietary Tencent Hunyuan and Documentation (and any portion thereof) as made available by Tencent under this Agreement.
|
| 12 |
+
g. “Model Derivatives” shall mean all: (i) modifications to Tencent Hunyuan or any Model Derivative of Tencent Hunyuan; (ii) works based on Tencent Hunyuan or any Model Derivative of Tencent Hunyuan; or (iii) any other machine learning model which is created by transfer of patterns of the weights, parameters, operations, or Output of Tencent Hunyuan or any Model Derivative of Tencent Hunyuan, to that model in order to cause that model to perform similarly to Tencent Hunyuan or a Model Derivative of Tencent Hunyuan, including distillation methods, methods that use intermediate data representations, or methods based on the generation of synthetic data Outputs by Tencent Hunyuan or a Model Derivative of Tencent Hunyuan for training that model. For clarity, Outputs by themselves are not deemed Model Derivatives.
|
| 13 |
+
h. “Output” shall mean the information and/or content output of Tencent Hunyuan or a Model Derivative that results from operating or otherwise using Tencent Hunyuan or a Model Derivative, including via a Hosted Service.
|
| 14 |
+
i. “Tencent,” “We” or “Us” shall mean the applicable entity or entities in the Tencent corporate family that own(s) intellectual property or other rights embodied in or utilized by the Materials.
|
| 15 |
+
j. “Tencent Hunyuan” shall mean the large language models, text/image/video/audio/3D generation models, and multimodal large language models and their software and algorithms, including trained model weights, parameters (including optimizer states), machine-learning model code, inference-enabling code, training-enabling code, fine-tuning enabling code and other elements of the foregoing made publicly available by Us, including, without limitation to, Tencent Hunyuan Image 2.1 released at [
|
| 16 |
+
https://github.com/Tencent-Hunyuan/HunyuanImage-3.0;https://huggingface.co/tencent/HunyuanImage-3.0;https://huggingface.co/tencent/HunyuanImage-3.0-Instruct;https://modelscope.cn/models/Tencent-Hunyuan HunyuanImage-3.0/;https://ai.gitcode.com/tencent_hunyuan/HunyuanImage-3.0].
|
| 17 |
+
k. “Tencent Hunyuan Works” shall mean: (i) the Materials; (ii) Model Derivatives; and (iii) all derivative works thereof.
|
| 18 |
+
l. “Territory” shall mean the worldwide territory, excluding the territory of the European Union, United Kingdom and South Korea.
|
| 19 |
+
m. “Third Party” or “Third Parties” shall mean individuals or legal entities that are not under common control with Us or You.
|
| 20 |
+
n. “including” shall mean including but not limited to.
|
| 21 |
+
2. GRANT OF RIGHTS.
|
| 22 |
+
We grant You, for the Territory only, a non-exclusive, non-transferable and royalty-free limited license under Tencent’s intellectual property or other rights owned by Us embodied in or utilized by the Materials to use, reproduce, distribute, create derivative works of (including Model Derivatives), and make modifications to the Materials, only in accordance with the terms of this Agreement and the Acceptable Use Policy, and You must not violate (or encourage or permit anyone else to violate) any term of this Agreement or the Acceptable Use Policy.
|
| 23 |
+
3. DISTRIBUTION.
|
| 24 |
+
You may, subject to Your compliance with this Agreement, distribute or make available to Third Parties the Tencent Hunyuan Works, exclusively in the Territory, provided that You meet all of the following conditions:
|
| 25 |
+
a. You must provide all such Third Party recipients of the Tencent Hunyuan Works or products or services using them a copy of this Agreement;
|
| 26 |
+
b. You must cause any modified files to carry prominent notices stating that You changed the files;
|
| 27 |
+
c. You are encouraged to: (i) publish at least one technology introduction blogpost or one public statement expressing Your experience of using the Tencent Hunyuan Works; and (ii) mark the products or services developed by using the Tencent Hunyuan Works to indicate that the product/service is “Powered by Tencent Hunyuan”; and
|
| 28 |
+
d. All distributions to Third Parties (other than through a Hosted Service) must be accompanied by a “Notice” text file that contains the following notice: “Tencent Hunyuan is licensed under the Tencent Hunyuan Community License Agreement, Copyright © 2025 Tencent. All Rights Reserved. The trademark rights of “Tencent Hunyuan” are owned by Tencent or its affiliate.”
|
| 29 |
+
You may add Your own copyright statement to Your modifications and, except as set forth in this Section and in Section 5, may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Model Derivatives as a whole, provided Your use, reproduction, modification, distribution, performance and display of the work otherwise complies with the terms and conditions of this Agreement (including as regards the Territory). If You receive Tencent Hunyuan Works from a Licensee as part of an integrated end user product, then this Section 3 of this Agreement will not apply to You.
|
| 30 |
+
4. ADDITIONAL COMMERCIAL TERMS.
|
| 31 |
+
If, on the Tencent Hunyuan version release date, the monthly active users of all products or services made available by or for Licensee is greater than 100 million monthly active users in the preceding calendar month, You must request a license from Tencent, which Tencent may grant to You in its sole discretion, and You are not authorized to exercise any of the rights under this Agreement unless or until Tencent otherwise expressly grants You such rights.
|
| 32 |
+
5. RULES OF USE.
|
| 33 |
+
a. Your use of the Tencent Hunyuan Works must comply with applicable laws and regulations (including trade compliance laws and regulations) and adhere to the Acceptable Use Policy for the Tencent Hunyuan Works, which is hereby incorporated by reference into this Agreement. You must include the use restrictions referenced in these Sections 5(a) and 5(b) as an enforceable provision in any agreement (e.g., license agreement, terms of use, etc.) governing the use and/or distribution of Tencent Hunyuan Works and You must provide notice to subsequent users to whom You distribute that Tencent Hunyuan Works are subject to the use restrictions in these Sections 5(a) and 5(b).
|
| 34 |
+
b. You must not use the Tencent Hunyuan Works or any Output or results of the Tencent Hunyuan Works to improve any other AI model (other than Tencent Hunyuan or Model Derivatives thereof).
|
| 35 |
+
c. You must not use, reproduce, modify, distribute, or display the Tencent Hunyuan Works, Output or results of the Tencent Hunyuan Works outside the Territory. Any such use outside the Territory is unlicensed and unauthorized under this Agreement.
|
| 36 |
+
6. INTELLECTUAL PROPERTY.
|
| 37 |
+
a. Subject to Tencent’s ownership of Tencent Hunyuan Works made by or for Tencent and intellectual property rights therein, conditioned upon Your compliance with the terms and conditions of this Agreement, as between You and Tencent, You will be the owner of any derivative works and modifications of the Materials and any Model Derivatives that are made by or for You.
|
| 38 |
+
b. No trademark licenses are granted under this Agreement, and in connection with the Tencent Hunyuan Works, Licensee may not use any name or mark owned by or associated with Tencent or any of its affiliates, except as required for reasonable and customary use in describing and distributing the Tencent Hunyuan Works. Tencent hereby grants You a license to use “Tencent Hunyuan” (the “Mark”) in the Territory solely as required to comply with the provisions of Section 3(c), provided that You comply with any applicable laws related to trademark protection. All goodwill arising out of Your use of the Mark will inure to the benefit of Tencent.
|
| 39 |
+
c. If You commence a lawsuit or other proceedings (including a cross-claim or counterclaim in a lawsuit) against Us or any person or entity alleging that the Materials or any Output, or any portion of any of the foregoing, infringe any intellectual property or other right owned or licensable by You, then all licenses granted to You under this Agreement shall terminate as of the date such lawsuit or other proceeding is filed. You will defend, indemnify and hold harmless Us from and against any claim by any Third Party arising out of or related to Your or the Third Party’s use or distribution of the Tencent Hunyuan Works.
|
| 40 |
+
d. Tencent claims no rights in Outputs You generate. You and Your users are solely responsible for Outputs and their subsequent uses.
|
| 41 |
+
7. DISCLAIMERS OF WARRANTY AND LIMITATIONS OF LIABILITY.
|
| 42 |
+
a. We are not obligated to support, update, provide training for, or develop any further version of the Tencent Hunyuan Works or to grant any license thereto.
|
| 43 |
+
b. UNLESS AND ONLY TO THE EXTENT REQUIRED BY APPLICABLE LAW, THE TENCENT HUNYUAN WORKS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED “AS IS” WITHOUT ANY EXPRESS OR IMPLIED WARRANTIES OF ANY KIND INCLUDING ANY WARRANTIES OF TITLE, MERCHANTABILITY, NONINFRINGEMENT, COURSE OF DEALING, USAGE OF TRADE, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING, REPRODUCING, MODIFYING, PERFORMING, DISPLAYING OR DISTRIBUTING ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS AND ASSUME ANY AND ALL RISKS ASSOCIATED WITH YOUR OR A THIRD PARTY’S USE OR DISTRIBUTION OF ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS AND YOUR EXERCISE OF RIGHTS AND PERMISSIONS UNDER THIS AGREEMENT.
|
| 44 |
+
c. TO THE FULLEST EXTENT PERMITTED BY APPLICABLE LAW, IN NO EVENT SHALL TENCENT OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, FOR ANY DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, CONSEQUENTIAL OR PUNITIVE DAMAGES, OR LOST PROFITS OF ANY KIND ARISING FROM THIS AGREEMENT OR RELATED TO ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS, EVEN IF TENCENT OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
|
| 45 |
+
8. SURVIVAL AND TERMINATION.
|
| 46 |
+
a. The term of this Agreement shall commence upon Your acceptance of this Agreement or access to the Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein.
|
| 47 |
+
b. We may terminate this Agreement if You breach any of the terms or conditions of this Agreement. Upon termination of this Agreement, You must promptly delete and cease use of the Tencent Hunyuan Works. Sections 6(a), 6(c), 7 and 9 shall survive the termination of this Agreement.
|
| 48 |
+
9. GOVERNING LAW AND JURISDICTION.
|
| 49 |
+
a. This Agreement and any dispute arising out of or relating to it will be governed by the laws of the Hong Kong Special Administrative Region of the People’s Republic of China, without regard to conflict of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement.
|
| 50 |
+
b. Exclusive jurisdiction and venue for any dispute arising out of or relating to this Agreement will be a court of competent jurisdiction in the Hong Kong Special Administrative Region of the People’s Republic of China, and Tencent and Licensee consent to the exclusive jurisdiction of such court with respect to any such dispute.
|
| 51 |
+
|
| 52 |
+
EXHIBIT A
|
| 53 |
+
ACCEPTABLE USE POLICY
|
| 54 |
+
|
| 55 |
+
Tencent reserves the right to update this Acceptable Use Policy from time to time.
|
| 56 |
+
Last modified: November 5, 2024
|
| 57 |
+
|
| 58 |
+
Tencent endeavors to promote safe and fair use of its tools and features, including Tencent Hunyuan. You agree not to use Tencent Hunyuan or Model Derivatives:
|
| 59 |
+
1. Outside the Territory;
|
| 60 |
+
2. In any way that violates any applicable national, federal, state, local, international or any other law or regulation;
|
| 61 |
+
3. To harm Yourself or others;
|
| 62 |
+
4. To repurpose or distribute output from Tencent Hunyuan or any Model Derivatives to harm Yourself or others;
|
| 63 |
+
5. To override or circumvent the safety guardrails and safeguards We have put in place;
|
| 64 |
+
6. For the purpose of exploiting, harming or attempting to exploit or harm minors in any way;
|
| 65 |
+
7. To generate or disseminate verifiably false information and/or content with the purpose of harming others or influencing elections;
|
| 66 |
+
8. To generate or facilitate false online engagement, including fake reviews and other means of fake online engagement;
|
| 67 |
+
9. To intentionally defame, disparage or otherwise harass others;
|
| 68 |
+
10. To generate and/or disseminate malware (including ransomware) or any other content to be used for the purpose of harming electronic systems;
|
| 69 |
+
11. To generate or disseminate personal identifiable information with the purpose of harming others;
|
| 70 |
+
12. To generate or disseminate information (including images, code, posts, articles), and place the information in any public context (including –through the use of bot generated tweets), without expressly and conspicuously identifying that the information and/or content is machine generated;
|
| 71 |
+
13. To impersonate another individual without consent, authorization, or legal right;
|
| 72 |
+
14. To make high-stakes automated decisions in domains that affect an individual’s safety, rights or wellbeing (e.g., law enforcement, migration, medicine/health, management of critical infrastructure, safety components of products, essential services, credit, employment, housing, education, social scoring, or insurance);
|
| 73 |
+
15. In a manner that violates or disrespects the social ethics and moral standards of other countries or regions;
|
| 74 |
+
16. To perform, facilitate, threaten, incite, plan, promote or encourage violent extremism or terrorism;
|
| 75 |
+
17. For any use intended to discriminate against or harm individuals or groups based on protected characteristics or categories, online or offline social behavior or known or predicted personal or personality characteristics;
|
| 76 |
+
18. To intentionally exploit any of the vulnerabilities of a specific group of persons based on their age, social, physical or mental characteristics, in order to materially distort the behavior of a person pertaining to that group in a manner that causes or is likely to cause that person or another person physical or psychological harm;
|
| 77 |
+
19. For military purposes;
|
| 78 |
+
20. To engage in the unauthorized or unlicensed practice of any profession including, but not limited to, financial, legal, medical/health, or other professional practices.
|
README.md
ADDED
|
@@ -0,0 +1,502 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: other
|
| 3 |
+
license_name: tencent-hunyuan-community
|
| 4 |
+
license_link: LICENSE
|
| 5 |
+
pipeline_tag: text-to-image
|
| 6 |
+
library_name: transformers
|
| 7 |
+
---
|
| 8 |
+
|
| 9 |
+
<div align="center">
|
| 10 |
+
|
| 11 |
+
<img src="./assets/logo.png" alt="HunyuanImage-3.0 Logo" width="600">
|
| 12 |
+
|
| 13 |
+
# 🎨 HunyuanImage-3.0: A Powerful Native Multimodal Model for Image Generation
|
| 14 |
+
|
| 15 |
+
</div>
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
<div align="center">
|
| 19 |
+
<img src="./assets/banner.png" alt="HunyuanImage-3.0 Banner" width="800">
|
| 20 |
+
|
| 21 |
+
</div>
|
| 22 |
+
|
| 23 |
+
<div align="center">
|
| 24 |
+
<a href=https://hunyuan.tencent.com/image target="_blank"><img src=https://img.shields.io/badge/Official%20Site-333399.svg?logo=homepage height=22px></a>
|
| 25 |
+
<a href=https://huggingface.co/tencent/HunyuanImage-3.0 target="_blank"><img src=https://img.shields.io/badge/%F0%9F%A4%97%20Models-d96902.svg height=22px></a>
|
| 26 |
+
<a href=https://github.com/Tencent-Hunyuan/HunyuanImage-3.0 target="_blank"><img src= https://img.shields.io/badge/Page-bb8a2e.svg?logo=github height=22px></a>
|
| 27 |
+
<a href=https://arxiv.org/pdf/2509.23951 target="_blank"><img src=https://img.shields.io/badge/Report-b5212f.svg?logo=arxiv height=22px></a>
|
| 28 |
+
<a href=https://x.com/TencentHunyuan target="_blank"><img src=https://img.shields.io/badge/Hunyuan-black.svg?logo=x height=22px></a>
|
| 29 |
+
<a href=https://docs.qq.com/doc/DUVVadmhCdG9qRXBU target="_blank"><img src=https://img.shields.io/badge/📚-PromptHandBook-blue.svg?logo=book height=22px></a>
|
| 30 |
+
</div>
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
<p align="center">
|
| 34 |
+
👏 Join our <a href="./assets/WECHAT.md" target="_blank">WeChat</a> and <a href="https://discord.gg/ehjWMqF5wY">Discord</a> |
|
| 35 |
+
💻 <a href="https://hunyuan.tencent.com/modelSquare/home/play?modelId=289&from=/visual">Official website(官网) Try our model!</a>  
|
| 36 |
+
</p>
|
| 37 |
+
|
| 38 |
+
## 🔥🔥🔥 News
|
| 39 |
+
- **September 28, 2025**: 📖 **HunyuanImage-3.0 Technical Report Released** - Comprehensive technical documentation now available
|
| 40 |
+
- **September 28, 2025**: 🚀 **HunyuanImage-3.0 Open Source Release** - Inference code and model weights publicly available
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
## 🧩 Community Contributions
|
| 44 |
+
|
| 45 |
+
If you develop/use HunyuanImage-3.0 in your projects, welcome to let us know.
|
| 46 |
+
|
| 47 |
+
## 📑 Open-source Plan
|
| 48 |
+
|
| 49 |
+
- HunyuanImage-3.0 (Image Generation Model)
|
| 50 |
+
- [x] Inference
|
| 51 |
+
- [x] HunyuanImage-3.0 Checkpoints
|
| 52 |
+
- [ ] HunyuanImage-3.0-Instruct Checkpoints (with reasoning)
|
| 53 |
+
- [ ] VLLM Support
|
| 54 |
+
- [ ] Distilled Checkpoints
|
| 55 |
+
- [ ] Image-to-Image Generation
|
| 56 |
+
- [ ] Multi-turn Interaction
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
## 🗂️ Contents
|
| 60 |
+
- [🔥🔥🔥 News](#-news)
|
| 61 |
+
- [🧩 Community Contributions](#-community-contributions)
|
| 62 |
+
- [📑 Open-source Plan](#-open-source-plan)
|
| 63 |
+
- [📖 Introduction](#-introduction)
|
| 64 |
+
- [✨ Key Features](#-key-features)
|
| 65 |
+
- [🛠️ Dependencies and Installation](#-dependencies-and-installation)
|
| 66 |
+
- [💻 System Requirements](#-system-requirements)
|
| 67 |
+
- [📦 Environment Setup](#-environment-setup)
|
| 68 |
+
- [📥 Install Dependencies](#-install-dependencies)
|
| 69 |
+
- [Performance Optimizations](#performance-optimizations)
|
| 70 |
+
- [🚀 Usage](#-usage)
|
| 71 |
+
- [🔥 Quick Start with Transformers](#-quick-start-with-transformers)
|
| 72 |
+
- [🏠 Local Installation & Usage](#-local-installation--usage)
|
| 73 |
+
- [🎨 Interactive Gradio Demo](#-interactive-gradio-demo)
|
| 74 |
+
- [🧱 Models Cards](#-models-cards)
|
| 75 |
+
- [📝 Prompt Guide](#-prompt-guide)
|
| 76 |
+
- [Manually Writing Prompts](#manually-writing-prompts)
|
| 77 |
+
- [System Prompt For Automatic Rewriting the Prompt](#system-prompt-for-automatic-rewriting-the-prompt)
|
| 78 |
+
- [Advanced Tips](#advanced-tips)
|
| 79 |
+
- [More Cases](#more-cases)
|
| 80 |
+
- [📊 Evaluation](#-evaluation)
|
| 81 |
+
- [📚 Citation](#-citation)
|
| 82 |
+
- [🙏 Acknowledgements](#-acknowledgements)
|
| 83 |
+
- [🌟🚀 Github Star History](#-github-star-history)
|
| 84 |
+
|
| 85 |
+
---
|
| 86 |
+
|
| 87 |
+
## 📖 Introduction
|
| 88 |
+
|
| 89 |
+
**HunyuanImage-3.0** is a groundbreaking native multimodal model that unifies multimodal understanding and generation within an autoregressive framework. Our text-to-image module achieves performance **comparable to or surpassing** leading closed-source models.
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
<div align="center">
|
| 93 |
+
<img src="./assets/framework.png" alt="HunyuanImage-3.0 Framework" width="90%">
|
| 94 |
+
</div>
|
| 95 |
+
|
| 96 |
+
## ✨ Key Features
|
| 97 |
+
|
| 98 |
+
* 🧠 **Unified Multimodal Architecture:** Moving beyond the prevalent DiT-based architectures, HunyuanImage-3.0 employs a unified autoregressive framework. This design enables a more direct and integrated modeling of text and image modalities, leading to surprisingly effective and contextually rich image generation.
|
| 99 |
+
|
| 100 |
+
* 🏆 **The Largest Image Generation MoE Model:** This is the largest open-source image generation Mixture of Experts (MoE) model to date. It features 64 experts and a total of 80 billion parameters, with 13 billion activated per token, significantly enhancing its capacity and performance.
|
| 101 |
+
|
| 102 |
+
* 🎨 **Superior Image Generation Performance:** Through rigorous dataset curation and advanced reinforcement learning post-training, we've achieved an optimal balance between semantic accuracy and visual excellence. The model demonstrates exceptional prompt adherence while delivering photorealistic imagery with stunning aesthetic quality and fine-grained details.
|
| 103 |
+
|
| 104 |
+
* 💭 **Intelligent World-Knowledge Reasoning:** The unified multimodal architecture endows HunyuanImage-3.0 with powerful reasoning capabilities. It leverages its extensive world knowledge to intelligently interpret user intent, automatically elaborating on sparse prompts with contextually appropriate details to produce superior, more complete visual outputs.
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
## 🛠️ Dependencies and Installation
|
| 108 |
+
|
| 109 |
+
### 💻 System Requirements
|
| 110 |
+
|
| 111 |
+
* 🖥️ **Operating System:** Linux
|
| 112 |
+
* 🎮 **GPU:** NVIDIA GPU with CUDA support
|
| 113 |
+
* 💾 **Disk Space:** 170GB for model weights
|
| 114 |
+
* 🧠 **GPU Memory:** ≥3×80GB (4×80GB recommended for better performance)
|
| 115 |
+
|
| 116 |
+
### 📦 Environment Setup
|
| 117 |
+
|
| 118 |
+
* 🐍 **Python:** 3.12+ (recommended and tested)
|
| 119 |
+
* 🔥 **PyTorch:** 2.7.1
|
| 120 |
+
* ⚡ **CUDA:** 12.8
|
| 121 |
+
|
| 122 |
+
### 📥 Install Dependencies
|
| 123 |
+
|
| 124 |
+
```bash
|
| 125 |
+
# 1. First install PyTorch (CUDA 12.8 Version)
|
| 126 |
+
pip install torch==2.7.1 torchvision==0.22.1 torchaudio==2.7.1 --index-url https://download.pytorch.org/whl/cu128
|
| 127 |
+
|
| 128 |
+
# 2. Then install tencentcloud-sdk
|
| 129 |
+
pip install -i https://mirrors.tencent.com/pypi/simple/ --upgrade tencentcloud-sdk-python
|
| 130 |
+
|
| 131 |
+
# 3. Then install other dependencies
|
| 132 |
+
pip install -r requirements.txt
|
| 133 |
+
```
|
| 134 |
+
|
| 135 |
+
#### Performance Optimizations
|
| 136 |
+
|
| 137 |
+
For **up to 3x faster inference**, install these optimizations:
|
| 138 |
+
|
| 139 |
+
```bash
|
| 140 |
+
# FlashAttention for faster attention computation
|
| 141 |
+
pip install flash-attn==2.8.3 --no-build-isolation
|
| 142 |
+
|
| 143 |
+
# FlashInfer for optimized moe inference. v0.3.1 is tested.
|
| 144 |
+
pip install flashinfer-python
|
| 145 |
+
```
|
| 146 |
+
> 💡**Installation Tips:** It is critical that the CUDA version used by PyTorch matches the system's CUDA version.
|
| 147 |
+
> FlashInfer relies on this compatibility when compiling kernels at runtime. Pytorch 2.7.1+cu128 is tested.
|
| 148 |
+
> GCC version >=9 is recommended for compiling FlashAttention and FlashInfer.
|
| 149 |
+
|
| 150 |
+
> ⚡ **Performance Tips:** These optimizations can significantly speed up your inference!
|
| 151 |
+
|
| 152 |
+
> 💡**Notation:** When FlashInfer is enabled, the first inference may be slower (about 10 minutes) due to kernel compilation. Subsequent inferences on the same machine will be much faster.
|
| 153 |
+
|
| 154 |
+
## 🚀 Usage
|
| 155 |
+
|
| 156 |
+
### 🔥 Quick Start with Transformers
|
| 157 |
+
|
| 158 |
+
#### 1️⃣ Download model weights
|
| 159 |
+
|
| 160 |
+
```bash
|
| 161 |
+
# Download from HuggingFace and rename the directory.
|
| 162 |
+
# Notice that the directory name should not contain dots, which may cause issues when loading using Transformers.
|
| 163 |
+
hf download tencent/HunyuanImage-3.0 --local-dir ./HunyuanImage-3
|
| 164 |
+
```
|
| 165 |
+
|
| 166 |
+
#### 2️⃣ Run with Transformers
|
| 167 |
+
|
| 168 |
+
```python
|
| 169 |
+
from transformers import AutoModelForCausalLM
|
| 170 |
+
|
| 171 |
+
# Load the model
|
| 172 |
+
model_id = "./HunyuanImage-3"
|
| 173 |
+
# Currently we can not load the model using HF model_id `tencent/HunyuanImage-3.0` directly
|
| 174 |
+
# due to the dot in the name.
|
| 175 |
+
|
| 176 |
+
kwargs = dict(
|
| 177 |
+
attn_implementation="sdpa", # Use "flash_attention_2" if FlashAttention is installed
|
| 178 |
+
trust_remote_code=True,
|
| 179 |
+
torch_dtype="auto",
|
| 180 |
+
device_map="auto",
|
| 181 |
+
moe_impl="eager", # Use "flashinfer" if FlashInfer is installed
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
model = AutoModelForCausalLM.from_pretrained(model_id, **kwargs)
|
| 185 |
+
model.load_tokenizer(model_id)
|
| 186 |
+
|
| 187 |
+
# generate the image
|
| 188 |
+
prompt = "A brown and white dog is running on the grass"
|
| 189 |
+
image = model.generate_image(prompt=prompt, stream=True)
|
| 190 |
+
image.save("image.png")
|
| 191 |
+
```
|
| 192 |
+
|
| 193 |
+
### 🏠 Local Installation & Usage
|
| 194 |
+
|
| 195 |
+
#### 1️⃣ Clone the Repository
|
| 196 |
+
|
| 197 |
+
```bash
|
| 198 |
+
git clone https://github.com/Tencent-Hunyuan/HunyuanImage-3.0.git
|
| 199 |
+
cd HunyuanImage-3.0/
|
| 200 |
+
```
|
| 201 |
+
|
| 202 |
+
#### 2️⃣ Download Model Weights
|
| 203 |
+
|
| 204 |
+
```bash
|
| 205 |
+
# Download from HuggingFace
|
| 206 |
+
hf download tencent/HunyuanImage-3.0 --local-dir ./HunyuanImage-3
|
| 207 |
+
```
|
| 208 |
+
|
| 209 |
+
#### 3️⃣ Run the Demo
|
| 210 |
+
The Pretrain Checkpoint does not automatically rewrite or enhance input prompts, for optimal results currently, we recommend community partners to use deepseek to rewrite the prompts. You can go to [Tencent Cloud](https://cloud.tencent.com/document/product/1772/115963#.E5.BF.AB.E9.80.9F.E6.8E.A5.E5.85.A5) to apply for an API Key.
|
| 211 |
+
|
| 212 |
+
```bash
|
| 213 |
+
# set env
|
| 214 |
+
export DEEPSEEK_KEY_ID="your_deepseek_key_id"
|
| 215 |
+
export DEEPSEEK_KEY_SECRET="your_deepseek_key_secret"
|
| 216 |
+
|
| 217 |
+
python3 run_image_gen.py --model-id ./HunyuanImage-3 --verbose 1 --sys-deepseek-prompt "universal" --prompt "A brown and white dog is running on the grass"
|
| 218 |
+
```
|
| 219 |
+
|
| 220 |
+
#### 4️⃣ Command Line Arguments
|
| 221 |
+
|
| 222 |
+
| Arguments | Description | Default |
|
| 223 |
+
| ----------------------- | ------------------------------------------------------------ | ----------- |
|
| 224 |
+
| `--prompt` | Input prompt | (Required) |
|
| 225 |
+
| `--model-id` | Model path | (Required) |
|
| 226 |
+
| `--attn-impl` | Attention implementation. Either `sdpa` or `flash_attention_2`. | `sdpa` |
|
| 227 |
+
| `--moe-impl` | MoE implementation. Either `eager` or `flashinfer` | `eager` |
|
| 228 |
+
| `--seed` | Random seed for image generation | `None` |
|
| 229 |
+
| `--diff-infer-steps` | Diffusion infer steps | `50` |
|
| 230 |
+
| `--image-size` | Image resolution. Can be `auto`, like `1280x768` or `16:9` | `auto` |
|
| 231 |
+
| `--save` | Image save path. | `image.png` |
|
| 232 |
+
| `--verbose` | Verbose level. 0: No log; 1: log inference information. | `0` |
|
| 233 |
+
| `--rewrite` | Whether to enable rewriting | `1` |
|
| 234 |
+
| `--sys-deepseek-prompt` | Select sys-prompt from `universal` or `text_rendering` | `universal` |
|
| 235 |
+
|
| 236 |
+
### 🎨 Interactive Gradio Demo
|
| 237 |
+
|
| 238 |
+
Launch an interactive web interface for easy text-to-image generation.
|
| 239 |
+
|
| 240 |
+
#### 1️⃣ Install Gradio
|
| 241 |
+
|
| 242 |
+
```bash
|
| 243 |
+
pip install gradio>=4.21.0
|
| 244 |
+
```
|
| 245 |
+
|
| 246 |
+
#### 2️⃣ Configure Environment
|
| 247 |
+
|
| 248 |
+
```bash
|
| 249 |
+
# Set your model path
|
| 250 |
+
export MODEL_ID="path/to/your/model"
|
| 251 |
+
|
| 252 |
+
# Optional: Configure GPU usage (default: 0,1,2,3)
|
| 253 |
+
export GPUS="0,1,2,3"
|
| 254 |
+
|
| 255 |
+
# Optional: Configure host and port (default: 0.0.0.0:443)
|
| 256 |
+
export HOST="0.0.0.0"
|
| 257 |
+
export PORT="443"
|
| 258 |
+
```
|
| 259 |
+
|
| 260 |
+
#### 3️⃣ Launch the Web Interface
|
| 261 |
+
|
| 262 |
+
**Basic Launch:**
|
| 263 |
+
```bash
|
| 264 |
+
sh run_app.sh
|
| 265 |
+
```
|
| 266 |
+
|
| 267 |
+
**With Performance Optimizations:**
|
| 268 |
+
```bash
|
| 269 |
+
# Use both optimizations for maximum performance
|
| 270 |
+
sh run_app.sh --moe-impl flashinfer --attn-impl flash_attention_2
|
| 271 |
+
```
|
| 272 |
+
|
| 273 |
+
#### 4️⃣ Access the Interface
|
| 274 |
+
|
| 275 |
+
> 🌐 **Web Interface:** Open your browser and navigate to `http://localhost:443` (or your configured port)
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
## 🧱 Models Cards
|
| 279 |
+
|
| 280 |
+
| Model | Params | Download | Recommended VRAM | Supported |
|
| 281 |
+
|---------------------------| --- | --- | --- | --- |
|
| 282 |
+
| HunyuanImage-3.0 | 80B total (13B active) | [HuggingFace](https://huggingface.co/tencent/HunyuanImage-3.0) | ≥ 3 × 80 GB | ✅ Text-to-Image
|
| 283 |
+
| HunyuanImage-3.0-Instruct | 80B total (13B active) | [HuggingFace](https://huggingface.co/tencent/HunyuanImage-3.0-Instruct) | ≥ 3 × 80 GB | ✅ Text-to-Image<br>✅ Prompt Self-Rewrite <br>✅ CoT Think
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
Notes:
|
| 288 |
+
- Install performance extras (FlashAttention, FlashInfer) for faster inference.
|
| 289 |
+
- Multi‑GPU inference is recommended for the Base model.
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
## 📝 Prompt Guide
|
| 293 |
+
|
| 294 |
+
### Manually Writing Prompts.
|
| 295 |
+
The Pretrain Checkpoint does not automatically rewrite or enhance input prompts, Instruct Checkpoint can rewrite or enhance input prompts with thinking . For optimal results currently, we recommend community partners consulting our official guide on how to write effective prompts.
|
| 296 |
+
|
| 297 |
+
Reference: [HunyuanImage 3.0 Prompt Handbook](
|
| 298 |
+
https://docs.qq.com/doc/DUVVadmhCdG9qRXBU)
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
### System Prompt For Automatic Rewriting the Prompt.
|
| 302 |
+
|
| 303 |
+
We've included two system prompts in the PE folder of this repository that leverage DeepSeek to automatically enhance user inputs:
|
| 304 |
+
|
| 305 |
+
* **system_prompt_universal**: This system prompt converts photographic style, artistic prompts into a detailed one.
|
| 306 |
+
* **system_prompt_text_rendering**: This system prompt converts UI/Poster/Text Rending prompts to a deailed on that suits the model.
|
| 307 |
+
|
| 308 |
+
Note that these system prompts are in Chinese because Deepseek works better with Chinese system prompts. If you want to use it for English oriented model, you may translate it into English or refer to the comments in the PE file as a guide.
|
| 309 |
+
|
| 310 |
+
We also create a [Yuanqi workflow](https://yuanqi.tencent.com/agent/H69VgtJdj3Dz) to implement the universal one, you can directly try it.
|
| 311 |
+
|
| 312 |
+
### Advanced Tips
|
| 313 |
+
- **Content Priority**: Focus on describing the main subject and action first, followed by details about the environment and style. A more general description framework is: **Main subject and scene + Image quality and style + Composition and perspective + Lighting and atmosphere + Technical parameters**. Keywords can be added both before and after this structure.
|
| 314 |
+
|
| 315 |
+
- **Image resolution**: Our model not only supports multiple resolutions but also offers both **automatic and specified resolution** options. In auto mode, the model automatically predicts the image resolution based on the input prompt. In specified mode (like traditional DiT), the model outputs an image resolution that strictly aligns with the user's chosen resolution.
|
| 316 |
+
|
| 317 |
+
### More Cases
|
| 318 |
+
Our model can follow complex instructions to generate high‑quality, creative images.
|
| 319 |
+
|
| 320 |
+
<div align="center">
|
| 321 |
+
<img src="./assets/banner_all.jpg" width=100% alt="HunyuanImage 3.0 Demo">
|
| 322 |
+
</div>
|
| 323 |
+
|
| 324 |
+
Our model can effectively process very long text inputs, enabling users to precisely control the finer details of generated images. Extended prompts allow for intricate elements to be accurately captured, making it ideal for complex projects requiring precision and creativity.
|
| 325 |
+
|
| 326 |
+
<p align="center">
|
| 327 |
+
<table>
|
| 328 |
+
<thead>
|
| 329 |
+
</thead>
|
| 330 |
+
<tbody>
|
| 331 |
+
<tr>
|
| 332 |
+
<td>
|
| 333 |
+
<img src="./assets/pg_imgs/image1.png" width=100%><details>
|
| 334 |
+
<summary>Show prompt</summary>
|
| 335 |
+
A cinematic medium shot captures a single Asian woman seated on a chair within a dimly lit room, creating an intimate and theatrical atmosphere. The composition is focused on the subject, rendered with rich colors and intricate textures that evoke a nostalgic and moody feeling.
|
| 336 |
+
|
| 337 |
+
The primary subject is a young Asian woman with a thoughtful and expressive countenance, her gaze directed slightly away from the camera. She is seated in a relaxed yet elegant posture on an ornate, vintage armchair. The chair is upholstered in a deep red velvet, its fabric showing detailed, intricate textures and slight signs of wear. She wears a simple, elegant dress in a dark teal hue, the material catching the light in a way that reveals its fine-woven texture. Her skin has a soft, matte quality, and the light delicately models the contours of her face and arms.
|
| 338 |
+
|
| 339 |
+
The surrounding room is characterized by its vintage decor, which contributes to the historic and evocative mood. In the immediate background, partially blurred due to a shallow depth of field consistent with a f/2.8 aperture, the wall is covered with wallpaper featuring a subtle, damask pattern. The overall color palette is a carefully balanced interplay of deep teal and rich red hues, creating a visually compelling and cohesive environment. The entire scene is detailed, from the fibers of the upholstery to the subtle patterns on the wall.
|
| 340 |
+
|
| 341 |
+
The lighting is highly dramatic and artistic, defined by high contrast and pronounced shadow play. A single key light source, positioned off-camera, projects gobo lighting patterns onto the scene, casting intricate shapes of light and shadow across the woman and the back wall. These dramatic shadows create a strong sense of depth and a theatrical quality. While some shadows are deep and defined, others remain soft, gently wrapping around the subject and preventing the loss of detail in darker areas. The soft focus on the background enhances the intimate feeling, drawing all attention to the expressive subject. The overall image presents a cinematic, photorealistic photography style.
|
| 342 |
+
</details>
|
| 343 |
+
</td>
|
| 344 |
+
<td><img src="./assets/pg_imgs/image2.png" width=100%><details>
|
| 345 |
+
<summary>Show prompt</summary>
|
| 346 |
+
A cinematic, photorealistic medium shot captures a high-contrast urban street corner, defined by the sharp intersection of light and shadow. The primary subject is the exterior corner of a building, rendered in a low-saturation, realistic style.
|
| 347 |
+
|
| 348 |
+
The building wall, which occupies the majority of the frame, is painted a warm orange with a finely detailed, rough stucco texture. Horizontal white stripes run across its surface. The base of the building is constructed from large, rough-hewn stone blocks, showing visible particles and texture. On the left, illuminated side of the building, there is a single window with closed, dark-colored shutters. Adjacent to the window, a simple black pendant lamp hangs from a thin, taut rope, casting a distinct, sharp-edged shadow onto the sunlit orange wall. The composition is split diagonally, with the right side of the building enveloped in a deep brown shadow. At the bottom of the frame, a smooth concrete sidewalk is visible, upon which the dynamic silhouette of a person is captured mid-stride, walking from right to left.
|
| 349 |
+
|
| 350 |
+
In the shallow background, the faint, out-of-focus outlines of another building and the bare, skeletal branches of trees are softly visible, contributing to the quiet urban atmosphere and adding a sense of depth to the scene. These elements are rendered with minimal detail to keep the focus on the foreground architecture.
|
| 351 |
+
|
| 352 |
+
The scene is illuminated by strong, natural sunlight originating from the upper left, creating a dramatic chiaroscuro effect. This hard light source casts deep, well-defined shadows, producing a sharp contrast between the brightly lit warm orange surfaces and the deep brown shadow areas. The lighting highlights the fine details in the wall texture and stone particles, emphasizing the photorealistic quality. The overall presentation reflects a high-quality photorealistic photography style, infused with a cinematic film noir aesthetic.
|
| 353 |
+
</details>
|
| 354 |
+
</td>
|
| 355 |
+
</tr>
|
| 356 |
+
<tr>
|
| 357 |
+
<td>
|
| 358 |
+
<img src="./assets/pg_imgs/image3.png" width=100%><details>
|
| 359 |
+
<summary>Show prompt</summary>
|
| 360 |
+
一幅极具视觉张力的杂志封面风格人像特写。画面主体是一个身着古风汉服的人物,构图采用了从肩部以上的超级近距离特写,人物占据了画面的绝大部分,形成了强烈的视觉冲击力。
|
| 361 |
+
|
| 362 |
+
画面中的人物以一种慵懒的姿态出现,微微倾斜着头部,裸露的一侧肩膀线条流畅。她正用一种妩媚而直接的眼神凝视着镜头,双眼微张,眼神深邃,传递出一种神秘而勾人的气质。人物的面部特征精致,皮肤质感细腻,在特定的光线下,面部轮廓清晰分明,展现出一种古典与现代融合的时尚美感。
|
| 363 |
+
|
| 364 |
+
整个画面的背景被设定为一种简约而高级的纯红色。这种红色色调深沉,呈现出哑光质感,既纯粹又无任何杂质,为整个暗黑神秘的氛围奠定了沉稳而富有张力的基调。这个纯色的背景有效地突出了前景中的人物主体,使得所有视觉焦点都集中在其身上。
|
| 365 |
+
|
| 366 |
+
光线和氛围的营造是这幅杂志风海报的关键。一束暗橘色的柔和光线作为主光源,从人���的一侧斜上方投射下来,精准地勾勒出人物的脸颊、鼻梁和肩膀的轮廓,在皮肤上形成微妙的光影过渡。同时,人物的周身萦绕着一层暗淡且低饱和度的银白色辉光,如同清冷的月光,形成一道朦胧的轮廓光。这道银辉为人物增添了几分疏离的幽灵感,强化了整体暗黑风格的神秘气质。光影的强烈对比与色彩的独特搭配,共同塑造了这张充满故事感的特写画面。整体图像呈现出一种融合了古典元素的现代时尚摄影风格。
|
| 367 |
+
</details>
|
| 368 |
+
</td>
|
| 369 |
+
<td>
|
| 370 |
+
<img src="./assets/pg_imgs/image4.png" width=100%><details>
|
| 371 |
+
<summary>Show prompt</summary>
|
| 372 |
+
一幅采用极简俯视视角的油画作品,画面主体由一道居中斜向的红色笔触构成。
|
| 373 |
+
|
| 374 |
+
这道醒目的红色笔触运用了厚涂技法,颜料堆叠形成了强烈的物理厚度和三维立体感。它从画面的左上角附近延伸至右下角附近,构成一个动态的对角线。颜料表面可以清晰地看到画刀刮擦和笔刷拖曳留下的痕迹,边缘处的颜料层相对较薄,而中央部分则高高隆起,形成了不规则的起伏。
|
| 375 |
+
|
| 376 |
+
在这道立体的红色颜料之上,巧妙地构建了一处精致的微缩景观。景观的核心是一片模拟红海滩的区域,由细腻的深红色颜料点缀而成,与下方基底的鲜红色形成丰富的层次对比。紧邻着“红海滩”的是一小片湖泊,由一层平滑且带有光泽的蓝色与白色混合颜料构成,质感如同平静无波的水面。湖泊边缘,一小撮芦苇丛生,由几根纤细挺拔的、用淡黄色和棕色颜料勾勒出的线条来表现。一只小巧的白鹭立于芦苇旁,其形态由一小块纯白色的厚涂颜料塑造,仅用一抹精炼的黑色颜料点出其尖喙,姿态优雅宁静。
|
| 377 |
+
|
| 378 |
+
整个构图的背景是大面积的留白,呈现为一张带有细微凹凸纹理的白色纸质基底,这种极简处理极大地突出了中央的红色笔触及其上的微缩景观。
|
| 379 |
+
|
| 380 |
+
光线从画面一侧柔和地照射下来,在厚涂的颜料堆叠处投下淡淡的、轮廓分明的阴影,进一步增强了画面的三维立体感和油画质感。整幅画面呈现出一种结合了厚涂技法的现代极简主义油画风格。
|
| 381 |
+
</details>
|
| 382 |
+
</td>
|
| 383 |
+
</tr>
|
| 384 |
+
<tr>
|
| 385 |
+
<td>
|
| 386 |
+
<img src="./assets/pg_imgs/image5.png" width=100%><details>
|
| 387 |
+
<summary>Show prompt</summary>
|
| 388 |
+
整体画面采用一个二乘二的四宫格布局,以产品可视化的风格,展示了一只兔子在四种不同材质下的渲染效果。每个宫格内都有一只姿态完全相同的兔子模型,它呈坐姿,双耳竖立,面朝前方。所有宫格的背景均是统一的中性深灰色,这种简约背景旨在最大限度地突出每种材质的独特质感。
|
| 389 |
+
|
| 390 |
+
左上角的宫格中,兔子模型由哑光白色石膏材质构成。其表面平滑、均匀且无反射,在模型的耳朵根部、四肢交接处等凹陷区域呈现出柔和的环境光遮蔽阴影,这种微妙的阴影变化凸显了其纯粹的几何形态,整体感觉像一个用于美术研究的基础模型。
|
| 391 |
+
|
| 392 |
+
右上角的宫格中,兔子模型由晶莹剔透的无瑕疵玻璃制成。它展现了逼真的物理折射效果,透过其透明的身体看到的背景呈现出轻微的扭曲。清晰的镜面高光沿着其身体的曲线轮廓流动,表面上还能看到微弱而清晰的环境反射,赋予其一种精致而易碎的质感。
|
| 393 |
+
|
| 394 |
+
左下角的宫格中,兔子模型呈现为带有拉丝纹理的钛金属材质。金属表面具有明显的各向异性反射效果,呈现出冷峻的灰调金属光泽。锐利明亮的高光和深邃的阴影形成了强烈对比,精确地定义了其坚固的三维形态,展现了工业设计般的美感。
|
| 395 |
+
|
| 396 |
+
右下角的宫格中,兔子模型覆盖着一层柔软浓密的灰色毛绒。根根分明的绒毛清晰可见,创造出一种温暖、可触摸的质地。光线照射在绒毛的末梢,形成柔和的光晕效果,而毛绒内部的阴影则显得深邃而柔软,展现了高度写实的毛发渲染效果。
|
| 397 |
+
|
| 398 |
+
整个四宫格由来自多个方向的、柔和均匀的影棚灯光照亮,确保了每种材质的细节和特性都得到清晰的展现,没有任何刺眼的阴影或过曝的高光。这张图像以一种高度写实的3D渲染风格呈现,完美地诠释了产品可视化的精髓
|
| 399 |
+
</details>
|
| 400 |
+
</td>
|
| 401 |
+
<td>
|
| 402 |
+
<img src="./assets/pg_imgs/image6.png" width=100%><details>
|
| 403 |
+
<summary>Show prompt</summary>
|
| 404 |
+
由一个两行两列的网格构成,共包含四个独立的场景,每个场景都以不同的艺术风格描绘了一个小男孩(小明)一天中的不同活动。
|
| 405 |
+
|
| 406 |
+
左上角的第一个场景,以超写实摄影风格呈现。画面主体是一个大约8岁的东亚小男孩,他穿着整洁的小学制服——一件白色短袖衬衫和蓝色短裤,脖子上系着红领巾。他背着一个蓝色的双肩书包,正走在去上学的路上。他位于画面的前景偏右侧,面带微笑,步伐轻快。场景设定���清晨,柔和的阳光从左上方照射下来,在人行道上投下清晰而柔和的影子。背景是绿树成荫的街道和模糊可见的学校铁艺大门,营造出宁静的早晨氛围。这张图片的细节表现极为丰富,可以清晰地看到男孩头发的光泽、衣服的褶皱纹理以及书包的帆布材质,完全展现了专业摄影的质感。
|
| 407 |
+
|
| 408 |
+
右上角的第二个场景,采用日式赛璐璐动漫风格绘制。画面中,小男孩坐在家中的木质餐桌旁吃午饭。他的形象被动漫化,拥有大而明亮的眼睛和简洁的五官线条。他身穿一件简单的黄色T恤,正用筷子夹起碗里的米饭。桌上摆放着一碗汤和两盘家常菜。背景是一个温馨的室内环境,一扇明亮的窗户透进正午的阳光,窗外是蓝天白云。整个画面色彩鲜艳、饱和度高,角色轮廓线清晰明确,阴影部分采用平涂的色块处理,是典型的赛璐璐动漫风格。
|
| 409 |
+
|
| 410 |
+
左下角的第三个场景,以细腻的铅笔素描风格呈现。画面描绘了下午在操场上踢足球的小男孩。整个图像由不同灰度的石墨色调构成,没有其他颜色。小男孩身穿运动短袖和短裤,身体呈前倾姿态,右脚正要踢向一个足球,动作充满动感。背景是空旷的操场和远处的球门,用简练的线条和排线勾勒。艺术家通过交叉排线和涂抹技巧来表现光影和体积感,足球上的阴影、人物身上的肌肉线条以及地面粗糙的质感都通过铅笔的笔触得到了充分的展现。这张铅笔画突出了素描的光影关系和线条美感。
|
| 411 |
+
|
| 412 |
+
右下角的第四个场景,以文森特·梵高的后印象派油画风格进行诠释。画面描绘了夜晚时分,小男孩独自在河边钓鱼的景象。他坐在一块岩石上,手持一根简易的钓鱼竿,身影在深蓝色的夜幕下显得很渺小。整个画面的视觉焦点是天空和水面,天空布满了旋转、卷曲的星云,星星和月亮被描绘成巨大、发光的光团,使用了厚涂的油画颜料(Impasto),笔触粗犷而充满能量。深蓝、亮黄和白色的颜料在画布上相互交织,形成强烈的视觉冲击力。水面倒映着天空中扭曲的光影,整个场景充满了梵高作品中特有的强烈情感和动荡不安的美感。这幅画作是对梵高风格的深度致敬。
|
| 413 |
+
</details>
|
| 414 |
+
</td>
|
| 415 |
+
</tr>
|
| 416 |
+
<tr>
|
| 417 |
+
<td>
|
| 418 |
+
<img src="./assets/pg_imgs/image7.png" width=100%><details>
|
| 419 |
+
<summary>Show prompt</summary>
|
| 420 |
+
以平视视角,呈现了一幅关于如何用素描技法绘制鹦鹉的九宫格教学图。整体构图规整,九个大小一致的方形画框以三行三列的形式均匀分布在浅灰色背景上,清晰地展示了从基本形状到最终成品的全过程。
|
| 421 |
+
|
| 422 |
+
第一行从左至右展示了绘画的初始步骤。左上角的第一个画框中,用简洁的铅笔线条勾勒出鹦鹉的基本几何形态:一个圆形代表头部,一个稍大的椭圆形代表身体。右上角有一个小号的无衬线字体数字“1”。中间的第二个画框中,在基础形态上添加了三角形的鸟喙轮廓和一条长长的弧线作为尾巴的雏形,头部和身体的连接处线条变得更加流畅;右上角标有数字“2”。右侧的第三个画框中,进一步精确了鹦鹉的整体轮廓,勾勒出头部顶端的羽冠和清晰的眼部圆形轮廓;右上角标有数字“3”。
|
| 423 |
+
|
| 424 |
+
第二行专注于结构与细节的添加,描绘了绘画的中期阶段。左侧的第四个画框里,鹦鹉的身体上添加了翅膀的基本形状,同时在身体下方画出了一根作为栖木的横向树枝,鹦鹉的爪子初步搭在树枝上;右上角标有数字“4”。中间的第五个画框中,开始细化翅膀和尾部的羽毛分组,用短促的线条表现出层次感,并清晰地画出爪子紧握树枝的细节;右上角标有数字“5”。右侧的第六个画框里,开始为鹦鹉添加初步的阴影,使用交叉排线的素描技法在腹部、翅膀下方和颈部制造出体积感;右上角标有数字“6”。
|
| 425 |
+
|
| 426 |
+
第三行则展示了最终的润色与完成阶段。左下角的第七个画框中,素描的排线更加密集,阴影层次更加丰富,羽毛的纹理细节被仔细刻画出来,眼珠也添加了高光点缀,显得炯炯有神;右上角标有数字“7”。中间的第八个画框里,描绘的重点转移到栖木上,增加了树枝的纹理和节疤细节,同时整体调整了鹦鹉身上的光影关系,使立体感更为突出;右上角标有数字“8”。右下角的第九个画框是最终完成图,所有线条都经过了精炼,光影对比强烈,鹦鹉的羽毛质感、木质栖木的粗糙感都表现得淋漓尽致,呈现出一幅完整且细节丰富的素描作品;右上角标有数字“9”。
|
| 427 |
+
|
| 428 |
+
整个画面的光线均匀而明亮,没有任何特定的光源方向,确保了每个教学步骤的视觉清晰度。整体呈现出一种清晰、有条理的数字插画教程风格。
|
| 429 |
+
</details>
|
| 430 |
+
</td>
|
| 431 |
+
<td>
|
| 432 |
+
<img src="./assets/pg_imgs/image8.png" width=100%><details>
|
| 433 |
+
<summary>Show prompt</summary>
|
| 434 |
+
一张现代平面设计风格的海报占据了整个画面,构图简洁且中心突出。
|
| 435 |
+
|
| 436 |
+
海报的主体是位于画面正中央的一只腾讯QQ企鹅。这只企鹅采用了圆润可爱的3D卡通渲染风格,身体主要为饱满的黑色,腹部为纯白色。它的眼睛大而圆,眼神好奇地直视前方。黄色的嘴巴小巧而立体,双脚同样为鲜明的黄色,稳稳地站立着。一条标志性的红色围巾整齐地系在它的脖子上,围巾的材质带有轻微的布料质感,末端自然下垂。企鹅的整体造型干净利落,边缘光滑,呈现出一种精致的数字插画质感。
|
| 437 |
+
|
| 438 |
+
海报的背景是一种从上到下由浅蓝色平滑过渡到白色的柔和渐变,营造出一种开阔、明亮的空间感。在企鹅的身后,散布着一些淡淡的、模糊的圆形光斑和几道柔和的抽象光束,为这个简约的平面设计海报增添了微妙的深度和科技感。
|
| 439 |
+
|
| 440 |
+
画面的底部区域是文字部分,排版居中对齐。上半部分是一行稍大的黑色黑体字,内容为“Hunyuan Image 3.0”。紧随其下的是一行字号略小的深灰色黑体字,内容为“原生多模态大模型”。两行文字清晰易读,与整体的现代平面设计风格保持一致。
|
| 441 |
+
|
| 442 |
+
整体光线明亮、均匀,没有明显的阴影,突出了企鹅和文字信息,符合现代设计海报的视觉要求。这张图像呈现了现代、简洁的平面设计海报风格。
|
| 443 |
+
</details>
|
| 444 |
+
</td>
|
| 445 |
+
</tr>
|
| 446 |
+
</tbody>
|
| 447 |
+
</table>
|
| 448 |
+
</p>
|
| 449 |
+
|
| 450 |
+
## 📊 Evaluation
|
| 451 |
+
|
| 452 |
+
* 🤖 **SSAE (Machine Evaluation)**
|
| 453 |
+
SSAE (Structured Semantic Alignment Evaluation) is an intelligent evaluation metric for image-text alignment based on advanced multimodal large language models (MLLMs). We extracted 3500 key points across 12 categories, then used multimodal large language models to automatically evaluate and score by comparing the generated images with these key points based on the visual content of the images. Mean Image Accuracy represents the image-wise average score across all key points, while Global Accuracy directly calculates the average score across all key points.
|
| 454 |
+
|
| 455 |
+
<p align="center">
|
| 456 |
+
<img src="./assets/ssae_side_by_side_comparison.png" width=98% alt="Human Evaluation with Other Models">
|
| 457 |
+
</p>
|
| 458 |
+
|
| 459 |
+
<p align="center">
|
| 460 |
+
<img src="./assets/ssae_side_by_side_heatmap.png" width=98% alt="Human Evaluation with Other Models">
|
| 461 |
+
</p>
|
| 462 |
+
|
| 463 |
+
|
| 464 |
+
* 👥 **GSB (Human Evaluation)**
|
| 465 |
+
|
| 466 |
+
We adopted the GSB (Good/Same/Bad) evaluation method commonly used to assess the relative performance between two models from an overall image perception perspective. In total, we utilized 1,000 text prompts, generating an equal number of image samples for all compared models in a single run. For a fair comparison, we conducted inference only once for each prompt, avoiding any cherry-picking of results. When comparing with the baseline methods, we maintained the default settings for all selected models. The evaluation was performed by more than 100 professional evaluators.
|
| 467 |
+
|
| 468 |
+
<p align="center">
|
| 469 |
+
<img src="./assets/gsb.png" width=98% alt="Human Evaluation with Other Models">
|
| 470 |
+
</p>
|
| 471 |
+
|
| 472 |
+
|
| 473 |
+
## 📚 Citation
|
| 474 |
+
|
| 475 |
+
If you find HunyuanImage-3.0 useful in your research, please cite our work:
|
| 476 |
+
|
| 477 |
+
```bibtex
|
| 478 |
+
@article{cao2025hunyuanimage,
|
| 479 |
+
title={HunyuanImage 3.0 Technical Report},
|
| 480 |
+
author={Cao, Siyu and Chen, Hangting and Chen, Peng and Cheng, Yiji and Cui, Yutao and Deng, Xinchi and Dong, Ying and Gong, Kipper and Gu, Tianpeng and Gu, Xiusen and others},
|
| 481 |
+
journal={arXiv preprint arXiv:2509.23951},
|
| 482 |
+
year={2025}
|
| 483 |
+
}
|
| 484 |
+
```
|
| 485 |
+
|
| 486 |
+
## 🙏 Acknowledgements
|
| 487 |
+
|
| 488 |
+
We extend our heartfelt gratitude to the following open-source projects and communities for their invaluable contributions:
|
| 489 |
+
|
| 490 |
+
* 🤗 [Transformers](https://github.com/huggingface/transformers) - State-of-the-art NLP library
|
| 491 |
+
* 🎨 [Diffusers](https://github.com/huggingface/diffusers) - Diffusion models library
|
| 492 |
+
* 🌐 [HuggingFace](https://huggingface.co/) - AI model hub and community
|
| 493 |
+
* ⚡ [FlashAttention](https://github.com/Dao-AILab/flash-attention) - Memory-efficient attention
|
| 494 |
+
* 🚀 [FlashInfer](https://github.com/flashinfer-ai/flashinfer) - Optimized inference engine
|
| 495 |
+
|
| 496 |
+
## 🌟🚀 Github Star History
|
| 497 |
+
|
| 498 |
+
[](https://github.com/Tencent-Hunyuan/HunyuanImage-3.0)
|
| 499 |
+
[](https://github.com/Tencent-Hunyuan/HunyuanImage-3.0)
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
[](https://www.star-history.com/#Tencent-Hunyuan/HunyuanImage-3.0&Date)
|
__init__.py
ADDED
|
File without changes
|
assets/WECHAT.md
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<div align="center">
|
| 2 |
+
<img src=wechat.png width="60%"/>
|
| 3 |
+
|
| 4 |
+
<p> 扫码关注混元图像系列工作,加入「 腾讯混元生图交流群 」 </p>
|
| 5 |
+
<p> Scan the QR code to join the "Tencent Hunyuan Image Generation Discussion Group" </p>
|
| 6 |
+
</div>
|
assets/banner.png
ADDED
|
Git LFS Details
|
assets/banner_all.jpg
ADDED
|
Git LFS Details
|
assets/framework.png
ADDED
|
Git LFS Details
|
assets/gsb.png
ADDED
|
Git LFS Details
|
assets/logo.png
ADDED
|
Git LFS Details
|
assets/pg_imgs/image1.png
ADDED
|
Git LFS Details
|
assets/pg_imgs/image2.png
ADDED
|
Git LFS Details
|
assets/pg_imgs/image3.png
ADDED
|
Git LFS Details
|
assets/pg_imgs/image4.png
ADDED
|
Git LFS Details
|
assets/pg_imgs/image5.png
ADDED
|
Git LFS Details
|
assets/pg_imgs/image6.png
ADDED
|
Git LFS Details
|
assets/pg_imgs/image7.png
ADDED
|
Git LFS Details
|
assets/pg_imgs/image8.png
ADDED
|
Git LFS Details
|
assets/robot.png
ADDED
|
Git LFS Details
|
assets/ssae_side_by_side_comparison.png
ADDED
|
Git LFS Details
|
assets/ssae_side_by_side_heatmap.png
ADDED
|
Git LFS Details
|
assets/user.png
ADDED
|
Git LFS Details
|
assets/wechat.png
ADDED
|
Git LFS Details
|
autoencoder_kl_3d.py
ADDED
|
@@ -0,0 +1,793 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Licensed under the TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT (the "License");
|
| 2 |
+
# you may not use this file except in compliance with the License.
|
| 3 |
+
# You may obtain a copy of the License at
|
| 4 |
+
#
|
| 5 |
+
# https://github.com/Tencent-Hunyuan/HunyuanImage-3.0/blob/main/LICENSE
|
| 6 |
+
#
|
| 7 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 8 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 9 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 10 |
+
# See the License for the specific language governing permissions and
|
| 11 |
+
# limitations under the License.
|
| 12 |
+
# ==============================================================================
|
| 13 |
+
|
| 14 |
+
from dataclasses import dataclass
|
| 15 |
+
from typing import Tuple, Optional
|
| 16 |
+
import math
|
| 17 |
+
import random
|
| 18 |
+
import numpy as np
|
| 19 |
+
from einops import rearrange
|
| 20 |
+
import torch
|
| 21 |
+
from torch import Tensor, nn
|
| 22 |
+
import torch.nn.functional as F
|
| 23 |
+
|
| 24 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 25 |
+
from diffusers.models.modeling_outputs import AutoencoderKLOutput
|
| 26 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 27 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 28 |
+
from diffusers.utils import BaseOutput
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class DiagonalGaussianDistribution(object):
|
| 32 |
+
def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
|
| 33 |
+
if parameters.ndim == 3:
|
| 34 |
+
dim = 2 # (B, L, C)
|
| 35 |
+
elif parameters.ndim == 5 or parameters.ndim == 4:
|
| 36 |
+
dim = 1 # (B, C, T, H ,W) / (B, C, H, W)
|
| 37 |
+
else:
|
| 38 |
+
raise NotImplementedError
|
| 39 |
+
self.parameters = parameters
|
| 40 |
+
self.mean, self.logvar = torch.chunk(parameters, 2, dim=dim)
|
| 41 |
+
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
| 42 |
+
self.deterministic = deterministic
|
| 43 |
+
self.std = torch.exp(0.5 * self.logvar)
|
| 44 |
+
self.var = torch.exp(self.logvar)
|
| 45 |
+
if self.deterministic:
|
| 46 |
+
self.var = self.std = torch.zeros_like(
|
| 47 |
+
self.mean, device=self.parameters.device, dtype=self.parameters.dtype
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
|
| 51 |
+
# make sure sample is on the same device as the parameters and has same dtype
|
| 52 |
+
sample = randn_tensor(
|
| 53 |
+
self.mean.shape,
|
| 54 |
+
generator=generator,
|
| 55 |
+
device=self.parameters.device,
|
| 56 |
+
dtype=self.parameters.dtype,
|
| 57 |
+
)
|
| 58 |
+
x = self.mean + self.std * sample
|
| 59 |
+
return x
|
| 60 |
+
|
| 61 |
+
def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor:
|
| 62 |
+
if self.deterministic:
|
| 63 |
+
return torch.Tensor([0.0])
|
| 64 |
+
else:
|
| 65 |
+
reduce_dim = list(range(1, self.mean.ndim))
|
| 66 |
+
if other is None:
|
| 67 |
+
return 0.5 * torch.sum(
|
| 68 |
+
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
|
| 69 |
+
dim=reduce_dim,
|
| 70 |
+
)
|
| 71 |
+
else:
|
| 72 |
+
return 0.5 * torch.sum(
|
| 73 |
+
torch.pow(self.mean - other.mean, 2) / other.var +
|
| 74 |
+
self.var / other.var -
|
| 75 |
+
1.0 -
|
| 76 |
+
self.logvar +
|
| 77 |
+
other.logvar,
|
| 78 |
+
dim=reduce_dim,
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
def nll(self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]) -> torch.Tensor:
|
| 82 |
+
if self.deterministic:
|
| 83 |
+
return torch.Tensor([0.0])
|
| 84 |
+
logtwopi = np.log(2.0 * np.pi)
|
| 85 |
+
return 0.5 * torch.sum(
|
| 86 |
+
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
|
| 87 |
+
dim=dims,
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
def mode(self) -> torch.Tensor:
|
| 91 |
+
return self.mean
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
@dataclass
|
| 95 |
+
class DecoderOutput(BaseOutput):
|
| 96 |
+
sample: torch.FloatTensor
|
| 97 |
+
posterior: Optional[DiagonalGaussianDistribution] = None
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def swish(x: Tensor) -> Tensor:
|
| 101 |
+
return x * torch.sigmoid(x)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def forward_with_checkpointing(module, *inputs, use_checkpointing=False):
|
| 105 |
+
def create_custom_forward(module):
|
| 106 |
+
def custom_forward(*inputs):
|
| 107 |
+
return module(*inputs)
|
| 108 |
+
return custom_forward
|
| 109 |
+
|
| 110 |
+
if use_checkpointing:
|
| 111 |
+
return torch.utils.checkpoint.checkpoint(create_custom_forward(module), *inputs, use_reentrant=False)
|
| 112 |
+
else:
|
| 113 |
+
return module(*inputs)
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
class Conv3d(nn.Conv3d):
|
| 117 |
+
"""
|
| 118 |
+
Perform Conv3d on patches with numerical differences from nn.Conv3d within 1e-5.
|
| 119 |
+
Only symmetric padding is supported.
|
| 120 |
+
"""
|
| 121 |
+
|
| 122 |
+
def forward(self, input):
|
| 123 |
+
B, C, T, H, W = input.shape
|
| 124 |
+
memory_count = (C * T * H * W) * 2 / 1024**3
|
| 125 |
+
if memory_count > 2:
|
| 126 |
+
n_split = math.ceil(memory_count / 2)
|
| 127 |
+
assert n_split >= 2
|
| 128 |
+
chunks = torch.chunk(input, chunks=n_split, dim=-3)
|
| 129 |
+
padded_chunks = []
|
| 130 |
+
for i in range(len(chunks)):
|
| 131 |
+
if self.padding[0] > 0:
|
| 132 |
+
padded_chunk = F.pad(
|
| 133 |
+
chunks[i],
|
| 134 |
+
(0, 0, 0, 0, self.padding[0], self.padding[0]),
|
| 135 |
+
mode="constant" if self.padding_mode == "zeros" else self.padding_mode,
|
| 136 |
+
value=0,
|
| 137 |
+
)
|
| 138 |
+
if i > 0:
|
| 139 |
+
padded_chunk[:, :, :self.padding[0]] = chunks[i - 1][:, :, -self.padding[0]:]
|
| 140 |
+
if i < len(chunks) - 1:
|
| 141 |
+
padded_chunk[:, :, -self.padding[0]:] = chunks[i + 1][:, :, :self.padding[0]]
|
| 142 |
+
else:
|
| 143 |
+
padded_chunk = chunks[i]
|
| 144 |
+
padded_chunks.append(padded_chunk)
|
| 145 |
+
padding_bak = self.padding
|
| 146 |
+
self.padding = (0, self.padding[1], self.padding[2])
|
| 147 |
+
outputs = []
|
| 148 |
+
for i in range(len(padded_chunks)):
|
| 149 |
+
outputs.append(super().forward(padded_chunks[i]))
|
| 150 |
+
self.padding = padding_bak
|
| 151 |
+
return torch.cat(outputs, dim=-3)
|
| 152 |
+
else:
|
| 153 |
+
return super().forward(input)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
class AttnBlock(nn.Module):
|
| 157 |
+
""" Attention with torch sdpa implementation. """
|
| 158 |
+
def __init__(self, in_channels: int):
|
| 159 |
+
super().__init__()
|
| 160 |
+
self.in_channels = in_channels
|
| 161 |
+
|
| 162 |
+
self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
| 163 |
+
|
| 164 |
+
self.q = Conv3d(in_channels, in_channels, kernel_size=1)
|
| 165 |
+
self.k = Conv3d(in_channels, in_channels, kernel_size=1)
|
| 166 |
+
self.v = Conv3d(in_channels, in_channels, kernel_size=1)
|
| 167 |
+
self.proj_out = Conv3d(in_channels, in_channels, kernel_size=1)
|
| 168 |
+
|
| 169 |
+
def attention(self, h_: Tensor) -> Tensor:
|
| 170 |
+
h_ = self.norm(h_)
|
| 171 |
+
q = self.q(h_)
|
| 172 |
+
k = self.k(h_)
|
| 173 |
+
v = self.v(h_)
|
| 174 |
+
|
| 175 |
+
b, c, f, h, w = q.shape
|
| 176 |
+
q = rearrange(q, "b c f h w -> b 1 (f h w) c").contiguous()
|
| 177 |
+
k = rearrange(k, "b c f h w -> b 1 (f h w) c").contiguous()
|
| 178 |
+
v = rearrange(v, "b c f h w -> b 1 (f h w) c").contiguous()
|
| 179 |
+
h_ = nn.functional.scaled_dot_product_attention(q, k, v)
|
| 180 |
+
|
| 181 |
+
return rearrange(h_, "b 1 (f h w) c -> b c f h w", f=f, h=h, w=w, c=c, b=b)
|
| 182 |
+
|
| 183 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 184 |
+
return x + self.proj_out(self.attention(x))
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
class ResnetBlock(nn.Module):
|
| 188 |
+
def __init__(self, in_channels: int, out_channels: int):
|
| 189 |
+
super().__init__()
|
| 190 |
+
self.in_channels = in_channels
|
| 191 |
+
out_channels = in_channels if out_channels is None else out_channels
|
| 192 |
+
self.out_channels = out_channels
|
| 193 |
+
|
| 194 |
+
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
| 195 |
+
self.conv1 = Conv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
| 196 |
+
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
|
| 197 |
+
self.conv2 = Conv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
| 198 |
+
if self.in_channels != self.out_channels:
|
| 199 |
+
self.nin_shortcut = Conv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
| 200 |
+
|
| 201 |
+
def forward(self, x):
|
| 202 |
+
h = x
|
| 203 |
+
h = self.norm1(h)
|
| 204 |
+
h = swish(h)
|
| 205 |
+
h = self.conv1(h)
|
| 206 |
+
|
| 207 |
+
h = self.norm2(h)
|
| 208 |
+
h = swish(h)
|
| 209 |
+
h = self.conv2(h)
|
| 210 |
+
|
| 211 |
+
if self.in_channels != self.out_channels:
|
| 212 |
+
x = self.nin_shortcut(x)
|
| 213 |
+
return x + h
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
class Downsample(nn.Module):
|
| 217 |
+
def __init__(self, in_channels: int, add_temporal_downsample: bool = True):
|
| 218 |
+
super().__init__()
|
| 219 |
+
self.add_temporal_downsample = add_temporal_downsample
|
| 220 |
+
stride = (2, 2, 2) if add_temporal_downsample else (1, 2, 2) # THW
|
| 221 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
| 222 |
+
self.conv = Conv3d(in_channels, in_channels, kernel_size=3, stride=stride, padding=0)
|
| 223 |
+
|
| 224 |
+
def forward(self, x: Tensor):
|
| 225 |
+
spatial_pad = (0, 1, 0, 1, 0, 0) # WHT
|
| 226 |
+
x = nn.functional.pad(x, spatial_pad, mode="constant", value=0)
|
| 227 |
+
|
| 228 |
+
temporal_pad = (0, 0, 0, 0, 0, 1) if self.add_temporal_downsample else (0, 0, 0, 0, 1, 1)
|
| 229 |
+
x = nn.functional.pad(x, temporal_pad, mode="replicate")
|
| 230 |
+
|
| 231 |
+
x = self.conv(x)
|
| 232 |
+
return x
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
class DownsampleDCAE(nn.Module):
|
| 236 |
+
def __init__(self, in_channels: int, out_channels: int, add_temporal_downsample: bool = True):
|
| 237 |
+
super().__init__()
|
| 238 |
+
factor = 2 * 2 * 2 if add_temporal_downsample else 1 * 2 * 2
|
| 239 |
+
assert out_channels % factor == 0
|
| 240 |
+
self.conv = Conv3d(in_channels, out_channels // factor, kernel_size=3, stride=1, padding=1)
|
| 241 |
+
|
| 242 |
+
self.add_temporal_downsample = add_temporal_downsample
|
| 243 |
+
self.group_size = factor * in_channels // out_channels
|
| 244 |
+
|
| 245 |
+
def forward(self, x: Tensor):
|
| 246 |
+
r1 = 2 if self.add_temporal_downsample else 1
|
| 247 |
+
h = self.conv(x)
|
| 248 |
+
h = rearrange(h, "b c (f r1) (h r2) (w r3) -> b (r1 r2 r3 c) f h w", r1=r1, r2=2, r3=2)
|
| 249 |
+
shortcut = rearrange(x, "b c (f r1) (h r2) (w r3) -> b (r1 r2 r3 c) f h w", r1=r1, r2=2, r3=2)
|
| 250 |
+
|
| 251 |
+
B, C, T, H, W = shortcut.shape
|
| 252 |
+
shortcut = shortcut.view(B, h.shape[1], self.group_size, T, H, W).mean(dim=2)
|
| 253 |
+
return h + shortcut
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
class Upsample(nn.Module):
|
| 257 |
+
def __init__(self, in_channels: int, add_temporal_upsample: bool = True):
|
| 258 |
+
super().__init__()
|
| 259 |
+
self.add_temporal_upsample = add_temporal_upsample
|
| 260 |
+
self.scale_factor = (2, 2, 2) if add_temporal_upsample else (1, 2, 2) # THW
|
| 261 |
+
self.conv = Conv3d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
| 262 |
+
|
| 263 |
+
def forward(self, x: Tensor):
|
| 264 |
+
x = nn.functional.interpolate(x, scale_factor=self.scale_factor, mode="nearest")
|
| 265 |
+
x = self.conv(x)
|
| 266 |
+
return x
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
class UpsampleDCAE(nn.Module):
|
| 270 |
+
def __init__(self, in_channels: int, out_channels: int, add_temporal_upsample: bool = True):
|
| 271 |
+
super().__init__()
|
| 272 |
+
factor = 2 * 2 * 2 if add_temporal_upsample else 1 * 2 * 2
|
| 273 |
+
self.conv = Conv3d(in_channels, out_channels * factor, kernel_size=3, stride=1, padding=1)
|
| 274 |
+
|
| 275 |
+
self.add_temporal_upsample = add_temporal_upsample
|
| 276 |
+
self.repeats = factor * out_channels // in_channels
|
| 277 |
+
|
| 278 |
+
def forward(self, x: Tensor):
|
| 279 |
+
r1 = 2 if self.add_temporal_upsample else 1
|
| 280 |
+
h = self.conv(x)
|
| 281 |
+
h = rearrange(h, "b (r1 r2 r3 c) f h w -> b c (f r1) (h r2) (w r3)", r1=r1, r2=2, r3=2)
|
| 282 |
+
shortcut = x.repeat_interleave(repeats=self.repeats, dim=1)
|
| 283 |
+
shortcut = rearrange(shortcut, "b (r1 r2 r3 c) f h w -> b c (f r1) (h r2) (w r3)", r1=r1, r2=2, r3=2)
|
| 284 |
+
return h + shortcut
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
class Encoder(nn.Module):
|
| 288 |
+
"""
|
| 289 |
+
The encoder network of AutoencoderKLConv3D.
|
| 290 |
+
"""
|
| 291 |
+
def __init__(
|
| 292 |
+
self,
|
| 293 |
+
in_channels: int,
|
| 294 |
+
z_channels: int,
|
| 295 |
+
block_out_channels: Tuple[int, ...],
|
| 296 |
+
num_res_blocks: int,
|
| 297 |
+
ffactor_spatial: int,
|
| 298 |
+
ffactor_temporal: int,
|
| 299 |
+
downsample_match_channel: bool = True,
|
| 300 |
+
):
|
| 301 |
+
super().__init__()
|
| 302 |
+
assert block_out_channels[-1] % (2 * z_channels) == 0
|
| 303 |
+
|
| 304 |
+
self.z_channels = z_channels
|
| 305 |
+
self.block_out_channels = block_out_channels
|
| 306 |
+
self.num_res_blocks = num_res_blocks
|
| 307 |
+
|
| 308 |
+
# downsampling
|
| 309 |
+
self.conv_in = Conv3d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
|
| 310 |
+
|
| 311 |
+
self.down = nn.ModuleList()
|
| 312 |
+
block_in = block_out_channels[0]
|
| 313 |
+
for i_level, ch in enumerate(block_out_channels):
|
| 314 |
+
block = nn.ModuleList()
|
| 315 |
+
block_out = ch
|
| 316 |
+
for _ in range(self.num_res_blocks):
|
| 317 |
+
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
| 318 |
+
block_in = block_out
|
| 319 |
+
down = nn.Module()
|
| 320 |
+
down.block = block
|
| 321 |
+
|
| 322 |
+
add_spatial_downsample = bool(i_level < np.log2(ffactor_spatial))
|
| 323 |
+
add_temporal_downsample = (add_spatial_downsample and
|
| 324 |
+
bool(i_level >= np.log2(ffactor_spatial // ffactor_temporal)))
|
| 325 |
+
if add_spatial_downsample or add_temporal_downsample:
|
| 326 |
+
assert i_level < len(block_out_channels) - 1
|
| 327 |
+
block_out = block_out_channels[i_level + 1] if downsample_match_channel else block_in
|
| 328 |
+
down.downsample = DownsampleDCAE(block_in, block_out, add_temporal_downsample)
|
| 329 |
+
block_in = block_out
|
| 330 |
+
self.down.append(down)
|
| 331 |
+
|
| 332 |
+
# middle
|
| 333 |
+
self.mid = nn.Module()
|
| 334 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
| 335 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
| 336 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
| 337 |
+
|
| 338 |
+
# end
|
| 339 |
+
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
|
| 340 |
+
self.conv_out = Conv3d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
|
| 341 |
+
|
| 342 |
+
self.gradient_checkpointing = False
|
| 343 |
+
|
| 344 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 345 |
+
use_checkpointing = bool(self.training and self.gradient_checkpointing)
|
| 346 |
+
|
| 347 |
+
# downsampling
|
| 348 |
+
h = self.conv_in(x)
|
| 349 |
+
for i_level in range(len(self.block_out_channels)):
|
| 350 |
+
for i_block in range(self.num_res_blocks):
|
| 351 |
+
h = forward_with_checkpointing(
|
| 352 |
+
self.down[i_level].block[i_block], h, use_checkpointing=use_checkpointing)
|
| 353 |
+
if hasattr(self.down[i_level], "downsample"):
|
| 354 |
+
h = forward_with_checkpointing(self.down[i_level].downsample, h, use_checkpointing=use_checkpointing)
|
| 355 |
+
|
| 356 |
+
# middle
|
| 357 |
+
h = forward_with_checkpointing(self.mid.block_1, h, use_checkpointing=use_checkpointing)
|
| 358 |
+
h = forward_with_checkpointing(self.mid.attn_1, h, use_checkpointing=use_checkpointing)
|
| 359 |
+
h = forward_with_checkpointing(self.mid.block_2, h, use_checkpointing=use_checkpointing)
|
| 360 |
+
|
| 361 |
+
# end
|
| 362 |
+
group_size = self.block_out_channels[-1] // (2 * self.z_channels)
|
| 363 |
+
shortcut = rearrange(h, "b (c r) f h w -> b c r f h w", r=group_size).mean(dim=2)
|
| 364 |
+
h = self.norm_out(h)
|
| 365 |
+
h = swish(h)
|
| 366 |
+
h = self.conv_out(h)
|
| 367 |
+
h += shortcut
|
| 368 |
+
return h
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
class Decoder(nn.Module):
|
| 372 |
+
"""
|
| 373 |
+
The decoder network of AutoencoderKLConv3D.
|
| 374 |
+
"""
|
| 375 |
+
def __init__(
|
| 376 |
+
self,
|
| 377 |
+
z_channels: int,
|
| 378 |
+
out_channels: int,
|
| 379 |
+
block_out_channels: Tuple[int, ...],
|
| 380 |
+
num_res_blocks: int,
|
| 381 |
+
ffactor_spatial: int,
|
| 382 |
+
ffactor_temporal: int,
|
| 383 |
+
upsample_match_channel: bool = True,
|
| 384 |
+
):
|
| 385 |
+
super().__init__()
|
| 386 |
+
assert block_out_channels[0] % z_channels == 0
|
| 387 |
+
|
| 388 |
+
self.z_channels = z_channels
|
| 389 |
+
self.block_out_channels = block_out_channels
|
| 390 |
+
self.num_res_blocks = num_res_blocks
|
| 391 |
+
|
| 392 |
+
# z to block_in
|
| 393 |
+
block_in = block_out_channels[0]
|
| 394 |
+
self.conv_in = Conv3d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
|
| 395 |
+
|
| 396 |
+
# middle
|
| 397 |
+
self.mid = nn.Module()
|
| 398 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
| 399 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
| 400 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
| 401 |
+
|
| 402 |
+
# upsampling
|
| 403 |
+
self.up = nn.ModuleList()
|
| 404 |
+
for i_level, ch in enumerate(block_out_channels):
|
| 405 |
+
block = nn.ModuleList()
|
| 406 |
+
block_out = ch
|
| 407 |
+
for _ in range(self.num_res_blocks + 1):
|
| 408 |
+
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
| 409 |
+
block_in = block_out
|
| 410 |
+
up = nn.Module()
|
| 411 |
+
up.block = block
|
| 412 |
+
|
| 413 |
+
add_spatial_upsample = bool(i_level < np.log2(ffactor_spatial))
|
| 414 |
+
add_temporal_upsample = bool(i_level < np.log2(ffactor_temporal))
|
| 415 |
+
if add_spatial_upsample or add_temporal_upsample:
|
| 416 |
+
assert i_level < len(block_out_channels) - 1
|
| 417 |
+
block_out = block_out_channels[i_level + 1] if upsample_match_channel else block_in
|
| 418 |
+
up.upsample = UpsampleDCAE(block_in, block_out, add_temporal_upsample)
|
| 419 |
+
block_in = block_out
|
| 420 |
+
self.up.append(up)
|
| 421 |
+
|
| 422 |
+
# end
|
| 423 |
+
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
|
| 424 |
+
self.conv_out = Conv3d(block_in, out_channels, kernel_size=3, stride=1, padding=1)
|
| 425 |
+
|
| 426 |
+
self.gradient_checkpointing = False
|
| 427 |
+
|
| 428 |
+
def forward(self, z: Tensor) -> Tensor:
|
| 429 |
+
use_checkpointing = bool(self.training and self.gradient_checkpointing)
|
| 430 |
+
|
| 431 |
+
# z to block_in
|
| 432 |
+
repeats = self.block_out_channels[0] // (self.z_channels)
|
| 433 |
+
h = self.conv_in(z) + z.repeat_interleave(repeats=repeats, dim=1)
|
| 434 |
+
|
| 435 |
+
# middle
|
| 436 |
+
h = forward_with_checkpointing(self.mid.block_1, h, use_checkpointing=use_checkpointing)
|
| 437 |
+
h = forward_with_checkpointing(self.mid.attn_1, h, use_checkpointing=use_checkpointing)
|
| 438 |
+
h = forward_with_checkpointing(self.mid.block_2, h, use_checkpointing=use_checkpointing)
|
| 439 |
+
|
| 440 |
+
# upsampling
|
| 441 |
+
for i_level in range(len(self.block_out_channels)):
|
| 442 |
+
for i_block in range(self.num_res_blocks + 1):
|
| 443 |
+
h = forward_with_checkpointing(self.up[i_level].block[i_block], h, use_checkpointing=use_checkpointing)
|
| 444 |
+
if hasattr(self.up[i_level], "upsample"):
|
| 445 |
+
h = forward_with_checkpointing(self.up[i_level].upsample, h, use_checkpointing=use_checkpointing)
|
| 446 |
+
|
| 447 |
+
# end
|
| 448 |
+
h = self.norm_out(h)
|
| 449 |
+
h = swish(h)
|
| 450 |
+
h = self.conv_out(h)
|
| 451 |
+
return h
|
| 452 |
+
|
| 453 |
+
|
| 454 |
+
class AutoencoderKLConv3D(ModelMixin, ConfigMixin):
|
| 455 |
+
"""
|
| 456 |
+
Autoencoder model with KL-regularized latent space based on 3D convolutions.
|
| 457 |
+
"""
|
| 458 |
+
_supports_gradient_checkpointing = True
|
| 459 |
+
|
| 460 |
+
@register_to_config
|
| 461 |
+
def __init__(
|
| 462 |
+
self,
|
| 463 |
+
in_channels: int,
|
| 464 |
+
out_channels: int,
|
| 465 |
+
latent_channels: int,
|
| 466 |
+
block_out_channels: Tuple[int, ...],
|
| 467 |
+
layers_per_block: int,
|
| 468 |
+
ffactor_spatial: int,
|
| 469 |
+
ffactor_temporal: int,
|
| 470 |
+
sample_size: int,
|
| 471 |
+
sample_tsize: int,
|
| 472 |
+
scaling_factor: float = None,
|
| 473 |
+
shift_factor: Optional[float] = None,
|
| 474 |
+
downsample_match_channel: bool = True,
|
| 475 |
+
upsample_match_channel: bool = True,
|
| 476 |
+
only_encoder: bool = False, # only build encoder for saving memory
|
| 477 |
+
only_decoder: bool = False, # only build decoder for saving memory
|
| 478 |
+
):
|
| 479 |
+
super().__init__()
|
| 480 |
+
self.ffactor_spatial = ffactor_spatial
|
| 481 |
+
self.ffactor_temporal = ffactor_temporal
|
| 482 |
+
self.scaling_factor = scaling_factor
|
| 483 |
+
self.shift_factor = shift_factor
|
| 484 |
+
|
| 485 |
+
# build model
|
| 486 |
+
if not only_decoder:
|
| 487 |
+
self.encoder = Encoder(
|
| 488 |
+
in_channels=in_channels,
|
| 489 |
+
z_channels=latent_channels,
|
| 490 |
+
block_out_channels=block_out_channels,
|
| 491 |
+
num_res_blocks=layers_per_block,
|
| 492 |
+
ffactor_spatial=ffactor_spatial,
|
| 493 |
+
ffactor_temporal=ffactor_temporal,
|
| 494 |
+
downsample_match_channel=downsample_match_channel,
|
| 495 |
+
)
|
| 496 |
+
if not only_encoder:
|
| 497 |
+
self.decoder = Decoder(
|
| 498 |
+
z_channels=latent_channels,
|
| 499 |
+
out_channels=out_channels,
|
| 500 |
+
block_out_channels=list(reversed(block_out_channels)),
|
| 501 |
+
num_res_blocks=layers_per_block,
|
| 502 |
+
ffactor_spatial=ffactor_spatial,
|
| 503 |
+
ffactor_temporal=ffactor_temporal,
|
| 504 |
+
upsample_match_channel=upsample_match_channel,
|
| 505 |
+
)
|
| 506 |
+
|
| 507 |
+
# slicing and tiling related
|
| 508 |
+
self.use_slicing = False
|
| 509 |
+
self.slicing_bsz = 1
|
| 510 |
+
self.use_spatial_tiling = False
|
| 511 |
+
self.use_temporal_tiling = False
|
| 512 |
+
self.use_tiling_during_training = False
|
| 513 |
+
|
| 514 |
+
# only relevant if vae tiling is enabled
|
| 515 |
+
self.tile_sample_min_size = sample_size
|
| 516 |
+
self.tile_latent_min_size = sample_size // ffactor_spatial
|
| 517 |
+
self.tile_sample_min_tsize = sample_tsize
|
| 518 |
+
self.tile_latent_min_tsize = sample_tsize // ffactor_temporal
|
| 519 |
+
self.tile_overlap_factor = 0.25
|
| 520 |
+
|
| 521 |
+
# use torch.compile for faster encode speed
|
| 522 |
+
self.use_compile = False
|
| 523 |
+
|
| 524 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
| 525 |
+
if isinstance(module, (Encoder, Decoder)):
|
| 526 |
+
module.gradient_checkpointing = value
|
| 527 |
+
|
| 528 |
+
def enable_tiling_during_training(self, use_tiling: bool = True):
|
| 529 |
+
self.use_tiling_during_training = use_tiling
|
| 530 |
+
|
| 531 |
+
def disable_tiling_during_training(self):
|
| 532 |
+
self.enable_tiling_during_training(False)
|
| 533 |
+
|
| 534 |
+
def enable_temporal_tiling(self, use_tiling: bool = True):
|
| 535 |
+
self.use_temporal_tiling = use_tiling
|
| 536 |
+
|
| 537 |
+
def disable_temporal_tiling(self):
|
| 538 |
+
self.enable_temporal_tiling(False)
|
| 539 |
+
|
| 540 |
+
def enable_spatial_tiling(self, use_tiling: bool = True):
|
| 541 |
+
self.use_spatial_tiling = use_tiling
|
| 542 |
+
|
| 543 |
+
def disable_spatial_tiling(self):
|
| 544 |
+
self.enable_spatial_tiling(False)
|
| 545 |
+
|
| 546 |
+
def enable_tiling(self, use_tiling: bool = True):
|
| 547 |
+
self.enable_spatial_tiling(use_tiling)
|
| 548 |
+
|
| 549 |
+
def disable_tiling(self):
|
| 550 |
+
self.disable_spatial_tiling()
|
| 551 |
+
|
| 552 |
+
def enable_slicing(self):
|
| 553 |
+
self.use_slicing = True
|
| 554 |
+
|
| 555 |
+
def disable_slicing(self):
|
| 556 |
+
self.use_slicing = False
|
| 557 |
+
|
| 558 |
+
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int):
|
| 559 |
+
blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
|
| 560 |
+
for x in range(blend_extent):
|
| 561 |
+
b[:, :, :, :, x] = \
|
| 562 |
+
a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (x / blend_extent)
|
| 563 |
+
return b
|
| 564 |
+
|
| 565 |
+
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int):
|
| 566 |
+
blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
|
| 567 |
+
for y in range(blend_extent):
|
| 568 |
+
b[:, :, :, y, :] = \
|
| 569 |
+
a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (y / blend_extent)
|
| 570 |
+
return b
|
| 571 |
+
|
| 572 |
+
def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int):
|
| 573 |
+
blend_extent = min(a.shape[-3], b.shape[-3], blend_extent)
|
| 574 |
+
for x in range(blend_extent):
|
| 575 |
+
b[:, :, x, :, :] = \
|
| 576 |
+
a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * (x / blend_extent)
|
| 577 |
+
return b
|
| 578 |
+
|
| 579 |
+
def spatial_tiled_encode(self, x: torch.Tensor):
|
| 580 |
+
""" spatial tailing for frames """
|
| 581 |
+
B, C, T, H, W = x.shape
|
| 582 |
+
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) # 256 * (1 - 0.25) = 192
|
| 583 |
+
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) # 8 * 0.25 = 2
|
| 584 |
+
row_limit = self.tile_latent_min_size - blend_extent # 8 - 2 = 6
|
| 585 |
+
|
| 586 |
+
rows = []
|
| 587 |
+
for i in range(0, H, overlap_size):
|
| 588 |
+
row = []
|
| 589 |
+
for j in range(0, W, overlap_size):
|
| 590 |
+
tile = x[:, :, :, i: i + self.tile_sample_min_size, j: j + self.tile_sample_min_size]
|
| 591 |
+
tile = self.encoder(tile)
|
| 592 |
+
row.append(tile)
|
| 593 |
+
rows.append(row)
|
| 594 |
+
result_rows = []
|
| 595 |
+
for i, row in enumerate(rows):
|
| 596 |
+
result_row = []
|
| 597 |
+
for j, tile in enumerate(row):
|
| 598 |
+
if i > 0:
|
| 599 |
+
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
| 600 |
+
if j > 0:
|
| 601 |
+
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
| 602 |
+
result_row.append(tile[:, :, :, :row_limit, :row_limit])
|
| 603 |
+
result_rows.append(torch.cat(result_row, dim=-1))
|
| 604 |
+
moments = torch.cat(result_rows, dim=-2)
|
| 605 |
+
return moments
|
| 606 |
+
|
| 607 |
+
def temporal_tiled_encode(self, x: torch.Tensor):
|
| 608 |
+
""" temporal tailing for frames """
|
| 609 |
+
B, C, T, H, W = x.shape
|
| 610 |
+
overlap_size = int(self.tile_sample_min_tsize * (1 - self.tile_overlap_factor)) # 64 * (1 - 0.25) = 48
|
| 611 |
+
blend_extent = int(self.tile_latent_min_tsize * self.tile_overlap_factor) # 8 * 0.25 = 2
|
| 612 |
+
t_limit = self.tile_latent_min_tsize - blend_extent # 8 - 2 = 6
|
| 613 |
+
|
| 614 |
+
row = []
|
| 615 |
+
for i in range(0, T, overlap_size):
|
| 616 |
+
tile = x[:, :, i: i + self.tile_sample_min_tsize, :, :]
|
| 617 |
+
if self.use_spatial_tiling and (
|
| 618 |
+
tile.shape[-1] > self.tile_sample_min_size or tile.shape[-2] > self.tile_sample_min_size):
|
| 619 |
+
tile = self.spatial_tiled_encode(tile)
|
| 620 |
+
else:
|
| 621 |
+
tile = self.encoder(tile)
|
| 622 |
+
row.append(tile)
|
| 623 |
+
result_row = []
|
| 624 |
+
for i, tile in enumerate(row):
|
| 625 |
+
if i > 0:
|
| 626 |
+
tile = self.blend_t(row[i - 1], tile, blend_extent)
|
| 627 |
+
result_row.append(tile[:, :, :t_limit, :, :])
|
| 628 |
+
moments = torch.cat(result_row, dim=-3)
|
| 629 |
+
return moments
|
| 630 |
+
|
| 631 |
+
def spatial_tiled_decode(self, z: torch.Tensor):
|
| 632 |
+
""" spatial tailing for frames """
|
| 633 |
+
B, C, T, H, W = z.shape
|
| 634 |
+
overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor)) # 8 * (1 - 0.25) = 6
|
| 635 |
+
blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor) # 256 * 0.25 = 64
|
| 636 |
+
row_limit = self.tile_sample_min_size - blend_extent # 256 - 64 = 192
|
| 637 |
+
|
| 638 |
+
rows = []
|
| 639 |
+
for i in range(0, H, overlap_size):
|
| 640 |
+
row = []
|
| 641 |
+
for j in range(0, W, overlap_size):
|
| 642 |
+
tile = z[:, :, :, i: i + self.tile_latent_min_size, j: j + self.tile_latent_min_size]
|
| 643 |
+
decoded = self.decoder(tile)
|
| 644 |
+
row.append(decoded)
|
| 645 |
+
rows.append(row)
|
| 646 |
+
|
| 647 |
+
result_rows = []
|
| 648 |
+
for i, row in enumerate(rows):
|
| 649 |
+
result_row = []
|
| 650 |
+
for j, tile in enumerate(row):
|
| 651 |
+
if i > 0:
|
| 652 |
+
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
| 653 |
+
if j > 0:
|
| 654 |
+
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
| 655 |
+
result_row.append(tile[:, :, :, :row_limit, :row_limit])
|
| 656 |
+
result_rows.append(torch.cat(result_row, dim=-1))
|
| 657 |
+
dec = torch.cat(result_rows, dim=-2)
|
| 658 |
+
return dec
|
| 659 |
+
|
| 660 |
+
def temporal_tiled_decode(self, z: torch.Tensor):
|
| 661 |
+
""" temporal tailing for frames """
|
| 662 |
+
B, C, T, H, W = z.shape
|
| 663 |
+
overlap_size = int(self.tile_latent_min_tsize * (1 - self.tile_overlap_factor)) # 8 * (1 - 0.25) = 6
|
| 664 |
+
blend_extent = int(self.tile_sample_min_tsize * self.tile_overlap_factor) # 64 * 0.25 = 16
|
| 665 |
+
t_limit = self.tile_sample_min_tsize - blend_extent # 64 - 16 = 48
|
| 666 |
+
assert 0 < overlap_size < self.tile_latent_min_tsize
|
| 667 |
+
|
| 668 |
+
row = []
|
| 669 |
+
for i in range(0, T, overlap_size):
|
| 670 |
+
tile = z[:, :, i: i + self.tile_latent_min_tsize, :, :]
|
| 671 |
+
if self.use_spatial_tiling and (
|
| 672 |
+
tile.shape[-1] > self.tile_latent_min_size or tile.shape[-2] > self.tile_latent_min_size):
|
| 673 |
+
decoded = self.spatial_tiled_decode(tile)
|
| 674 |
+
else:
|
| 675 |
+
decoded = self.decoder(tile)
|
| 676 |
+
row.append(decoded)
|
| 677 |
+
|
| 678 |
+
result_row = []
|
| 679 |
+
for i, tile in enumerate(row):
|
| 680 |
+
if i > 0:
|
| 681 |
+
tile = self.blend_t(row[i - 1], tile, blend_extent)
|
| 682 |
+
result_row.append(tile[:, :, :t_limit, :, :])
|
| 683 |
+
dec = torch.cat(result_row, dim=-3)
|
| 684 |
+
return dec
|
| 685 |
+
|
| 686 |
+
def encode(self, x: Tensor, return_dict: bool = True):
|
| 687 |
+
"""
|
| 688 |
+
Encodes the input by passing through the encoder network.
|
| 689 |
+
Support slicing and tiling for memory efficiency.
|
| 690 |
+
"""
|
| 691 |
+
def _encode(x):
|
| 692 |
+
if self.use_temporal_tiling and x.shape[-3] > self.tile_sample_min_tsize:
|
| 693 |
+
return self.temporal_tiled_encode(x)
|
| 694 |
+
if self.use_spatial_tiling and (
|
| 695 |
+
x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
|
| 696 |
+
return self.spatial_tiled_encode(x)
|
| 697 |
+
|
| 698 |
+
if self.use_compile:
|
| 699 |
+
@torch.compile
|
| 700 |
+
def encoder(x):
|
| 701 |
+
return self.encoder(x)
|
| 702 |
+
return encoder(x)
|
| 703 |
+
return self.encoder(x)
|
| 704 |
+
|
| 705 |
+
if len(x.shape) != 5: # (B, C, T, H, W)
|
| 706 |
+
x = x[:, :, None]
|
| 707 |
+
assert len(x.shape) == 5 # (B, C, T, H, W)
|
| 708 |
+
if x.shape[2] == 1:
|
| 709 |
+
x = x.expand(-1, -1, self.ffactor_temporal, -1, -1)
|
| 710 |
+
else:
|
| 711 |
+
assert x.shape[2] != self.ffactor_temporal and x.shape[2] % self.ffactor_temporal == 0
|
| 712 |
+
|
| 713 |
+
if self.use_slicing and x.shape[0] > 1:
|
| 714 |
+
if self.slicing_bsz == 1:
|
| 715 |
+
encoded_slices = [_encode(x_slice) for x_slice in x.split(1)]
|
| 716 |
+
else:
|
| 717 |
+
sections = [self.slicing_bsz] * (x.shape[0] // self.slicing_bsz)
|
| 718 |
+
if x.shape[0] % self.slicing_bsz != 0:
|
| 719 |
+
sections.append(x.shape[0] % self.slicing_bsz)
|
| 720 |
+
encoded_slices = [_encode(x_slice) for x_slice in x.split(sections)]
|
| 721 |
+
h = torch.cat(encoded_slices)
|
| 722 |
+
else:
|
| 723 |
+
h = _encode(x)
|
| 724 |
+
posterior = DiagonalGaussianDistribution(h)
|
| 725 |
+
|
| 726 |
+
if not return_dict:
|
| 727 |
+
return (posterior,)
|
| 728 |
+
|
| 729 |
+
return AutoencoderKLOutput(latent_dist=posterior)
|
| 730 |
+
|
| 731 |
+
def decode(self, z: Tensor, return_dict: bool = True, generator=None):
|
| 732 |
+
"""
|
| 733 |
+
Decodes the input by passing through the decoder network.
|
| 734 |
+
Support slicing and tiling for memory efficiency.
|
| 735 |
+
"""
|
| 736 |
+
def _decode(z):
|
| 737 |
+
if self.use_temporal_tiling and z.shape[-3] > self.tile_latent_min_tsize:
|
| 738 |
+
return self.temporal_tiled_decode(z)
|
| 739 |
+
if self.use_spatial_tiling and (
|
| 740 |
+
z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
|
| 741 |
+
return self.spatial_tiled_decode(z)
|
| 742 |
+
return self.decoder(z)
|
| 743 |
+
|
| 744 |
+
if self.use_slicing and z.shape[0] > 1:
|
| 745 |
+
decoded_slices = [_decode(z_slice) for z_slice in z.split(1)]
|
| 746 |
+
decoded = torch.cat(decoded_slices)
|
| 747 |
+
else:
|
| 748 |
+
decoded = _decode(z)
|
| 749 |
+
|
| 750 |
+
if z.shape[-3] == 1:
|
| 751 |
+
decoded = decoded[:, :, -1:]
|
| 752 |
+
|
| 753 |
+
if not return_dict:
|
| 754 |
+
return (decoded,)
|
| 755 |
+
|
| 756 |
+
return DecoderOutput(sample=decoded)
|
| 757 |
+
|
| 758 |
+
def forward(
|
| 759 |
+
self,
|
| 760 |
+
sample: torch.Tensor,
|
| 761 |
+
sample_posterior: bool = False,
|
| 762 |
+
return_posterior: bool = True,
|
| 763 |
+
return_dict: bool = True
|
| 764 |
+
):
|
| 765 |
+
posterior = self.encode(sample).latent_dist
|
| 766 |
+
z = posterior.sample() if sample_posterior else posterior.mode()
|
| 767 |
+
dec = self.decode(z).sample
|
| 768 |
+
return DecoderOutput(sample=dec, posterior=posterior) if return_dict else (dec, posterior)
|
| 769 |
+
|
| 770 |
+
def random_reset_tiling(self, x: torch.Tensor):
|
| 771 |
+
if x.shape[-3] == 1:
|
| 772 |
+
self.disable_spatial_tiling()
|
| 773 |
+
self.disable_temporal_tiling()
|
| 774 |
+
return
|
| 775 |
+
|
| 776 |
+
# Use fixed shape here
|
| 777 |
+
min_sample_size = int(1 / self.tile_overlap_factor) * self.ffactor_spatial
|
| 778 |
+
min_sample_tsize = int(1 / self.tile_overlap_factor) * self.ffactor_temporal
|
| 779 |
+
sample_size = random.choice([None, 1 * min_sample_size, 2 * min_sample_size, 3 * min_sample_size])
|
| 780 |
+
if sample_size is None:
|
| 781 |
+
self.disable_spatial_tiling()
|
| 782 |
+
else:
|
| 783 |
+
self.tile_sample_min_size = sample_size
|
| 784 |
+
self.tile_latent_min_size = sample_size // self.ffactor_spatial
|
| 785 |
+
self.enable_spatial_tiling()
|
| 786 |
+
|
| 787 |
+
sample_tsize = random.choice([None, 1 * min_sample_tsize, 2 * min_sample_tsize, 3 * min_sample_tsize])
|
| 788 |
+
if sample_tsize is None:
|
| 789 |
+
self.disable_temporal_tiling()
|
| 790 |
+
else:
|
| 791 |
+
self.tile_sample_min_tsize = sample_tsize
|
| 792 |
+
self.tile_latent_min_tsize = sample_tsize // self.ffactor_temporal
|
| 793 |
+
self.enable_temporal_tiling()
|
config.json
ADDED
|
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_classification_head": false,
|
| 3 |
+
"anyres_pooling_size": 2,
|
| 4 |
+
"anyres_vit_max_image_size": null,
|
| 5 |
+
"anyres_vit_two_views": false,
|
| 6 |
+
"architectures": [
|
| 7 |
+
"HunyuanImage3ForCausalMM"
|
| 8 |
+
],
|
| 9 |
+
"attention_bias": false,
|
| 10 |
+
"attention_dropout": 0.0,
|
| 11 |
+
"attention_head_dim": 128,
|
| 12 |
+
"auto_map": {
|
| 13 |
+
"AutoConfig": "configuration_hunyuan.HunyuanImage3Config",
|
| 14 |
+
"AutoModel": "hunyuan.HunyuanImage3Model",
|
| 15 |
+
"AutoModelForCausalLM": "hunyuan.HunyuanImage3ForCausalMM"
|
| 16 |
+
},
|
| 17 |
+
"bos_token_id": 127958,
|
| 18 |
+
"cla_share_factor": 2,
|
| 19 |
+
"class_num": 0,
|
| 20 |
+
"dense_list": [
|
| 21 |
+
4096,
|
| 22 |
+
0
|
| 23 |
+
],
|
| 24 |
+
"eod_token_id": 3,
|
| 25 |
+
"eos_token_id": 127957,
|
| 26 |
+
"group_limited_greedy": false,
|
| 27 |
+
"hidden_act": "silu",
|
| 28 |
+
"hidden_size": 4096,
|
| 29 |
+
"im_end_id": 128001,
|
| 30 |
+
"im_newline_id": 11,
|
| 31 |
+
"im_start_id": 128000,
|
| 32 |
+
"image_token_id": 128006,
|
| 33 |
+
"initializer_range": 0.02,
|
| 34 |
+
"intermediate_size": 3072,
|
| 35 |
+
"kv_lora_rank": null,
|
| 36 |
+
"mask_init_id": 12,
|
| 37 |
+
"max_position_embeddings": 12800,
|
| 38 |
+
"mlp_bias": false,
|
| 39 |
+
"model_type": "hunyuan_image_3_moe",
|
| 40 |
+
"moe_drop_tokens": false,
|
| 41 |
+
"moe_intermediate_size": [
|
| 42 |
+
3072,
|
| 43 |
+
3072,
|
| 44 |
+
3072,
|
| 45 |
+
3072,
|
| 46 |
+
3072,
|
| 47 |
+
3072,
|
| 48 |
+
3072,
|
| 49 |
+
3072,
|
| 50 |
+
3072,
|
| 51 |
+
3072,
|
| 52 |
+
3072,
|
| 53 |
+
3072,
|
| 54 |
+
3072,
|
| 55 |
+
3072,
|
| 56 |
+
3072,
|
| 57 |
+
3072,
|
| 58 |
+
3072,
|
| 59 |
+
3072,
|
| 60 |
+
3072,
|
| 61 |
+
3072,
|
| 62 |
+
3072,
|
| 63 |
+
3072,
|
| 64 |
+
3072,
|
| 65 |
+
3072,
|
| 66 |
+
3072,
|
| 67 |
+
3072,
|
| 68 |
+
3072,
|
| 69 |
+
3072,
|
| 70 |
+
3072,
|
| 71 |
+
3072,
|
| 72 |
+
3072,
|
| 73 |
+
3072
|
| 74 |
+
],
|
| 75 |
+
"moe_layer_num_skipped": 0,
|
| 76 |
+
"moe_random_routing_dropped_token": false,
|
| 77 |
+
"moe_topk": [
|
| 78 |
+
8,
|
| 79 |
+
8,
|
| 80 |
+
8,
|
| 81 |
+
8,
|
| 82 |
+
8,
|
| 83 |
+
8,
|
| 84 |
+
8,
|
| 85 |
+
8,
|
| 86 |
+
8,
|
| 87 |
+
8,
|
| 88 |
+
8,
|
| 89 |
+
8,
|
| 90 |
+
8,
|
| 91 |
+
8,
|
| 92 |
+
8,
|
| 93 |
+
8,
|
| 94 |
+
8,
|
| 95 |
+
8,
|
| 96 |
+
8,
|
| 97 |
+
8,
|
| 98 |
+
8,
|
| 99 |
+
8,
|
| 100 |
+
8,
|
| 101 |
+
8,
|
| 102 |
+
8,
|
| 103 |
+
8,
|
| 104 |
+
8,
|
| 105 |
+
8,
|
| 106 |
+
8,
|
| 107 |
+
8,
|
| 108 |
+
8,
|
| 109 |
+
8
|
| 110 |
+
],
|
| 111 |
+
"n_group": false,
|
| 112 |
+
"norm_topk_prob": true,
|
| 113 |
+
"norm_type": "rms",
|
| 114 |
+
"num_attention_heads": 32,
|
| 115 |
+
"num_experts": 64,
|
| 116 |
+
"num_hidden_layers": 32,
|
| 117 |
+
"num_key_value_heads": 8,
|
| 118 |
+
"num_media_embeds": 257,
|
| 119 |
+
"num_shared_expert": [
|
| 120 |
+
1,
|
| 121 |
+
1,
|
| 122 |
+
1,
|
| 123 |
+
1,
|
| 124 |
+
1,
|
| 125 |
+
1,
|
| 126 |
+
1,
|
| 127 |
+
1,
|
| 128 |
+
1,
|
| 129 |
+
1,
|
| 130 |
+
1,
|
| 131 |
+
1,
|
| 132 |
+
1,
|
| 133 |
+
1,
|
| 134 |
+
1,
|
| 135 |
+
1,
|
| 136 |
+
1,
|
| 137 |
+
1,
|
| 138 |
+
1,
|
| 139 |
+
1,
|
| 140 |
+
1,
|
| 141 |
+
1,
|
| 142 |
+
1,
|
| 143 |
+
1,
|
| 144 |
+
1,
|
| 145 |
+
1,
|
| 146 |
+
1,
|
| 147 |
+
1,
|
| 148 |
+
1,
|
| 149 |
+
1,
|
| 150 |
+
1,
|
| 151 |
+
1
|
| 152 |
+
],
|
| 153 |
+
"pad_id": 128009,
|
| 154 |
+
"pad_token_id": 128009,
|
| 155 |
+
"pool_type": "last",
|
| 156 |
+
"position_embedding_xdrope": false,
|
| 157 |
+
"pretraining_tp": 1,
|
| 158 |
+
"q_lora_rank": null,
|
| 159 |
+
"qk_nope_head_dim": null,
|
| 160 |
+
"qk_rope_head_dim": null,
|
| 161 |
+
"rms_norm_eps": 1e-05,
|
| 162 |
+
"rope_scaling": {
|
| 163 |
+
"alpha": 1.0,
|
| 164 |
+
"beta_fast": 32,
|
| 165 |
+
"beta_slow": 1,
|
| 166 |
+
"factor": 1.0,
|
| 167 |
+
"mscale": 1.0,
|
| 168 |
+
"mscale_all_dim": 1.0,
|
| 169 |
+
"type": "custom"
|
| 170 |
+
},
|
| 171 |
+
"rope_theta": 10000.0,
|
| 172 |
+
"routed_scaling_factor": false,
|
| 173 |
+
"skip_cls_token": false,
|
| 174 |
+
"text_end_id": 7,
|
| 175 |
+
"text_start_id": 6,
|
| 176 |
+
"tie_word_embeddings": false,
|
| 177 |
+
"topk_group": false,
|
| 178 |
+
"torch_dtype": "bfloat16",
|
| 179 |
+
"transformers_version": "4.50.0",
|
| 180 |
+
"use_cache": true,
|
| 181 |
+
"use_cla": false,
|
| 182 |
+
"use_mixed_mlp_moe": true,
|
| 183 |
+
"use_mla": false,
|
| 184 |
+
"use_qk_norm": true,
|
| 185 |
+
"use_rotary_pos_emb": true,
|
| 186 |
+
"v_head_dim": null,
|
| 187 |
+
"video_end_id": 10,
|
| 188 |
+
"video_start_id": 9,
|
| 189 |
+
"vit_add_patchemb_bias": false,
|
| 190 |
+
"vit_input_resolution": 224,
|
| 191 |
+
"vit_mapping_type": "resampler",
|
| 192 |
+
"vit_norm_type": "fused",
|
| 193 |
+
"vit_patch": 1,
|
| 194 |
+
"vit_path": null,
|
| 195 |
+
"vit_remove_prenorm": false,
|
| 196 |
+
"vit_token": 64,
|
| 197 |
+
"vit_type": null,
|
| 198 |
+
"vit_used_rms_norm": false,
|
| 199 |
+
"vocab_size": 133120,
|
| 200 |
+
"xdrope_section": null,
|
| 201 |
+
"head_dim": 128,
|
| 202 |
+
"vae_downsample_factor": [
|
| 203 |
+
16,
|
| 204 |
+
16
|
| 205 |
+
],
|
| 206 |
+
"vae": {
|
| 207 |
+
"_class_name": "AutoencoderKLConv3D",
|
| 208 |
+
"block_out_channels": [
|
| 209 |
+
128,
|
| 210 |
+
256,
|
| 211 |
+
512,
|
| 212 |
+
1024,
|
| 213 |
+
1024
|
| 214 |
+
],
|
| 215 |
+
"in_channels": 3,
|
| 216 |
+
"out_channels": 3,
|
| 217 |
+
"latent_channels": 32,
|
| 218 |
+
"layers_per_block": 2,
|
| 219 |
+
"ffactor_spatial": 16,
|
| 220 |
+
"ffactor_temporal": 4,
|
| 221 |
+
"sample_size": 384,
|
| 222 |
+
"sample_tsize": 96,
|
| 223 |
+
"downsample_match_channel": true,
|
| 224 |
+
"upsample_match_channel": true,
|
| 225 |
+
"scaling_factor": 0.562679178327931
|
| 226 |
+
},
|
| 227 |
+
"vit": {
|
| 228 |
+
"_attn_implementation": "sdpa",
|
| 229 |
+
"attention_dropout": 0.0,
|
| 230 |
+
"hidden_act": "gelu_pytorch_tanh",
|
| 231 |
+
"hidden_size": 1152,
|
| 232 |
+
"intermediate_size": 4304,
|
| 233 |
+
"layer_norm_eps": 1e-06,
|
| 234 |
+
"num_attention_heads": 16,
|
| 235 |
+
"num_channels": 3,
|
| 236 |
+
"num_hidden_layers": 27,
|
| 237 |
+
"num_patches": 256,
|
| 238 |
+
"patch_size": 16,
|
| 239 |
+
"torch_dtype": "float32",
|
| 240 |
+
"output_attentions": false,
|
| 241 |
+
"output_hidden_states": false,
|
| 242 |
+
"use_return_dict": true
|
| 243 |
+
},
|
| 244 |
+
"vit_processor": {
|
| 245 |
+
"do_convert_rgb": null,
|
| 246 |
+
"do_normalize": true,
|
| 247 |
+
"do_rescale": true,
|
| 248 |
+
"do_resize": true,
|
| 249 |
+
"image_mean": [
|
| 250 |
+
0.5,
|
| 251 |
+
0.5,
|
| 252 |
+
0.5
|
| 253 |
+
],
|
| 254 |
+
"image_processor_type": "Siglip2ImageProcessorFast",
|
| 255 |
+
"image_std": [
|
| 256 |
+
0.5,
|
| 257 |
+
0.5,
|
| 258 |
+
0.5
|
| 259 |
+
],
|
| 260 |
+
"max_num_patches": 1024,
|
| 261 |
+
"patch_size": 16,
|
| 262 |
+
"processor_class": "Siglip2Processor",
|
| 263 |
+
"resample": 2,
|
| 264 |
+
"rescale_factor": 0.00392156862745098
|
| 265 |
+
},
|
| 266 |
+
"vit_aligner": {
|
| 267 |
+
"projector_type": "mlp_gelu",
|
| 268 |
+
"input_dim": 1152,
|
| 269 |
+
"n_embed": 4096,
|
| 270 |
+
"depth": 2,
|
| 271 |
+
"torch_dtype": "float32"
|
| 272 |
+
}
|
| 273 |
+
}
|
configuration_hunyuan.py
ADDED
|
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Licensed under the TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT (the "License");
|
| 2 |
+
# you may not use this file except in compliance with the License.
|
| 3 |
+
# You may obtain a copy of the License at
|
| 4 |
+
#
|
| 5 |
+
# https://github.com/Tencent-Hunyuan/HunyuanImage-3.0/blob/main/LICENSE
|
| 6 |
+
#
|
| 7 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 8 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 9 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 10 |
+
# See the License for the specific language governing permissions and
|
| 11 |
+
# limitations under the License.
|
| 12 |
+
# ==============================================================================
|
| 13 |
+
|
| 14 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 15 |
+
from transformers.utils import logging
|
| 16 |
+
from typing import List, Union
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
logger = logging.get_logger(__name__)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class HunyuanImage3Config(PretrainedConfig):
|
| 23 |
+
r"""
|
| 24 |
+
This is the configuration class to store the configuration of a [`HunyuanImage3Model`]. It is used to instantiate
|
| 25 |
+
an Hunyuan model according to the specified arguments, defining the model architecture. Instantiating a
|
| 26 |
+
configuration with the defaults will yield a similar configuration to that of the Hunyuan-7B.
|
| 27 |
+
|
| 28 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 29 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
vocab_size (`int`, *optional*, defaults to 32000):
|
| 34 |
+
Vocabulary size of the Hunyuan Image 3 model. Defines the number of different tokens that can be
|
| 35 |
+
represented by the `inputs_ids` passed when calling [`HunyuanImage3Model`]
|
| 36 |
+
hidden_size (`int`, *optional*, defaults to 4096):
|
| 37 |
+
Dimension of the hidden representations.
|
| 38 |
+
intermediate_size (`int`, *optional*, defaults to 11008):
|
| 39 |
+
Dimension of the MLP representations or shared MLP representations.
|
| 40 |
+
moe_intermediate_size (`int` or `List`, *optional*, defaults to 11008):
|
| 41 |
+
Dimension of the MLP representations in MoE. Use a list if you want a different size per layer.
|
| 42 |
+
num_hidden_layers (`int`, *optional*, defaults to 32):
|
| 43 |
+
Number of hidden layers in the Transformer decoder.
|
| 44 |
+
num_attention_heads (`int`, *optional*, defaults to 32):
|
| 45 |
+
Number of attention heads for each attention layer in the Transformer decoder.
|
| 46 |
+
num_key_value_heads (`int`, *optional*):
|
| 47 |
+
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
| 48 |
+
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
| 49 |
+
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
| 50 |
+
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
| 51 |
+
by meanpooling all the original heads within that group. For more details checkout [this
|
| 52 |
+
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
|
| 53 |
+
`num_attention_heads`.
|
| 54 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
| 55 |
+
The non-linear activation function (function or string) in the decoder.
|
| 56 |
+
max_position_embeddings (`int`, *optional*, defaults to 2048):
|
| 57 |
+
The maximum sequence length that this model might ever be used with.
|
| 58 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 59 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 60 |
+
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
| 61 |
+
The epsilon used by the rms normalization layers.
|
| 62 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
| 63 |
+
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
| 64 |
+
relevant if `config.is_decoder=True`.
|
| 65 |
+
pad_token_id (`int`, *optional*):
|
| 66 |
+
Padding token id.
|
| 67 |
+
bos_token_id (`int`, *optional*, defaults to 1):
|
| 68 |
+
Beginning of stream token id.
|
| 69 |
+
eos_token_id (`int`, *optional*, defaults to 2):
|
| 70 |
+
End of stream token id.
|
| 71 |
+
pretraining_tp (`int`, *optional*, defaults to 1):
|
| 72 |
+
Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
|
| 73 |
+
document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
|
| 74 |
+
necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
|
| 75 |
+
issue](https://github.com/pytorch/pytorch/issues/76232).
|
| 76 |
+
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
| 77 |
+
Whether to tie weight embeddings
|
| 78 |
+
rope_theta (`float`, *optional*, defaults to 10000.0):
|
| 79 |
+
The base period of the RoPE embeddings.
|
| 80 |
+
rope_scaling (`Dict`, *optional*):
|
| 81 |
+
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
|
| 82 |
+
strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
|
| 83 |
+
`{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
|
| 84 |
+
`max_position_embeddings` to the expected new maximum. See the following thread for more information on how
|
| 85 |
+
these scaling strategies behave:
|
| 86 |
+
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
|
| 87 |
+
experimental feature, subject to breaking API changes in future versions.
|
| 88 |
+
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
|
| 89 |
+
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
| 90 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
| 91 |
+
The dropout ratio for the attention probabilities.
|
| 92 |
+
use_qk_norm (`bool`, *optional*, defaults to `False`):
|
| 93 |
+
Whether query and key in attention use norm
|
| 94 |
+
use_cla (`bool`, *optional*, defaults to `False`):
|
| 95 |
+
Whether to use CLA in attention
|
| 96 |
+
cla_share_factor (`int`, *optional*, defaults to 1):
|
| 97 |
+
The share factor of CLA
|
| 98 |
+
num_experts (`int` or `List`, *optional*, defaults to 1):
|
| 99 |
+
The number of experts for moe. If it is a list, it will be used as the number of experts for each layer.
|
| 100 |
+
num_shared_expert (`int` or `List`, *optional*, defaults to 1):
|
| 101 |
+
The number of shared experts for moe. If it is a list, it will be used as the number of shared experts
|
| 102 |
+
for each layer.
|
| 103 |
+
moe_topk (`int` or `List`, *optional*, defaults to 1):
|
| 104 |
+
The topk value for moe. If it is a list, it will be used as the topk value for each layer.
|
| 105 |
+
capacity_factor (Not used) (`float` or `List`, *optional*, defaults to 1.0):
|
| 106 |
+
The capacity factor for moe. If it is a list, it will be used as the capacity factor for each layer.
|
| 107 |
+
moe_layer_num_skipped (`int`, *optional*, defaults to 0):
|
| 108 |
+
First moe_layer_num_skipped layers do not use MoE.
|
| 109 |
+
"""
|
| 110 |
+
|
| 111 |
+
model_type = "Hunyuan"
|
| 112 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
| 113 |
+
|
| 114 |
+
def __init__(
|
| 115 |
+
self,
|
| 116 |
+
vocab_size=290943,
|
| 117 |
+
hidden_size=4096,
|
| 118 |
+
intermediate_size: int=11008,
|
| 119 |
+
moe_intermediate_size: Union[int, List]=None,
|
| 120 |
+
num_hidden_layers=32,
|
| 121 |
+
num_attention_heads=32,
|
| 122 |
+
num_key_value_heads=None,
|
| 123 |
+
attention_head_dim=None,
|
| 124 |
+
hidden_act="silu",
|
| 125 |
+
max_position_embeddings=2048,
|
| 126 |
+
initializer_range=0.02,
|
| 127 |
+
rms_norm_eps=1e-5,
|
| 128 |
+
use_cache=True,
|
| 129 |
+
pad_token_id=0,
|
| 130 |
+
bos_token_id=1,
|
| 131 |
+
eos_token_id=2,
|
| 132 |
+
eod_token_id=3,
|
| 133 |
+
im_start_id=4,
|
| 134 |
+
im_end_id=5,
|
| 135 |
+
text_start_id=6,
|
| 136 |
+
text_end_id=7,
|
| 137 |
+
image_token_id=8,
|
| 138 |
+
video_start_id=9,
|
| 139 |
+
video_end_id=10,
|
| 140 |
+
im_newline_id=11,
|
| 141 |
+
mask_init_id=12,
|
| 142 |
+
pretraining_tp=1,
|
| 143 |
+
tie_word_embeddings=False,
|
| 144 |
+
rope_theta=10000.0,
|
| 145 |
+
rope_scaling=None,
|
| 146 |
+
attention_bias=False,
|
| 147 |
+
mlp_bias=False,
|
| 148 |
+
attention_dropout=0.0,
|
| 149 |
+
use_qk_norm=False,
|
| 150 |
+
use_rotary_pos_emb=True,
|
| 151 |
+
use_cla=False,
|
| 152 |
+
cla_share_factor=1,
|
| 153 |
+
norm_type="hf_rms",
|
| 154 |
+
num_experts: Union[int, List] = 1,
|
| 155 |
+
use_mixed_mlp_moe=False,
|
| 156 |
+
num_shared_expert: Union[int, List] = 1,
|
| 157 |
+
moe_topk: Union[int, List] = 1,
|
| 158 |
+
capacity_factor: int = 1.0,
|
| 159 |
+
moe_drop_tokens=False,
|
| 160 |
+
moe_random_routing_dropped_token=False,
|
| 161 |
+
use_mla=False,
|
| 162 |
+
kv_lora_rank=512,
|
| 163 |
+
q_lora_rank=1536,
|
| 164 |
+
qk_rope_head_dim=64,
|
| 165 |
+
v_head_dim=128,
|
| 166 |
+
qk_nope_head_dim=128,
|
| 167 |
+
moe_layer_num_skipped=0,
|
| 168 |
+
norm_topk_prob=True,
|
| 169 |
+
routed_scaling_factor=1.0,
|
| 170 |
+
group_limited_greedy=False,
|
| 171 |
+
n_group=None,
|
| 172 |
+
topk_group=None,
|
| 173 |
+
add_classification_head=False,
|
| 174 |
+
class_num=0,
|
| 175 |
+
pool_type="last",
|
| 176 |
+
pad_id=-1,
|
| 177 |
+
# Added
|
| 178 |
+
moe_impl="eager",
|
| 179 |
+
vae_downsample_factor=(16, 16), # (h, w)
|
| 180 |
+
img_proj_type="unet",
|
| 181 |
+
patch_size=1,
|
| 182 |
+
patch_embed_hidden_dim=1024,
|
| 183 |
+
image_base_size=1024,
|
| 184 |
+
vae=None,
|
| 185 |
+
vit=None,
|
| 186 |
+
vit_processor=None,
|
| 187 |
+
vit_aligner=None,
|
| 188 |
+
**kwargs,
|
| 189 |
+
):
|
| 190 |
+
self.vocab_size = vocab_size
|
| 191 |
+
self.max_position_embeddings = max_position_embeddings
|
| 192 |
+
self.hidden_size = hidden_size
|
| 193 |
+
self.intermediate_size = intermediate_size
|
| 194 |
+
self.moe_intermediate_size = moe_intermediate_size
|
| 195 |
+
self.num_hidden_layers = num_hidden_layers
|
| 196 |
+
self.num_attention_heads = num_attention_heads
|
| 197 |
+
self.moe_impl = moe_impl
|
| 198 |
+
self.num_experts = num_experts
|
| 199 |
+
self.use_mixed_mlp_moe = use_mixed_mlp_moe
|
| 200 |
+
self.num_shared_expert = num_shared_expert
|
| 201 |
+
self.moe_topk = moe_topk
|
| 202 |
+
self.capacity_factor = capacity_factor
|
| 203 |
+
self.moe_drop_tokens = moe_drop_tokens
|
| 204 |
+
self.moe_random_routing_dropped_token = moe_random_routing_dropped_token
|
| 205 |
+
|
| 206 |
+
if attention_head_dim is not None:
|
| 207 |
+
self.attention_head_dim = attention_head_dim
|
| 208 |
+
else:
|
| 209 |
+
self.attention_head_dim = self.hidden_size // num_attention_heads
|
| 210 |
+
|
| 211 |
+
# for backward compatibility
|
| 212 |
+
if num_key_value_heads is None:
|
| 213 |
+
num_key_value_heads = num_attention_heads
|
| 214 |
+
|
| 215 |
+
self.num_key_value_heads = num_key_value_heads
|
| 216 |
+
self.hidden_act = hidden_act
|
| 217 |
+
self.initializer_range = initializer_range
|
| 218 |
+
self.rms_norm_eps = rms_norm_eps
|
| 219 |
+
self.pretraining_tp = pretraining_tp
|
| 220 |
+
self.use_cache = use_cache
|
| 221 |
+
self.rope_theta = rope_theta
|
| 222 |
+
self.rope_scaling = rope_scaling
|
| 223 |
+
self.attention_bias = attention_bias
|
| 224 |
+
self.mlp_bias = mlp_bias
|
| 225 |
+
self.attention_dropout = attention_dropout
|
| 226 |
+
self.use_qk_norm = use_qk_norm
|
| 227 |
+
self.use_rotary_pos_emb = use_rotary_pos_emb
|
| 228 |
+
self.use_cla = use_cla
|
| 229 |
+
self.cla_share_factor = cla_share_factor
|
| 230 |
+
self.norm_type = norm_type
|
| 231 |
+
# MLA args
|
| 232 |
+
self.use_mla = use_mla
|
| 233 |
+
self.kv_lora_rank = kv_lora_rank
|
| 234 |
+
self.q_lora_rank = q_lora_rank
|
| 235 |
+
self.qk_rope_head_dim = qk_rope_head_dim
|
| 236 |
+
self.qk_nope_head_dim = qk_nope_head_dim
|
| 237 |
+
self.v_head_dim = v_head_dim
|
| 238 |
+
|
| 239 |
+
# DeepSeek related args
|
| 240 |
+
self.moe_layer_num_skipped = moe_layer_num_skipped
|
| 241 |
+
self.norm_topk_prob = norm_topk_prob
|
| 242 |
+
self.routed_scaling_factor = routed_scaling_factor
|
| 243 |
+
self.group_limited_greedy = group_limited_greedy
|
| 244 |
+
self.n_group = n_group
|
| 245 |
+
self.topk_group = topk_group
|
| 246 |
+
self.add_classification_head = add_classification_head
|
| 247 |
+
self.class_num = class_num
|
| 248 |
+
self.pool_type = pool_type
|
| 249 |
+
self.pad_id = pad_id
|
| 250 |
+
|
| 251 |
+
if self.class_num is not None:
|
| 252 |
+
self.dense_list = [self.hidden_size, self.class_num]
|
| 253 |
+
|
| 254 |
+
# ViT args
|
| 255 |
+
self.vit = vit
|
| 256 |
+
self.vit_processor = vit_processor
|
| 257 |
+
self.vit_aligner = vit_aligner
|
| 258 |
+
|
| 259 |
+
# Image Gen args
|
| 260 |
+
self.vae = vae
|
| 261 |
+
self.vae_downsample_factor = vae_downsample_factor
|
| 262 |
+
self.img_proj_type = img_proj_type
|
| 263 |
+
self.patch_size = patch_size
|
| 264 |
+
self.patch_embed_hidden_dim = patch_embed_hidden_dim
|
| 265 |
+
self.image_base_size = image_base_size
|
| 266 |
+
|
| 267 |
+
# token id
|
| 268 |
+
self.eod_token_id = eod_token_id
|
| 269 |
+
self.im_start_id = im_start_id
|
| 270 |
+
self.im_end_id = im_end_id
|
| 271 |
+
self.text_start_id = text_start_id
|
| 272 |
+
self.text_end_id = text_end_id
|
| 273 |
+
self.image_token_id = image_token_id
|
| 274 |
+
self.video_start_id = video_start_id
|
| 275 |
+
self.video_end_id = video_end_id
|
| 276 |
+
self.im_newline_id = im_newline_id
|
| 277 |
+
self.mask_init_id = mask_init_id
|
| 278 |
+
|
| 279 |
+
super().__init__(
|
| 280 |
+
pad_token_id=pad_token_id,
|
| 281 |
+
bos_token_id=bos_token_id,
|
| 282 |
+
eos_token_id=eos_token_id,
|
| 283 |
+
tie_word_embeddings=tie_word_embeddings,
|
| 284 |
+
**kwargs,
|
| 285 |
+
)
|
generation_config.json
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"disable_compile": true,
|
| 3 |
+
"eos_token_id": [
|
| 4 |
+
127957
|
| 5 |
+
],
|
| 6 |
+
"pad_token_id": 128009,
|
| 7 |
+
"do_sample": true,
|
| 8 |
+
"top_k": 1024,
|
| 9 |
+
"top_p": 0.95,
|
| 10 |
+
"temperature": 0.6,
|
| 11 |
+
"max_length": 12800,
|
| 12 |
+
"sequence_template": "pretrain",
|
| 13 |
+
"diff_infer_steps": 50,
|
| 14 |
+
"diff_guidance_scale": 5.0,
|
| 15 |
+
"flow_shift": 3.0,
|
| 16 |
+
"use_system_prompt": "None",
|
| 17 |
+
"drop_think": false,
|
| 18 |
+
"bot_task": "image",
|
| 19 |
+
"transformers_version": "4.50.0"
|
| 20 |
+
}
|
hunyuan.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
hunyuan_image_3_pipeline.py
ADDED
|
@@ -0,0 +1,879 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Licensed under the TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT (the "License");
|
| 2 |
+
# you may not use this file except in compliance with the License.
|
| 3 |
+
# You may obtain a copy of the License at
|
| 4 |
+
#
|
| 5 |
+
# https://github.com/Tencent-Hunyuan/HunyuanImage-3.0/blob/main/LICENSE
|
| 6 |
+
#
|
| 7 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 8 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 9 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 10 |
+
# See the License for the specific language governing permissions and
|
| 11 |
+
# limitations under the License.
|
| 12 |
+
# ==============================================================================
|
| 13 |
+
#
|
| 14 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 15 |
+
#
|
| 16 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 17 |
+
# you may not use this file except in compliance with the License.
|
| 18 |
+
# You may obtain a copy of the License at
|
| 19 |
+
#
|
| 20 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 21 |
+
#
|
| 22 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 23 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 24 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 25 |
+
# See the License for the specific language governing permissions and
|
| 26 |
+
# limitations under the License.
|
| 27 |
+
# ==============================================================================================
|
| 28 |
+
|
| 29 |
+
import inspect
|
| 30 |
+
import math
|
| 31 |
+
from dataclasses import dataclass
|
| 32 |
+
from typing import Any, Callable, Dict, List
|
| 33 |
+
from typing import Optional, Tuple, Union
|
| 34 |
+
|
| 35 |
+
import numpy as np
|
| 36 |
+
import torch
|
| 37 |
+
from PIL import Image
|
| 38 |
+
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
| 39 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 40 |
+
from diffusers.image_processor import VaeImageProcessor
|
| 41 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
| 42 |
+
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
| 43 |
+
from diffusers.utils import BaseOutput, logging
|
| 44 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 45 |
+
|
| 46 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def retrieve_timesteps(
|
| 50 |
+
scheduler,
|
| 51 |
+
num_inference_steps: Optional[int] = None,
|
| 52 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 53 |
+
timesteps: Optional[List[int]] = None,
|
| 54 |
+
sigmas: Optional[List[float]] = None,
|
| 55 |
+
**kwargs,
|
| 56 |
+
):
|
| 57 |
+
"""
|
| 58 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
| 59 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
scheduler (`SchedulerMixin`):
|
| 63 |
+
The scheduler to get timesteps from.
|
| 64 |
+
num_inference_steps (`int`):
|
| 65 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
| 66 |
+
must be `None`.
|
| 67 |
+
device (`str` or `torch.device`, *optional*):
|
| 68 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 69 |
+
timesteps (`List[int]`, *optional*):
|
| 70 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
| 71 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
| 72 |
+
sigmas (`List[float]`, *optional*):
|
| 73 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
| 74 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
| 75 |
+
|
| 76 |
+
Returns:
|
| 77 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
| 78 |
+
second element is the number of inference steps.
|
| 79 |
+
"""
|
| 80 |
+
if timesteps is not None and sigmas is not None:
|
| 81 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
| 82 |
+
if timesteps is not None:
|
| 83 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 84 |
+
if not accepts_timesteps:
|
| 85 |
+
raise ValueError(
|
| 86 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 87 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
| 88 |
+
)
|
| 89 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
| 90 |
+
timesteps = scheduler.timesteps
|
| 91 |
+
num_inference_steps = len(timesteps)
|
| 92 |
+
elif sigmas is not None:
|
| 93 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 94 |
+
if not accept_sigmas:
|
| 95 |
+
raise ValueError(
|
| 96 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 97 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
| 98 |
+
)
|
| 99 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
| 100 |
+
timesteps = scheduler.timesteps
|
| 101 |
+
num_inference_steps = len(timesteps)
|
| 102 |
+
else:
|
| 103 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
| 104 |
+
timesteps = scheduler.timesteps
|
| 105 |
+
return timesteps, num_inference_steps
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
| 109 |
+
r"""
|
| 110 |
+
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
|
| 111 |
+
Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
| 112 |
+
Flawed](https://arxiv.org/pdf/2305.08891.pdf).
|
| 113 |
+
|
| 114 |
+
Args:
|
| 115 |
+
noise_cfg (`torch.Tensor`):
|
| 116 |
+
The predicted noise tensor for the guided diffusion process.
|
| 117 |
+
noise_pred_text (`torch.Tensor`):
|
| 118 |
+
The predicted noise tensor for the text-guided diffusion process.
|
| 119 |
+
guidance_rescale (`float`, *optional*, defaults to 0.0):
|
| 120 |
+
A rescale factor applied to the noise predictions.
|
| 121 |
+
Returns:
|
| 122 |
+
noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
|
| 123 |
+
"""
|
| 124 |
+
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
| 125 |
+
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
| 126 |
+
# rescale the results from guidance (fixes overexposure)
|
| 127 |
+
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
| 128 |
+
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
|
| 129 |
+
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
|
| 130 |
+
return noise_cfg
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
@dataclass
|
| 134 |
+
class HunyuanImage3Text2ImagePipelineOutput(BaseOutput):
|
| 135 |
+
samples: Union[List[Any], np.ndarray]
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
@dataclass
|
| 139 |
+
class FlowMatchDiscreteSchedulerOutput(BaseOutput):
|
| 140 |
+
"""
|
| 141 |
+
Output class for the scheduler's `step` function output.
|
| 142 |
+
|
| 143 |
+
Args:
|
| 144 |
+
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
| 145 |
+
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
|
| 146 |
+
denoising loop.
|
| 147 |
+
"""
|
| 148 |
+
|
| 149 |
+
prev_sample: torch.FloatTensor
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
class FlowMatchDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
| 153 |
+
"""
|
| 154 |
+
Euler scheduler.
|
| 155 |
+
|
| 156 |
+
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
|
| 157 |
+
methods the library implements for all schedulers such as loading and saving.
|
| 158 |
+
|
| 159 |
+
Args:
|
| 160 |
+
num_train_timesteps (`int`, defaults to 1000):
|
| 161 |
+
The number of diffusion steps to train the model.
|
| 162 |
+
timestep_spacing (`str`, defaults to `"linspace"`):
|
| 163 |
+
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
| 164 |
+
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
| 165 |
+
shift (`float`, defaults to 1.0):
|
| 166 |
+
The shift value for the timestep schedule.
|
| 167 |
+
reverse (`bool`, defaults to `True`):
|
| 168 |
+
Whether to reverse the timestep schedule.
|
| 169 |
+
"""
|
| 170 |
+
|
| 171 |
+
_compatibles = []
|
| 172 |
+
order = 1
|
| 173 |
+
|
| 174 |
+
@register_to_config
|
| 175 |
+
def __init__(
|
| 176 |
+
self,
|
| 177 |
+
num_train_timesteps: int = 1000,
|
| 178 |
+
shift: float = 1.0,
|
| 179 |
+
reverse: bool = True,
|
| 180 |
+
solver: str = "euler",
|
| 181 |
+
use_flux_shift: bool = False,
|
| 182 |
+
flux_base_shift: float = 0.5,
|
| 183 |
+
flux_max_shift: float = 1.15,
|
| 184 |
+
n_tokens: Optional[int] = None,
|
| 185 |
+
):
|
| 186 |
+
sigmas = torch.linspace(1, 0, num_train_timesteps + 1)
|
| 187 |
+
|
| 188 |
+
if not reverse:
|
| 189 |
+
sigmas = sigmas.flip(0)
|
| 190 |
+
|
| 191 |
+
self.sigmas = sigmas
|
| 192 |
+
# the value fed to model
|
| 193 |
+
self.timesteps = (sigmas[:-1] * num_train_timesteps).to(dtype=torch.float32)
|
| 194 |
+
self.timesteps_full = (sigmas * num_train_timesteps).to(dtype=torch.float32)
|
| 195 |
+
|
| 196 |
+
self._step_index = None
|
| 197 |
+
self._begin_index = None
|
| 198 |
+
|
| 199 |
+
self.supported_solver = [
|
| 200 |
+
"euler",
|
| 201 |
+
"heun-2", "midpoint-2",
|
| 202 |
+
"kutta-4",
|
| 203 |
+
]
|
| 204 |
+
if solver not in self.supported_solver:
|
| 205 |
+
raise ValueError(f"Solver {solver} not supported. Supported solvers: {self.supported_solver}")
|
| 206 |
+
|
| 207 |
+
# empty dt and derivative (for heun)
|
| 208 |
+
self.derivative_1 = None
|
| 209 |
+
self.derivative_2 = None
|
| 210 |
+
self.derivative_3 = None
|
| 211 |
+
self.dt = None
|
| 212 |
+
|
| 213 |
+
@property
|
| 214 |
+
def step_index(self):
|
| 215 |
+
"""
|
| 216 |
+
The index counter for current timestep. It will increase 1 after each scheduler step.
|
| 217 |
+
"""
|
| 218 |
+
return self._step_index
|
| 219 |
+
|
| 220 |
+
@property
|
| 221 |
+
def begin_index(self):
|
| 222 |
+
"""
|
| 223 |
+
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
| 224 |
+
"""
|
| 225 |
+
return self._begin_index
|
| 226 |
+
|
| 227 |
+
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
| 228 |
+
def set_begin_index(self, begin_index: int = 0):
|
| 229 |
+
"""
|
| 230 |
+
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
| 231 |
+
|
| 232 |
+
Args:
|
| 233 |
+
begin_index (`int`):
|
| 234 |
+
The begin index for the scheduler.
|
| 235 |
+
"""
|
| 236 |
+
self._begin_index = begin_index
|
| 237 |
+
|
| 238 |
+
def _sigma_to_t(self, sigma):
|
| 239 |
+
return sigma * self.config.num_train_timesteps
|
| 240 |
+
|
| 241 |
+
@property
|
| 242 |
+
def state_in_first_order(self):
|
| 243 |
+
return self.derivative_1 is None
|
| 244 |
+
|
| 245 |
+
@property
|
| 246 |
+
def state_in_second_order(self):
|
| 247 |
+
return self.derivative_2 is None
|
| 248 |
+
|
| 249 |
+
@property
|
| 250 |
+
def state_in_third_order(self):
|
| 251 |
+
return self.derivative_3 is None
|
| 252 |
+
|
| 253 |
+
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None,
|
| 254 |
+
n_tokens: int = None):
|
| 255 |
+
"""
|
| 256 |
+
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
| 257 |
+
|
| 258 |
+
Args:
|
| 259 |
+
num_inference_steps (`int`):
|
| 260 |
+
The number of diffusion steps used when generating samples with a pre-trained model.
|
| 261 |
+
device (`str` or `torch.device`, *optional*):
|
| 262 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 263 |
+
n_tokens (`int`, *optional*):
|
| 264 |
+
Number of tokens in the input sequence.
|
| 265 |
+
"""
|
| 266 |
+
self.num_inference_steps = num_inference_steps
|
| 267 |
+
|
| 268 |
+
sigmas = torch.linspace(1, 0, num_inference_steps + 1)
|
| 269 |
+
|
| 270 |
+
# Apply timestep shift
|
| 271 |
+
if self.config.use_flux_shift:
|
| 272 |
+
assert isinstance(n_tokens, int), "n_tokens should be provided for flux shift"
|
| 273 |
+
mu = self.get_lin_function(y1=self.config.flux_base_shift, y2=self.config.flux_max_shift)(n_tokens)
|
| 274 |
+
sigmas = self.flux_time_shift(mu, 1.0, sigmas)
|
| 275 |
+
elif self.config.shift != 1.:
|
| 276 |
+
sigmas = self.sd3_time_shift(sigmas)
|
| 277 |
+
|
| 278 |
+
if not self.config.reverse:
|
| 279 |
+
sigmas = 1 - sigmas
|
| 280 |
+
|
| 281 |
+
self.sigmas = sigmas
|
| 282 |
+
self.timesteps = (sigmas[:-1] * self.config.num_train_timesteps).to(dtype=torch.float32, device=device)
|
| 283 |
+
self.timesteps_full = (sigmas * self.config.num_train_timesteps).to(dtype=torch.float32, device=device)
|
| 284 |
+
|
| 285 |
+
# empty dt and derivative (for kutta)
|
| 286 |
+
self.derivative_1 = None
|
| 287 |
+
self.derivative_2 = None
|
| 288 |
+
self.derivative_3 = None
|
| 289 |
+
self.dt = None
|
| 290 |
+
|
| 291 |
+
# Reset step index
|
| 292 |
+
self._step_index = None
|
| 293 |
+
|
| 294 |
+
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
| 295 |
+
if schedule_timesteps is None:
|
| 296 |
+
schedule_timesteps = self.timesteps
|
| 297 |
+
|
| 298 |
+
indices = (schedule_timesteps == timestep).nonzero()
|
| 299 |
+
|
| 300 |
+
# The sigma index that is taken for the **very** first `step`
|
| 301 |
+
# is always the second index (or the last index if there is only 1)
|
| 302 |
+
# This way we can ensure we don't accidentally skip a sigma in
|
| 303 |
+
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
| 304 |
+
pos = 1 if len(indices) > 1 else 0
|
| 305 |
+
|
| 306 |
+
return indices[pos].item()
|
| 307 |
+
|
| 308 |
+
def _init_step_index(self, timestep):
|
| 309 |
+
if self.begin_index is None:
|
| 310 |
+
if isinstance(timestep, torch.Tensor):
|
| 311 |
+
timestep = timestep.to(self.timesteps.device)
|
| 312 |
+
self._step_index = self.index_for_timestep(timestep)
|
| 313 |
+
else:
|
| 314 |
+
self._step_index = self._begin_index
|
| 315 |
+
|
| 316 |
+
def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
|
| 317 |
+
return sample
|
| 318 |
+
|
| 319 |
+
@staticmethod
|
| 320 |
+
def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15):
|
| 321 |
+
m = (y2 - y1) / (x2 - x1)
|
| 322 |
+
b = y1 - m * x1
|
| 323 |
+
return lambda x: m * x + b
|
| 324 |
+
|
| 325 |
+
@staticmethod
|
| 326 |
+
def flux_time_shift(mu: float, sigma: float, t: torch.Tensor):
|
| 327 |
+
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
| 328 |
+
|
| 329 |
+
def sd3_time_shift(self, t: torch.Tensor):
|
| 330 |
+
return (self.config.shift * t) / (1 + (self.config.shift - 1) * t)
|
| 331 |
+
|
| 332 |
+
def step(
|
| 333 |
+
self,
|
| 334 |
+
model_output: torch.FloatTensor,
|
| 335 |
+
timestep: Union[float, torch.FloatTensor],
|
| 336 |
+
sample: torch.FloatTensor,
|
| 337 |
+
pred_uncond: torch.FloatTensor = None,
|
| 338 |
+
generator: Optional[torch.Generator] = None,
|
| 339 |
+
n_tokens: Optional[int] = None,
|
| 340 |
+
return_dict: bool = True,
|
| 341 |
+
) -> Union[FlowMatchDiscreteSchedulerOutput, Tuple]:
|
| 342 |
+
"""
|
| 343 |
+
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
| 344 |
+
process from the learned model outputs (most often the predicted noise).
|
| 345 |
+
|
| 346 |
+
Args:
|
| 347 |
+
model_output (`torch.FloatTensor`):
|
| 348 |
+
The direct output from learned diffusion model.
|
| 349 |
+
timestep (`float`):
|
| 350 |
+
The current discrete timestep in the diffusion chain.
|
| 351 |
+
sample (`torch.FloatTensor`):
|
| 352 |
+
A current instance of a sample created by the diffusion process.
|
| 353 |
+
generator (`torch.Generator`, *optional*):
|
| 354 |
+
A random number generator.
|
| 355 |
+
n_tokens (`int`, *optional*):
|
| 356 |
+
Number of tokens in the input sequence.
|
| 357 |
+
return_dict (`bool`):
|
| 358 |
+
Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
|
| 359 |
+
tuple.
|
| 360 |
+
|
| 361 |
+
Returns:
|
| 362 |
+
[`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
|
| 363 |
+
If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
|
| 364 |
+
returned, otherwise a tuple is returned where the first element is the sample tensor.
|
| 365 |
+
"""
|
| 366 |
+
|
| 367 |
+
if (
|
| 368 |
+
isinstance(timestep, int)
|
| 369 |
+
or isinstance(timestep, torch.IntTensor)
|
| 370 |
+
or isinstance(timestep, torch.LongTensor)
|
| 371 |
+
):
|
| 372 |
+
raise ValueError(
|
| 373 |
+
(
|
| 374 |
+
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
| 375 |
+
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
|
| 376 |
+
" one of the `scheduler.timesteps` as a timestep."
|
| 377 |
+
),
|
| 378 |
+
)
|
| 379 |
+
|
| 380 |
+
if self.step_index is None:
|
| 381 |
+
self._init_step_index(timestep)
|
| 382 |
+
|
| 383 |
+
# Upcast to avoid precision issues when computing prev_sample
|
| 384 |
+
sample = sample.to(torch.float32)
|
| 385 |
+
model_output = model_output.to(torch.float32)
|
| 386 |
+
pred_uncond = pred_uncond.to(torch.float32) if pred_uncond is not None else None
|
| 387 |
+
|
| 388 |
+
# dt = self.sigmas[self.step_index + 1] - self.sigmas[self.step_index]
|
| 389 |
+
sigma = self.sigmas[self.step_index]
|
| 390 |
+
sigma_next = self.sigmas[self.step_index + 1]
|
| 391 |
+
|
| 392 |
+
last_inner_step = True
|
| 393 |
+
if self.config.solver == "euler":
|
| 394 |
+
derivative, dt, sample, last_inner_step = self.first_order_method(model_output, sigma, sigma_next, sample)
|
| 395 |
+
elif self.config.solver in ["heun-2", "midpoint-2"]:
|
| 396 |
+
derivative, dt, sample, last_inner_step = self.second_order_method(model_output, sigma, sigma_next, sample)
|
| 397 |
+
elif self.config.solver == "kutta-4":
|
| 398 |
+
derivative, dt, sample, last_inner_step = self.fourth_order_method(model_output, sigma, sigma_next, sample)
|
| 399 |
+
else:
|
| 400 |
+
raise ValueError(f"Solver {self.config.solver} not supported. Supported solvers: {self.supported_solver}")
|
| 401 |
+
|
| 402 |
+
prev_sample = sample + derivative * dt
|
| 403 |
+
|
| 404 |
+
# Cast sample back to model compatible dtype
|
| 405 |
+
# prev_sample = prev_sample.to(model_output.dtype)
|
| 406 |
+
|
| 407 |
+
# upon completion increase step index by one
|
| 408 |
+
if last_inner_step:
|
| 409 |
+
self._step_index += 1
|
| 410 |
+
|
| 411 |
+
if not return_dict:
|
| 412 |
+
return (prev_sample,)
|
| 413 |
+
|
| 414 |
+
return FlowMatchDiscreteSchedulerOutput(prev_sample=prev_sample)
|
| 415 |
+
|
| 416 |
+
def first_order_method(self, model_output, sigma, sigma_next, sample):
|
| 417 |
+
derivative = model_output
|
| 418 |
+
dt = sigma_next - sigma
|
| 419 |
+
return derivative, dt, sample, True
|
| 420 |
+
|
| 421 |
+
def second_order_method(self, model_output, sigma, sigma_next, sample):
|
| 422 |
+
if self.state_in_first_order:
|
| 423 |
+
# store for 2nd order step
|
| 424 |
+
self.derivative_1 = model_output
|
| 425 |
+
self.dt = sigma_next - sigma
|
| 426 |
+
self.sample = sample
|
| 427 |
+
|
| 428 |
+
derivative = model_output
|
| 429 |
+
if self.config.solver == 'heun-2':
|
| 430 |
+
dt = self.dt
|
| 431 |
+
elif self.config.solver == 'midpoint-2':
|
| 432 |
+
dt = self.dt / 2
|
| 433 |
+
else:
|
| 434 |
+
raise NotImplementedError(f"Solver {self.config.solver} not supported.")
|
| 435 |
+
last_inner_step = False
|
| 436 |
+
|
| 437 |
+
else:
|
| 438 |
+
if self.config.solver == 'heun-2':
|
| 439 |
+
derivative = 0.5 * (self.derivative_1 + model_output)
|
| 440 |
+
elif self.config.solver == 'midpoint-2':
|
| 441 |
+
derivative = model_output
|
| 442 |
+
else:
|
| 443 |
+
raise NotImplementedError(f"Solver {self.config.solver} not supported.")
|
| 444 |
+
|
| 445 |
+
# 3. take prev timestep & sample
|
| 446 |
+
dt = self.dt
|
| 447 |
+
sample = self.sample
|
| 448 |
+
last_inner_step = True
|
| 449 |
+
|
| 450 |
+
# free dt and derivative
|
| 451 |
+
# Note, this puts the scheduler in "first order mode"
|
| 452 |
+
self.derivative_1 = None
|
| 453 |
+
self.dt = None
|
| 454 |
+
self.sample = None
|
| 455 |
+
|
| 456 |
+
return derivative, dt, sample, last_inner_step
|
| 457 |
+
|
| 458 |
+
def fourth_order_method(self, model_output, sigma, sigma_next, sample):
|
| 459 |
+
if self.state_in_first_order:
|
| 460 |
+
self.derivative_1 = model_output
|
| 461 |
+
self.dt = sigma_next - sigma
|
| 462 |
+
self.sample = sample
|
| 463 |
+
derivative = model_output
|
| 464 |
+
dt = self.dt / 2
|
| 465 |
+
last_inner_step = False
|
| 466 |
+
|
| 467 |
+
elif self.state_in_second_order:
|
| 468 |
+
self.derivative_2 = model_output
|
| 469 |
+
derivative = model_output
|
| 470 |
+
dt = self.dt / 2
|
| 471 |
+
last_inner_step = False
|
| 472 |
+
|
| 473 |
+
elif self.state_in_third_order:
|
| 474 |
+
self.derivative_3 = model_output
|
| 475 |
+
derivative = model_output
|
| 476 |
+
dt = self.dt
|
| 477 |
+
last_inner_step = False
|
| 478 |
+
|
| 479 |
+
else:
|
| 480 |
+
derivative = (1/6 * self.derivative_1 + 1/3 * self.derivative_2 + 1/3 * self.derivative_3 +
|
| 481 |
+
1/6 * model_output)
|
| 482 |
+
|
| 483 |
+
# 3. take prev timestep & sample
|
| 484 |
+
dt = self.dt
|
| 485 |
+
sample = self.sample
|
| 486 |
+
last_inner_step = True
|
| 487 |
+
|
| 488 |
+
# free dt and derivative
|
| 489 |
+
# Note, this puts the scheduler in "first order mode"
|
| 490 |
+
self.derivative_1 = None
|
| 491 |
+
self.derivative_2 = None
|
| 492 |
+
self.derivative_3 = None
|
| 493 |
+
self.dt = None
|
| 494 |
+
self.sample = None
|
| 495 |
+
|
| 496 |
+
return derivative, dt, sample, last_inner_step
|
| 497 |
+
|
| 498 |
+
def __len__(self):
|
| 499 |
+
return self.config.num_train_timesteps
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
class ClassifierFreeGuidance:
|
| 503 |
+
def __init__(
|
| 504 |
+
self,
|
| 505 |
+
use_original_formulation: bool = False,
|
| 506 |
+
start: float = 0.0,
|
| 507 |
+
stop: float = 1.0,
|
| 508 |
+
):
|
| 509 |
+
super().__init__()
|
| 510 |
+
self.use_original_formulation = use_original_formulation
|
| 511 |
+
|
| 512 |
+
def __call__(
|
| 513 |
+
self,
|
| 514 |
+
pred_cond: torch.Tensor,
|
| 515 |
+
pred_uncond: Optional[torch.Tensor],
|
| 516 |
+
guidance_scale: float,
|
| 517 |
+
step: int,
|
| 518 |
+
) -> torch.Tensor:
|
| 519 |
+
|
| 520 |
+
shift = pred_cond - pred_uncond
|
| 521 |
+
pred = pred_cond if self.use_original_formulation else pred_uncond
|
| 522 |
+
pred = pred + guidance_scale * shift
|
| 523 |
+
|
| 524 |
+
return pred
|
| 525 |
+
|
| 526 |
+
|
| 527 |
+
class HunyuanImage3Text2ImagePipeline(DiffusionPipeline):
|
| 528 |
+
r"""
|
| 529 |
+
Pipeline for condition-to-sample generation using Stable Diffusion.
|
| 530 |
+
|
| 531 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
| 532 |
+
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
| 533 |
+
|
| 534 |
+
Args:
|
| 535 |
+
model ([`ModelMixin`]):
|
| 536 |
+
A model to denoise the diffused latents.
|
| 537 |
+
scheduler ([`SchedulerMixin`]):
|
| 538 |
+
A scheduler to be used in combination with `diffusion_model` to denoise the diffused latents. Can be one of
|
| 539 |
+
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
| 540 |
+
"""
|
| 541 |
+
|
| 542 |
+
model_cpu_offload_seq = ""
|
| 543 |
+
_optional_components = []
|
| 544 |
+
_exclude_from_cpu_offload = []
|
| 545 |
+
_callback_tensor_inputs = ["latents"]
|
| 546 |
+
|
| 547 |
+
def __init__(
|
| 548 |
+
self,
|
| 549 |
+
model,
|
| 550 |
+
scheduler: SchedulerMixin,
|
| 551 |
+
vae,
|
| 552 |
+
progress_bar_config: Dict[str, Any] = None,
|
| 553 |
+
):
|
| 554 |
+
super().__init__()
|
| 555 |
+
|
| 556 |
+
# ==========================================================================================
|
| 557 |
+
if progress_bar_config is None:
|
| 558 |
+
progress_bar_config = {}
|
| 559 |
+
if not hasattr(self, '_progress_bar_config'):
|
| 560 |
+
self._progress_bar_config = {}
|
| 561 |
+
self._progress_bar_config.update(progress_bar_config)
|
| 562 |
+
# ==========================================================================================
|
| 563 |
+
|
| 564 |
+
self.register_modules(
|
| 565 |
+
model=model,
|
| 566 |
+
scheduler=scheduler,
|
| 567 |
+
vae=vae,
|
| 568 |
+
)
|
| 569 |
+
|
| 570 |
+
# should be a tuple or a list corresponding to the size of latents (batch_size, channel, *size)
|
| 571 |
+
# if None, will be treated as a tuple of 1
|
| 572 |
+
self.latent_scale_factor = self.model.config.vae_downsample_factor
|
| 573 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.latent_scale_factor)
|
| 574 |
+
|
| 575 |
+
# Must start with APG_mode_
|
| 576 |
+
self.cfg_operator = ClassifierFreeGuidance()
|
| 577 |
+
|
| 578 |
+
@staticmethod
|
| 579 |
+
def denormalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
|
| 580 |
+
"""
|
| 581 |
+
Denormalize an image array to [0,1].
|
| 582 |
+
"""
|
| 583 |
+
return (images / 2 + 0.5).clamp(0, 1)
|
| 584 |
+
|
| 585 |
+
@staticmethod
|
| 586 |
+
def pt_to_numpy(images: torch.Tensor) -> np.ndarray:
|
| 587 |
+
"""
|
| 588 |
+
Convert a PyTorch tensor to a NumPy image.
|
| 589 |
+
"""
|
| 590 |
+
images = images.cpu().permute(0, 2, 3, 1).float().numpy()
|
| 591 |
+
return images
|
| 592 |
+
|
| 593 |
+
@staticmethod
|
| 594 |
+
def numpy_to_pil(images: np.ndarray):
|
| 595 |
+
"""
|
| 596 |
+
Convert a numpy image or a batch of images to a PIL image.
|
| 597 |
+
"""
|
| 598 |
+
if images.ndim == 3:
|
| 599 |
+
images = images[None, ...]
|
| 600 |
+
images = (images * 255).round().astype("uint8")
|
| 601 |
+
if images.shape[-1] == 1:
|
| 602 |
+
# special case for grayscale (single channel) images
|
| 603 |
+
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
|
| 604 |
+
else:
|
| 605 |
+
pil_images = [Image.fromarray(image) for image in images]
|
| 606 |
+
|
| 607 |
+
return pil_images
|
| 608 |
+
|
| 609 |
+
def prepare_extra_func_kwargs(self, func, kwargs):
|
| 610 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
| 611 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
| 612 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
| 613 |
+
# and should be between [0, 1]
|
| 614 |
+
extra_kwargs = {}
|
| 615 |
+
|
| 616 |
+
for k, v in kwargs.items():
|
| 617 |
+
accepts = k in set(inspect.signature(func).parameters.keys())
|
| 618 |
+
if accepts:
|
| 619 |
+
extra_kwargs[k] = v
|
| 620 |
+
return extra_kwargs
|
| 621 |
+
|
| 622 |
+
def prepare_latents(self, batch_size, latent_channel, image_size, dtype, device, generator, latents=None):
|
| 623 |
+
if self.latent_scale_factor is None:
|
| 624 |
+
latent_scale_factor = (1,) * len(image_size)
|
| 625 |
+
elif isinstance(self.latent_scale_factor, int):
|
| 626 |
+
latent_scale_factor = (self.latent_scale_factor,) * len(image_size)
|
| 627 |
+
elif isinstance(self.latent_scale_factor, tuple) or isinstance(self.latent_scale_factor, list):
|
| 628 |
+
assert len(self.latent_scale_factor) == len(image_size), \
|
| 629 |
+
"len(latent_scale_factor) shoudl be the same as len(image_size)"
|
| 630 |
+
latent_scale_factor = self.latent_scale_factor
|
| 631 |
+
else:
|
| 632 |
+
raise ValueError(
|
| 633 |
+
f"latent_scale_factor should be either None, int, tuple of int, or list of int, "
|
| 634 |
+
f"but got {self.latent_scale_factor}"
|
| 635 |
+
)
|
| 636 |
+
|
| 637 |
+
latents_shape = (
|
| 638 |
+
batch_size,
|
| 639 |
+
latent_channel,
|
| 640 |
+
*[int(s) // f for s, f in zip(image_size, latent_scale_factor)],
|
| 641 |
+
)
|
| 642 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 643 |
+
raise ValueError(
|
| 644 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 645 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 646 |
+
)
|
| 647 |
+
|
| 648 |
+
if latents is None:
|
| 649 |
+
latents = randn_tensor(latents_shape, generator=generator, device=device, dtype=dtype)
|
| 650 |
+
else:
|
| 651 |
+
latents = latents.to(device)
|
| 652 |
+
|
| 653 |
+
# Check existence to make it compatible with FlowMatchEulerDiscreteScheduler
|
| 654 |
+
if hasattr(self.scheduler, "init_noise_sigma"):
|
| 655 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
| 656 |
+
latents = latents * self.scheduler.init_noise_sigma
|
| 657 |
+
|
| 658 |
+
return latents
|
| 659 |
+
|
| 660 |
+
@property
|
| 661 |
+
def guidance_scale(self):
|
| 662 |
+
return self._guidance_scale
|
| 663 |
+
|
| 664 |
+
@property
|
| 665 |
+
def guidance_rescale(self):
|
| 666 |
+
return self._guidance_rescale
|
| 667 |
+
|
| 668 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 669 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
| 670 |
+
# corresponds to doing no classifier free guidance.
|
| 671 |
+
@property
|
| 672 |
+
def do_classifier_free_guidance(self):
|
| 673 |
+
return self._guidance_scale > 1.0
|
| 674 |
+
|
| 675 |
+
@property
|
| 676 |
+
def num_timesteps(self):
|
| 677 |
+
return self._num_timesteps
|
| 678 |
+
|
| 679 |
+
def set_scheduler(self, new_scheduler):
|
| 680 |
+
self.register_modules(scheduler=new_scheduler)
|
| 681 |
+
|
| 682 |
+
@torch.no_grad()
|
| 683 |
+
def __call__(
|
| 684 |
+
self,
|
| 685 |
+
batch_size: int,
|
| 686 |
+
image_size: List[int],
|
| 687 |
+
num_inference_steps: int = 50,
|
| 688 |
+
timesteps: List[int] = None,
|
| 689 |
+
sigmas: List[float] = None,
|
| 690 |
+
guidance_scale: float = 7.5,
|
| 691 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 692 |
+
latents: Optional[torch.Tensor] = None,
|
| 693 |
+
output_type: Optional[str] = "pil",
|
| 694 |
+
return_dict: bool = True,
|
| 695 |
+
guidance_rescale: float = 0.0,
|
| 696 |
+
callback_on_step_end: Optional[
|
| 697 |
+
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
| 698 |
+
] = None,
|
| 699 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 700 |
+
model_kwargs: Dict[str, Any] = None,
|
| 701 |
+
**kwargs,
|
| 702 |
+
):
|
| 703 |
+
r"""
|
| 704 |
+
The call function to the pipeline for generation.
|
| 705 |
+
|
| 706 |
+
Args:
|
| 707 |
+
prompt (`str` or `List[str]`):
|
| 708 |
+
The text to guide image generation.
|
| 709 |
+
image_size (`Tuple[int]` or `List[int]`):
|
| 710 |
+
The size (height, width) of the generated image.
|
| 711 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 712 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 713 |
+
expense of slower inference.
|
| 714 |
+
timesteps (`List[int]`, *optional*):
|
| 715 |
+
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
| 716 |
+
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
| 717 |
+
passed will be used. Must be in descending order.
|
| 718 |
+
sigmas (`List[float]`, *optional*):
|
| 719 |
+
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
| 720 |
+
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
| 721 |
+
will be used.
|
| 722 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
| 723 |
+
A higher guidance scale value encourages the model to generate samples closely linked to the
|
| 724 |
+
`condition` at the expense of lower sample quality. Guidance scale is enabled when `guidance_scale > 1`.
|
| 725 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 726 |
+
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
| 727 |
+
generation deterministic.
|
| 728 |
+
latents (`torch.Tensor`, *optional*):
|
| 729 |
+
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for sample
|
| 730 |
+
generation. Can be used to tweak the same generation with different conditions. If not provided,
|
| 731 |
+
a latents tensor is generated by sampling using the supplied random `generator`.
|
| 732 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 733 |
+
The output format of the generated sample.
|
| 734 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 735 |
+
Whether or not to return a [`~DiffusionPipelineOutput`] instead of a
|
| 736 |
+
plain tuple.
|
| 737 |
+
guidance_rescale (`float`, *optional*, defaults to 0.0):
|
| 738 |
+
Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
|
| 739 |
+
Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
|
| 740 |
+
using zero terminal SNR.
|
| 741 |
+
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
|
| 742 |
+
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
|
| 743 |
+
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
|
| 744 |
+
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
|
| 745 |
+
list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
|
| 746 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 747 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 748 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 749 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
| 750 |
+
|
| 751 |
+
Examples:
|
| 752 |
+
|
| 753 |
+
Returns:
|
| 754 |
+
[`~DiffusionPipelineOutput`] or `tuple`:
|
| 755 |
+
If `return_dict` is `True`, [`~DiffusionPipelineOutput`] is returned,
|
| 756 |
+
otherwise a `tuple` is returned where the first element is a list with the generated samples.
|
| 757 |
+
"""
|
| 758 |
+
|
| 759 |
+
callback_steps = kwargs.pop("callback_steps", None)
|
| 760 |
+
pbar_steps = kwargs.pop("pbar_steps", None)
|
| 761 |
+
|
| 762 |
+
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
| 763 |
+
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
| 764 |
+
|
| 765 |
+
self._guidance_scale = guidance_scale
|
| 766 |
+
self._guidance_rescale = guidance_rescale
|
| 767 |
+
|
| 768 |
+
cfg_factor = 1 + self.do_classifier_free_guidance
|
| 769 |
+
|
| 770 |
+
# Define call parameters
|
| 771 |
+
device = self._execution_device
|
| 772 |
+
|
| 773 |
+
# Prepare timesteps
|
| 774 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
| 775 |
+
self.scheduler, num_inference_steps, device, timesteps, sigmas,
|
| 776 |
+
)
|
| 777 |
+
|
| 778 |
+
# Prepare latent variables
|
| 779 |
+
latents = self.prepare_latents(
|
| 780 |
+
batch_size=batch_size,
|
| 781 |
+
latent_channel=self.model.config.vae["latent_channels"],
|
| 782 |
+
image_size=image_size,
|
| 783 |
+
dtype=torch.bfloat16,
|
| 784 |
+
device=device,
|
| 785 |
+
generator=generator,
|
| 786 |
+
latents=latents,
|
| 787 |
+
)
|
| 788 |
+
|
| 789 |
+
# Prepare extra step kwargs.
|
| 790 |
+
_scheduler_step_extra_kwargs = self.prepare_extra_func_kwargs(
|
| 791 |
+
self.scheduler.step, {"generator": generator}
|
| 792 |
+
)
|
| 793 |
+
|
| 794 |
+
# Prepare model kwargs
|
| 795 |
+
input_ids = model_kwargs.pop("input_ids")
|
| 796 |
+
attention_mask = self.model._prepare_attention_mask_for_generation( # noqa
|
| 797 |
+
input_ids, self.model.generation_config, model_kwargs=model_kwargs,
|
| 798 |
+
)
|
| 799 |
+
model_kwargs["attention_mask"] = attention_mask.to(latents.device)
|
| 800 |
+
|
| 801 |
+
# Sampling loop
|
| 802 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
| 803 |
+
self._num_timesteps = len(timesteps)
|
| 804 |
+
|
| 805 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 806 |
+
for i, t in enumerate(timesteps):
|
| 807 |
+
# expand the latents if we are doing classifier free guidance
|
| 808 |
+
latent_model_input = torch.cat([latents] * cfg_factor)
|
| 809 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 810 |
+
|
| 811 |
+
t_expand = t.repeat(latent_model_input.shape[0])
|
| 812 |
+
|
| 813 |
+
model_inputs = self.model.prepare_inputs_for_generation(
|
| 814 |
+
input_ids,
|
| 815 |
+
images=latent_model_input,
|
| 816 |
+
timestep=t_expand,
|
| 817 |
+
**model_kwargs,
|
| 818 |
+
)
|
| 819 |
+
|
| 820 |
+
with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True):
|
| 821 |
+
model_output = self.model(**model_inputs, first_step=(i == 0))
|
| 822 |
+
pred = model_output["diffusion_prediction"]
|
| 823 |
+
pred = pred.to(dtype=torch.float32)
|
| 824 |
+
|
| 825 |
+
# perform guidance
|
| 826 |
+
if self.do_classifier_free_guidance:
|
| 827 |
+
pred_cond, pred_uncond = pred.chunk(2)
|
| 828 |
+
pred = self.cfg_operator(pred_cond, pred_uncond, self.guidance_scale, step=i)
|
| 829 |
+
|
| 830 |
+
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
|
| 831 |
+
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
| 832 |
+
pred = rescale_noise_cfg(pred, pred_cond, guidance_rescale=self.guidance_rescale)
|
| 833 |
+
|
| 834 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 835 |
+
latents = self.scheduler.step(pred, t, latents, **_scheduler_step_extra_kwargs, return_dict=False)[0]
|
| 836 |
+
|
| 837 |
+
if i != len(timesteps) - 1:
|
| 838 |
+
model_kwargs = self.model._update_model_kwargs_for_generation( # noqa
|
| 839 |
+
model_output,
|
| 840 |
+
model_kwargs,
|
| 841 |
+
)
|
| 842 |
+
if input_ids.shape[1] != model_kwargs["position_ids"].shape[1]:
|
| 843 |
+
input_ids = torch.gather(input_ids, 1, index=model_kwargs["position_ids"])
|
| 844 |
+
|
| 845 |
+
if callback_on_step_end is not None:
|
| 846 |
+
callback_kwargs = {}
|
| 847 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 848 |
+
callback_kwargs[k] = locals()[k]
|
| 849 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 850 |
+
|
| 851 |
+
latents = callback_outputs.pop("latents", latents)
|
| 852 |
+
|
| 853 |
+
# call the callback, if provided
|
| 854 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 855 |
+
progress_bar.update()
|
| 856 |
+
|
| 857 |
+
if hasattr(self.vae.config, 'scaling_factor') and self.vae.config.scaling_factor:
|
| 858 |
+
latents = latents / self.vae.config.scaling_factor
|
| 859 |
+
if hasattr(self.vae.config, 'shift_factor') and self.vae.config.shift_factor:
|
| 860 |
+
latents = latents + self.vae.config.shift_factor
|
| 861 |
+
|
| 862 |
+
if hasattr(self.vae, "ffactor_temporal"):
|
| 863 |
+
latents = latents.unsqueeze(2)
|
| 864 |
+
|
| 865 |
+
with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=True):
|
| 866 |
+
image = self.vae.decode(latents, return_dict=False, generator=generator)[0]
|
| 867 |
+
|
| 868 |
+
# b c t h w
|
| 869 |
+
if hasattr(self.vae, "ffactor_temporal"):
|
| 870 |
+
assert image.shape[2] == 1, "image should have shape [B, C, T, H, W] and T should be 1"
|
| 871 |
+
image = image.squeeze(2)
|
| 872 |
+
|
| 873 |
+
do_denormalize = [True] * image.shape[0]
|
| 874 |
+
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
|
| 875 |
+
|
| 876 |
+
if not return_dict:
|
| 877 |
+
return (image,)
|
| 878 |
+
|
| 879 |
+
return HunyuanImage3Text2ImagePipelineOutput(samples=image)
|
image_processor.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Licensed under the TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT (the "License");
|
| 2 |
+
# you may not use this file except in compliance with the License.
|
| 3 |
+
# You may obtain a copy of the License at
|
| 4 |
+
#
|
| 5 |
+
# https://github.com/Tencent-Hunyuan/HunyuanImage-3.0/blob/main/LICENSE
|
| 6 |
+
#
|
| 7 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 8 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 9 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 10 |
+
# See the License for the specific language governing permissions and
|
| 11 |
+
# limitations under the License.
|
| 12 |
+
# ==============================================================================
|
| 13 |
+
|
| 14 |
+
from typing import Tuple
|
| 15 |
+
|
| 16 |
+
from PIL import Image
|
| 17 |
+
from torchvision import transforms
|
| 18 |
+
from transformers import Siglip2ImageProcessorFast
|
| 19 |
+
|
| 20 |
+
from .tokenizer_wrapper import ImageInfo, JointImageInfo, ResolutionGroup
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def resize_and_crop(image: Image.Image, target_size: Tuple[int, int]) -> Image.Image:
|
| 24 |
+
tw, th = target_size
|
| 25 |
+
w, h = image.size
|
| 26 |
+
|
| 27 |
+
tr = th / tw
|
| 28 |
+
r = h / w
|
| 29 |
+
|
| 30 |
+
# resize
|
| 31 |
+
if r < tr:
|
| 32 |
+
resize_height = th
|
| 33 |
+
resize_width = int(round(th / h * w))
|
| 34 |
+
else:
|
| 35 |
+
resize_width = tw
|
| 36 |
+
resize_height = int(round(tw / w * h))
|
| 37 |
+
|
| 38 |
+
image = image.resize((resize_width, resize_height), resample=Image.Resampling.LANCZOS)
|
| 39 |
+
|
| 40 |
+
# center crop
|
| 41 |
+
crop_top = int(round((resize_height - th) / 2.0))
|
| 42 |
+
crop_left = int(round((resize_width - tw) / 2.0))
|
| 43 |
+
|
| 44 |
+
image = image.crop((crop_left, crop_top, crop_left + tw, crop_top + th))
|
| 45 |
+
return image
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class HunyuanImage3ImageProcessor(object):
|
| 49 |
+
def __init__(self, config):
|
| 50 |
+
self.config = config
|
| 51 |
+
|
| 52 |
+
self.reso_group = ResolutionGroup(base_size=config.image_base_size)
|
| 53 |
+
self.vae_processor = transforms.Compose([
|
| 54 |
+
transforms.ToTensor(),
|
| 55 |
+
transforms.Normalize([0.5], [0.5]), # transform to [-1, 1]
|
| 56 |
+
])
|
| 57 |
+
self.vision_encoder_processor = Siglip2ImageProcessorFast.from_dict(config.vit_processor)
|
| 58 |
+
|
| 59 |
+
def build_image_info(self, image_size):
|
| 60 |
+
# parse image size (HxW, H:W, or <img_ratio_i>)
|
| 61 |
+
if isinstance(image_size, str):
|
| 62 |
+
if image_size.startswith("<img_ratio_"):
|
| 63 |
+
ratio_index = int(image_size.split("_")[-1].rstrip(">"))
|
| 64 |
+
reso = self.reso_group[ratio_index]
|
| 65 |
+
image_size = reso.height, reso.width
|
| 66 |
+
elif 'x' in image_size:
|
| 67 |
+
image_size = [int(s) for s in image_size.split('x')]
|
| 68 |
+
elif ':' in image_size:
|
| 69 |
+
image_size = [int(s) for s in image_size.split(':')]
|
| 70 |
+
else:
|
| 71 |
+
raise ValueError(
|
| 72 |
+
f"`image_size` should be in the format of 'HxW', 'H:W' or <img_ratio_i>, got {image_size}.")
|
| 73 |
+
assert len(image_size) == 2, f"`image_size` should be in the format of 'HxW', got {image_size}."
|
| 74 |
+
elif isinstance(image_size, (list, tuple)):
|
| 75 |
+
assert len(image_size) == 2 and all(isinstance(s, int) for s in image_size), \
|
| 76 |
+
f"`image_size` should be a tuple of two integers or a string in the format of 'HxW', got {image_size}."
|
| 77 |
+
else:
|
| 78 |
+
raise ValueError(f"`image_size` should be a tuple of two integers or a string in the format of 'WxH', "
|
| 79 |
+
f"got {image_size}.")
|
| 80 |
+
image_width, image_height = self.reso_group.get_target_size(image_size[1], image_size[0])
|
| 81 |
+
token_height = image_height // (self.config.vae_downsample_factor[0] * self.config.patch_size)
|
| 82 |
+
token_width = image_width // (self.config.vae_downsample_factor[1] * self.config.patch_size)
|
| 83 |
+
base_size, ratio_idx = self.reso_group.get_base_size_and_ratio_index(image_size[1], image_size[0])
|
| 84 |
+
image_info = ImageInfo(
|
| 85 |
+
image_type="gen_image", image_width=image_width, image_height=image_height,
|
| 86 |
+
token_width=token_width, token_height=token_height, base_size=base_size, ratio_index=ratio_idx,
|
| 87 |
+
)
|
| 88 |
+
return image_info
|
| 89 |
+
|
| 90 |
+
def preprocess(self, image: Image.Image):
|
| 91 |
+
# ==== VAE processor ====
|
| 92 |
+
image_width, image_height = self.reso_group.get_target_size(image.width, image.height)
|
| 93 |
+
resized_image = resize_and_crop(image, (image_width, image_height))
|
| 94 |
+
image_tensor = self.vae_processor(resized_image)
|
| 95 |
+
token_height = image_height // (self.config.vae_downsample_factor[0] * self.config.patch_size)
|
| 96 |
+
token_width = image_width // (self.config.vae_downsample_factor[1] * self.config.patch_size)
|
| 97 |
+
base_size, ratio_index = self.reso_group.get_base_size_and_ratio_index(width=image_width, height=image_height)
|
| 98 |
+
vae_image_info = ImageInfo(
|
| 99 |
+
image_type="vae",
|
| 100 |
+
image_tensor=image_tensor.unsqueeze(0), # include batch dim
|
| 101 |
+
image_width=image_width, image_height=image_height,
|
| 102 |
+
token_width=token_width, token_height=token_height,
|
| 103 |
+
base_size=base_size, ratio_index=ratio_index,
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
# ==== ViT processor ====
|
| 107 |
+
inputs = self.vision_encoder_processor(image)
|
| 108 |
+
image = inputs["pixel_values"].squeeze(0) # seq_len x dim
|
| 109 |
+
pixel_attention_mask = inputs["pixel_attention_mask"].squeeze(0) # seq_len
|
| 110 |
+
spatial_shapes = inputs["spatial_shapes"].squeeze(0) # 2 (h, w)
|
| 111 |
+
vision_encoder_kwargs = dict(
|
| 112 |
+
pixel_attention_mask=pixel_attention_mask,
|
| 113 |
+
spatial_shapes=spatial_shapes,
|
| 114 |
+
)
|
| 115 |
+
vision_image_info = ImageInfo(
|
| 116 |
+
image_type="vit",
|
| 117 |
+
image_tensor=image.unsqueeze(0), # 1 x seq_len x dim
|
| 118 |
+
image_width=spatial_shapes[1].item() * self.config.vit_processor["patch_size"],
|
| 119 |
+
image_height=spatial_shapes[0].item() * self.config.vit_processor["patch_size"],
|
| 120 |
+
token_width=spatial_shapes[1].item(),
|
| 121 |
+
token_height=spatial_shapes[0].item(),
|
| 122 |
+
image_token_length=self.config.vit_processor["max_num_patches"],
|
| 123 |
+
# may not equal to token_width * token_height
|
| 124 |
+
)
|
| 125 |
+
return JointImageInfo(vae_image_info, vision_image_info, vision_encoder_kwargs)
|
model-0001-of-0032.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:dad22fa5e99dcda532c242aa4d4875f9ea6fd8b2ed59e39776dec4ea55baf4e5
|
| 3 |
+
size 5363066616
|
model-0002-of-0032.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9987e8220f81b70d07b62f06ac6c92bb0faf38ccb0ddd3f30b65ed895ad4a2fb
|
| 3 |
+
size 5318937248
|
model-0003-of-0032.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:79f8d4d1b23562299da3360ac7e2437a4dd24be30b86bc8db580521b5f9b2616
|
| 3 |
+
size 5344627472
|
model-0004-of-0032.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4faf1357831b25b9f9637594312e9024ee0fa1e87c734e20afdde2845fdaa516
|
| 3 |
+
size 5327343192
|
model-0005-of-0032.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:46189f8777c117c431e46cc57ec2328fe72050452119ac7bb676bdaca3f76575
|
| 3 |
+
size 5344103080
|
model-0006-of-0032.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8f9d5f386b7c2d0b171bd8a25f3f08e3150936fde2dfd92e9aa1f6e27dbf2e0d
|
| 3 |
+
size 5318937248
|
model-0007-of-0032.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d30616044acead06484eacace50a4cab66267feb13555f235bac63d2540cf471
|
| 3 |
+
size 5344103088
|
model-0008-of-0032.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:740ccbff8fa1dbb2847fe8c342654f7d24fa81f058065e82dfbccb89ce2743c1
|
| 3 |
+
size 5318937256
|
model-0009-of-0032.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d5fc3df50de8591735d29f7acfece39b64b3735cccef176eb4a137f4ede68430
|
| 3 |
+
size 5344103088
|
model-0010-of-0032.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3f6058eb7527741d18c17131cb7810f11d8bd4c69cce10962e093e684413cd2a
|
| 3 |
+
size 5318937304
|
model-0011-of-0032.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4c38d5fd2f18191d849b444e873ff91d3f048d8c4bcd71b3035ff0f7973ac273
|
| 3 |
+
size 5344103232
|
model-0012-of-0032.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:688a6a818f6d164d345e3bb37c4f3fcee40cc7d458027d2a37f7486463843ec3
|
| 3 |
+
size 5318937400
|
model-0013-of-0032.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f77757aa32fa67f75f8f8ec5bc831d358093483c2a8692bff7477378aea00f28
|
| 3 |
+
size 5344103232
|
model-0014-of-0032.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3308c079c20008e1ac8852cfb986764064077278754492f2fd9ec893857b6489
|
| 3 |
+
size 5318937400
|
model-0015-of-0032.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e32b467eb49473c7f42696db0916ca3275c01984c48a10433d78be4d351b7ff8
|
| 3 |
+
size 5344103232
|
model-0016-of-0032.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b97d98195a45518bae971bc43c224225b60e1fbb8b2eb93115024d2bdf328dca
|
| 3 |
+
size 5318937400
|
model-0017-of-0032.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2f00339bad7371e59f2d3642fd0575abafa92fc4509803f8fe5a64492185d2ab
|
| 3 |
+
size 5344103224
|
model-0018-of-0032.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6b48a59d090d396aa9801765485381f8255d442c2da2d9e98f1c21a68c6b83b1
|
| 3 |
+
size 5327859080
|
model-0019-of-0032.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:dd4e5a082f3db3b61774ce86675cfb171f33319fd3dd8f942cd952633834d334
|
| 3 |
+
size 5344111888
|
model-0020-of-0032.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f27fc2c0eedfc6b99ebe07e244c9689e89fa06dc65216d9c07aa6067783f86b5
|
| 3 |
+
size 5318937392
|