chaitnya26 TencentOpen commited on
Commit
22ee710
·
verified ·
0 Parent(s):

Duplicate from tencent/HunyuanImage-3.0

Browse files

Co-authored-by: TencentOpen <TencentOpen@users.noreply.huggingface.co>

This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +40 -0
  2. LICENSE +78 -0
  3. README.md +502 -0
  4. __init__.py +0 -0
  5. assets/WECHAT.md +6 -0
  6. assets/banner.png +3 -0
  7. assets/banner_all.jpg +3 -0
  8. assets/framework.png +3 -0
  9. assets/gsb.png +3 -0
  10. assets/logo.png +3 -0
  11. assets/pg_imgs/image1.png +3 -0
  12. assets/pg_imgs/image2.png +3 -0
  13. assets/pg_imgs/image3.png +3 -0
  14. assets/pg_imgs/image4.png +3 -0
  15. assets/pg_imgs/image5.png +3 -0
  16. assets/pg_imgs/image6.png +3 -0
  17. assets/pg_imgs/image7.png +3 -0
  18. assets/pg_imgs/image8.png +3 -0
  19. assets/robot.png +3 -0
  20. assets/ssae_side_by_side_comparison.png +3 -0
  21. assets/ssae_side_by_side_heatmap.png +3 -0
  22. assets/user.png +3 -0
  23. assets/wechat.png +3 -0
  24. autoencoder_kl_3d.py +793 -0
  25. config.json +273 -0
  26. configuration_hunyuan.py +285 -0
  27. generation_config.json +20 -0
  28. hunyuan.py +0 -0
  29. hunyuan_image_3_pipeline.py +879 -0
  30. image_processor.py +125 -0
  31. model-0001-of-0032.safetensors +3 -0
  32. model-0002-of-0032.safetensors +3 -0
  33. model-0003-of-0032.safetensors +3 -0
  34. model-0004-of-0032.safetensors +3 -0
  35. model-0005-of-0032.safetensors +3 -0
  36. model-0006-of-0032.safetensors +3 -0
  37. model-0007-of-0032.safetensors +3 -0
  38. model-0008-of-0032.safetensors +3 -0
  39. model-0009-of-0032.safetensors +3 -0
  40. model-0010-of-0032.safetensors +3 -0
  41. model-0011-of-0032.safetensors +3 -0
  42. model-0012-of-0032.safetensors +3 -0
  43. model-0013-of-0032.safetensors +3 -0
  44. model-0014-of-0032.safetensors +3 -0
  45. model-0015-of-0032.safetensors +3 -0
  46. model-0016-of-0032.safetensors +3 -0
  47. model-0017-of-0032.safetensors +3 -0
  48. model-0018-of-0032.safetensors +3 -0
  49. model-0019-of-0032.safetensors +3 -0
  50. model-0020-of-0032.safetensors +3 -0
.gitattributes ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
37
+ assets/banner_all.jpg filter=lfs diff=lfs merge=lfs -text
38
+ *.png filter=lfs diff=lfs merge=lfs -text
39
+ assets/**/*.png filter=lfs diff=lfs merge=lfs -text
40
+ *.tar.gz filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT
2
+ Tencent Hunyuan Image 3.0 Release Date: September 28, 2025
3
+ THIS LICENSE AGREEMENT DOES NOT APPLY IN THE EUROPEAN UNION, UNITED KINGDOM AND SOUTH KOREA AND IS EXPRESSLY LIMITED TO THE TERRITORY, AS DEFINED BELOW.
4
+ By clicking to agree or by using, reproducing, modifying, distributing, performing or displaying any portion or element of the Tencent Hunyuan Works, including via any Hosted Service, You will be deemed to have recognized and accepted the content of this Agreement, which is effective immediately.
5
+ 1. DEFINITIONS.
6
+ a. “Acceptable Use Policy” shall mean the policy made available by Tencent as set forth in the Exhibit A.
7
+ b. “Agreement” shall mean the terms and conditions for use, reproduction, distribution, modification, performance and displaying of Tencent Hunyuan Works or any portion or element thereof set forth herein.
8
+ c. “Documentation” shall mean the specifications, manuals and documentation for Tencent Hunyuan made publicly available by Tencent.
9
+ d. “Hosted Service” shall mean a hosted service offered via an application programming interface (API), web access, or any other electronic or remote means.
10
+ e. “Licensee,” “You” or “Your” shall mean a natural person or legal entity exercising the rights granted by this Agreement and/or using the Tencent Hunyuan Works for any purpose and in any field of use.
11
+ f. “Materials” shall mean, collectively, Tencent’s proprietary Tencent Hunyuan and Documentation (and any portion thereof) as made available by Tencent under this Agreement.
12
+ g. “Model Derivatives” shall mean all: (i) modifications to Tencent Hunyuan or any Model Derivative of Tencent Hunyuan; (ii) works based on Tencent Hunyuan or any Model Derivative of Tencent Hunyuan; or (iii) any other machine learning model which is created by transfer of patterns of the weights, parameters, operations, or Output of Tencent Hunyuan or any Model Derivative of Tencent Hunyuan, to that model in order to cause that model to perform similarly to Tencent Hunyuan or a Model Derivative of Tencent Hunyuan, including distillation methods, methods that use intermediate data representations, or methods based on the generation of synthetic data Outputs by Tencent Hunyuan or a Model Derivative of Tencent Hunyuan for training that model. For clarity, Outputs by themselves are not deemed Model Derivatives.
13
+ h. “Output” shall mean the information and/or content output of Tencent Hunyuan or a Model Derivative that results from operating or otherwise using Tencent Hunyuan or a Model Derivative, including via a Hosted Service.
14
+ i. “Tencent,” “We” or “Us” shall mean the applicable entity or entities in the Tencent corporate family that own(s) intellectual property or other rights embodied in or utilized by the Materials.
15
+ j. “Tencent Hunyuan” shall mean the large language models, text/image/video/audio/3D generation models, and multimodal large language models and their software and algorithms, including trained model weights, parameters (including optimizer states), machine-learning model code, inference-enabling code, training-enabling code, fine-tuning enabling code and other elements of the foregoing made publicly available by Us, including, without limitation to, Tencent Hunyuan Image 2.1 released at [
16
+ https://github.com/Tencent-Hunyuan/HunyuanImage-3.0;https://huggingface.co/tencent/HunyuanImage-3.0;https://huggingface.co/tencent/HunyuanImage-3.0-Instruct;https://modelscope.cn/models/Tencent-Hunyuan HunyuanImage-3.0/;https://ai.gitcode.com/tencent_hunyuan/HunyuanImage-3.0].
17
+ k. “Tencent Hunyuan Works” shall mean: (i) the Materials; (ii) Model Derivatives; and (iii) all derivative works thereof.
18
+ l. “Territory” shall mean the worldwide territory, excluding the territory of the European Union, United Kingdom and South Korea.
19
+ m. “Third Party” or “Third Parties” shall mean individuals or legal entities that are not under common control with Us or You.
20
+ n. “including” shall mean including but not limited to.
21
+ 2. GRANT OF RIGHTS.
22
+ We grant You, for the Territory only, a non-exclusive, non-transferable and royalty-free limited license under Tencent’s intellectual property or other rights owned by Us embodied in or utilized by the Materials to use, reproduce, distribute, create derivative works of (including Model Derivatives), and make modifications to the Materials, only in accordance with the terms of this Agreement and the Acceptable Use Policy, and You must not violate (or encourage or permit anyone else to violate) any term of this Agreement or the Acceptable Use Policy.
23
+ 3. DISTRIBUTION.
24
+ You may, subject to Your compliance with this Agreement, distribute or make available to Third Parties the Tencent Hunyuan Works, exclusively in the Territory, provided that You meet all of the following conditions:
25
+ a. You must provide all such Third Party recipients of the Tencent Hunyuan Works or products or services using them a copy of this Agreement;
26
+ b. You must cause any modified files to carry prominent notices stating that You changed the files;
27
+ c. You are encouraged to: (i) publish at least one technology introduction blogpost or one public statement expressing Your experience of using the Tencent Hunyuan Works; and (ii) mark the products or services developed by using the Tencent Hunyuan Works to indicate that the product/service is “Powered by Tencent Hunyuan”; and
28
+ d. All distributions to Third Parties (other than through a Hosted Service) must be accompanied by a “Notice” text file that contains the following notice: “Tencent Hunyuan is licensed under the Tencent Hunyuan Community License Agreement, Copyright © 2025 Tencent. All Rights Reserved. The trademark rights of “Tencent Hunyuan” are owned by Tencent or its affiliate.”
29
+ You may add Your own copyright statement to Your modifications and, except as set forth in this Section and in Section 5, may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Model Derivatives as a whole, provided Your use, reproduction, modification, distribution, performance and display of the work otherwise complies with the terms and conditions of this Agreement (including as regards the Territory). If You receive Tencent Hunyuan Works from a Licensee as part of an integrated end user product, then this Section 3 of this Agreement will not apply to You.
30
+ 4. ADDITIONAL COMMERCIAL TERMS.
31
+ If, on the Tencent Hunyuan version release date, the monthly active users of all products or services made available by or for Licensee is greater than 100 million monthly active users in the preceding calendar month, You must request a license from Tencent, which Tencent may grant to You in its sole discretion, and You are not authorized to exercise any of the rights under this Agreement unless or until Tencent otherwise expressly grants You such rights.
32
+ 5. RULES OF USE.
33
+ a. Your use of the Tencent Hunyuan Works must comply with applicable laws and regulations (including trade compliance laws and regulations) and adhere to the Acceptable Use Policy for the Tencent Hunyuan Works, which is hereby incorporated by reference into this Agreement. You must include the use restrictions referenced in these Sections 5(a) and 5(b) as an enforceable provision in any agreement (e.g., license agreement, terms of use, etc.) governing the use and/or distribution of Tencent Hunyuan Works and You must provide notice to subsequent users to whom You distribute that Tencent Hunyuan Works are subject to the use restrictions in these Sections 5(a) and 5(b).
34
+ b. You must not use the Tencent Hunyuan Works or any Output or results of the Tencent Hunyuan Works to improve any other AI model (other than Tencent Hunyuan or Model Derivatives thereof).
35
+ c. You must not use, reproduce, modify, distribute, or display the Tencent Hunyuan Works, Output or results of the Tencent Hunyuan Works outside the Territory. Any such use outside the Territory is unlicensed and unauthorized under this Agreement.
36
+ 6. INTELLECTUAL PROPERTY.
37
+ a. Subject to Tencent’s ownership of Tencent Hunyuan Works made by or for Tencent and intellectual property rights therein, conditioned upon Your compliance with the terms and conditions of this Agreement, as between You and Tencent, You will be the owner of any derivative works and modifications of the Materials and any Model Derivatives that are made by or for You.
38
+ b. No trademark licenses are granted under this Agreement, and in connection with the Tencent Hunyuan Works, Licensee may not use any name or mark owned by or associated with Tencent or any of its affiliates, except as required for reasonable and customary use in describing and distributing the Tencent Hunyuan Works. Tencent hereby grants You a license to use “Tencent Hunyuan” (the “Mark”) in the Territory solely as required to comply with the provisions of Section 3(c), provided that You comply with any applicable laws related to trademark protection. All goodwill arising out of Your use of the Mark will inure to the benefit of Tencent.
39
+ c. If You commence a lawsuit or other proceedings (including a cross-claim or counterclaim in a lawsuit) against Us or any person or entity alleging that the Materials or any Output, or any portion of any of the foregoing, infringe any intellectual property or other right owned or licensable by You, then all licenses granted to You under this Agreement shall terminate as of the date such lawsuit or other proceeding is filed. You will defend, indemnify and hold harmless Us from and against any claim by any Third Party arising out of or related to Your or the Third Party’s use or distribution of the Tencent Hunyuan Works.
40
+ d. Tencent claims no rights in Outputs You generate. You and Your users are solely responsible for Outputs and their subsequent uses.
41
+ 7. DISCLAIMERS OF WARRANTY AND LIMITATIONS OF LIABILITY.
42
+ a. We are not obligated to support, update, provide training for, or develop any further version of the Tencent Hunyuan Works or to grant any license thereto.
43
+ b. UNLESS AND ONLY TO THE EXTENT REQUIRED BY APPLICABLE LAW, THE TENCENT HUNYUAN WORKS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED “AS IS” WITHOUT ANY EXPRESS OR IMPLIED WARRANTIES OF ANY KIND INCLUDING ANY WARRANTIES OF TITLE, MERCHANTABILITY, NONINFRINGEMENT, COURSE OF DEALING, USAGE OF TRADE, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING, REPRODUCING, MODIFYING, PERFORMING, DISPLAYING OR DISTRIBUTING ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS AND ASSUME ANY AND ALL RISKS ASSOCIATED WITH YOUR OR A THIRD PARTY’S USE OR DISTRIBUTION OF ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS AND YOUR EXERCISE OF RIGHTS AND PERMISSIONS UNDER THIS AGREEMENT.
44
+ c. TO THE FULLEST EXTENT PERMITTED BY APPLICABLE LAW, IN NO EVENT SHALL TENCENT OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, FOR ANY DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, CONSEQUENTIAL OR PUNITIVE DAMAGES, OR LOST PROFITS OF ANY KIND ARISING FROM THIS AGREEMENT OR RELATED TO ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS, EVEN IF TENCENT OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
45
+ 8. SURVIVAL AND TERMINATION.
46
+ a. The term of this Agreement shall commence upon Your acceptance of this Agreement or access to the Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein.
47
+ b. We may terminate this Agreement if You breach any of the terms or conditions of this Agreement. Upon termination of this Agreement, You must promptly delete and cease use of the Tencent Hunyuan Works. Sections 6(a), 6(c), 7 and 9 shall survive the termination of this Agreement.
48
+ 9. GOVERNING LAW AND JURISDICTION.
49
+ a. This Agreement and any dispute arising out of or relating to it will be governed by the laws of the Hong Kong Special Administrative Region of the People’s Republic of China, without regard to conflict of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement.
50
+ b. Exclusive jurisdiction and venue for any dispute arising out of or relating to this Agreement will be a court of competent jurisdiction in the Hong Kong Special Administrative Region of the People’s Republic of China, and Tencent and Licensee consent to the exclusive jurisdiction of such court with respect to any such dispute.
51
+
52
+ EXHIBIT A
53
+ ACCEPTABLE USE POLICY
54
+
55
+ Tencent reserves the right to update this Acceptable Use Policy from time to time.
56
+ Last modified: November 5, 2024
57
+
58
+ Tencent endeavors to promote safe and fair use of its tools and features, including Tencent Hunyuan. You agree not to use Tencent Hunyuan or Model Derivatives:
59
+ 1. Outside the Territory;
60
+ 2. In any way that violates any applicable national, federal, state, local, international or any other law or regulation;
61
+ 3. To harm Yourself or others;
62
+ 4. To repurpose or distribute output from Tencent Hunyuan or any Model Derivatives to harm Yourself or others;
63
+ 5. To override or circumvent the safety guardrails and safeguards We have put in place;
64
+ 6. For the purpose of exploiting, harming or attempting to exploit or harm minors in any way;
65
+ 7. To generate or disseminate verifiably false information and/or content with the purpose of harming others or influencing elections;
66
+ 8. To generate or facilitate false online engagement, including fake reviews and other means of fake online engagement;
67
+ 9. To intentionally defame, disparage or otherwise harass others;
68
+ 10. To generate and/or disseminate malware (including ransomware) or any other content to be used for the purpose of harming electronic systems;
69
+ 11. To generate or disseminate personal identifiable information with the purpose of harming others;
70
+ 12. To generate or disseminate information (including images, code, posts, articles), and place the information in any public context (including –through the use of bot generated tweets), without expressly and conspicuously identifying that the information and/or content is machine generated;
71
+ 13. To impersonate another individual without consent, authorization, or legal right;
72
+ 14. To make high-stakes automated decisions in domains that affect an individual’s safety, rights or wellbeing (e.g., law enforcement, migration, medicine/health, management of critical infrastructure, safety components of products, essential services, credit, employment, housing, education, social scoring, or insurance);
73
+ 15. In a manner that violates or disrespects the social ethics and moral standards of other countries or regions;
74
+ 16. To perform, facilitate, threaten, incite, plan, promote or encourage violent extremism or terrorism;
75
+ 17. For any use intended to discriminate against or harm individuals or groups based on protected characteristics or categories, online or offline social behavior or known or predicted personal or personality characteristics;
76
+ 18. To intentionally exploit any of the vulnerabilities of a specific group of persons based on their age, social, physical or mental characteristics, in order to materially distort the behavior of a person pertaining to that group in a manner that causes or is likely to cause that person or another person physical or psychological harm;
77
+ 19. For military purposes;
78
+ 20. To engage in the unauthorized or unlicensed practice of any profession including, but not limited to, financial, legal, medical/health, or other professional practices.
README.md ADDED
@@ -0,0 +1,502 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: other
3
+ license_name: tencent-hunyuan-community
4
+ license_link: LICENSE
5
+ pipeline_tag: text-to-image
6
+ library_name: transformers
7
+ ---
8
+
9
+ <div align="center">
10
+
11
+ <img src="./assets/logo.png" alt="HunyuanImage-3.0 Logo" width="600">
12
+
13
+ # 🎨 HunyuanImage-3.0: A Powerful Native Multimodal Model for Image Generation
14
+
15
+ </div>
16
+
17
+
18
+ <div align="center">
19
+ <img src="./assets/banner.png" alt="HunyuanImage-3.0 Banner" width="800">
20
+
21
+ </div>
22
+
23
+ <div align="center">
24
+ <a href=https://hunyuan.tencent.com/image target="_blank"><img src=https://img.shields.io/badge/Official%20Site-333399.svg?logo=homepage height=22px></a>
25
+ <a href=https://huggingface.co/tencent/HunyuanImage-3.0 target="_blank"><img src=https://img.shields.io/badge/%F0%9F%A4%97%20Models-d96902.svg height=22px></a>
26
+ <a href=https://github.com/Tencent-Hunyuan/HunyuanImage-3.0 target="_blank"><img src= https://img.shields.io/badge/Page-bb8a2e.svg?logo=github height=22px></a>
27
+ <a href=https://arxiv.org/pdf/2509.23951 target="_blank"><img src=https://img.shields.io/badge/Report-b5212f.svg?logo=arxiv height=22px></a>
28
+ <a href=https://x.com/TencentHunyuan target="_blank"><img src=https://img.shields.io/badge/Hunyuan-black.svg?logo=x height=22px></a>
29
+ <a href=https://docs.qq.com/doc/DUVVadmhCdG9qRXBU target="_blank"><img src=https://img.shields.io/badge/📚-PromptHandBook-blue.svg?logo=book height=22px></a>
30
+ </div>
31
+
32
+
33
+ <p align="center">
34
+ 👏 Join our <a href="./assets/WECHAT.md" target="_blank">WeChat</a> and <a href="https://discord.gg/ehjWMqF5wY">Discord</a> |
35
+ 💻 <a href="https://hunyuan.tencent.com/modelSquare/home/play?modelId=289&from=/visual">Official website(官网) Try our model!</a>&nbsp&nbsp
36
+ </p>
37
+
38
+ ## 🔥🔥🔥 News
39
+ - **September 28, 2025**: 📖 **HunyuanImage-3.0 Technical Report Released** - Comprehensive technical documentation now available
40
+ - **September 28, 2025**: 🚀 **HunyuanImage-3.0 Open Source Release** - Inference code and model weights publicly available
41
+
42
+
43
+ ## 🧩 Community Contributions
44
+
45
+ If you develop/use HunyuanImage-3.0 in your projects, welcome to let us know.
46
+
47
+ ## 📑 Open-source Plan
48
+
49
+ - HunyuanImage-3.0 (Image Generation Model)
50
+ - [x] Inference
51
+ - [x] HunyuanImage-3.0 Checkpoints
52
+ - [ ] HunyuanImage-3.0-Instruct Checkpoints (with reasoning)
53
+ - [ ] VLLM Support
54
+ - [ ] Distilled Checkpoints
55
+ - [ ] Image-to-Image Generation
56
+ - [ ] Multi-turn Interaction
57
+
58
+
59
+ ## 🗂️ Contents
60
+ - [🔥🔥🔥 News](#-news)
61
+ - [🧩 Community Contributions](#-community-contributions)
62
+ - [📑 Open-source Plan](#-open-source-plan)
63
+ - [📖 Introduction](#-introduction)
64
+ - [✨ Key Features](#-key-features)
65
+ - [🛠️ Dependencies and Installation](#-dependencies-and-installation)
66
+ - [💻 System Requirements](#-system-requirements)
67
+ - [📦 Environment Setup](#-environment-setup)
68
+ - [📥 Install Dependencies](#-install-dependencies)
69
+ - [Performance Optimizations](#performance-optimizations)
70
+ - [🚀 Usage](#-usage)
71
+ - [🔥 Quick Start with Transformers](#-quick-start-with-transformers)
72
+ - [🏠 Local Installation & Usage](#-local-installation--usage)
73
+ - [🎨 Interactive Gradio Demo](#-interactive-gradio-demo)
74
+ - [🧱 Models Cards](#-models-cards)
75
+ - [📝 Prompt Guide](#-prompt-guide)
76
+ - [Manually Writing Prompts](#manually-writing-prompts)
77
+ - [System Prompt For Automatic Rewriting the Prompt](#system-prompt-for-automatic-rewriting-the-prompt)
78
+ - [Advanced Tips](#advanced-tips)
79
+ - [More Cases](#more-cases)
80
+ - [📊 Evaluation](#-evaluation)
81
+ - [📚 Citation](#-citation)
82
+ - [🙏 Acknowledgements](#-acknowledgements)
83
+ - [🌟🚀 Github Star History](#-github-star-history)
84
+
85
+ ---
86
+
87
+ ## 📖 Introduction
88
+
89
+ **HunyuanImage-3.0** is a groundbreaking native multimodal model that unifies multimodal understanding and generation within an autoregressive framework. Our text-to-image module achieves performance **comparable to or surpassing** leading closed-source models.
90
+
91
+
92
+ <div align="center">
93
+ <img src="./assets/framework.png" alt="HunyuanImage-3.0 Framework" width="90%">
94
+ </div>
95
+
96
+ ## ✨ Key Features
97
+
98
+ * 🧠 **Unified Multimodal Architecture:** Moving beyond the prevalent DiT-based architectures, HunyuanImage-3.0 employs a unified autoregressive framework. This design enables a more direct and integrated modeling of text and image modalities, leading to surprisingly effective and contextually rich image generation.
99
+
100
+ * 🏆 **The Largest Image Generation MoE Model:** This is the largest open-source image generation Mixture of Experts (MoE) model to date. It features 64 experts and a total of 80 billion parameters, with 13 billion activated per token, significantly enhancing its capacity and performance.
101
+
102
+ * 🎨 **Superior Image Generation Performance:** Through rigorous dataset curation and advanced reinforcement learning post-training, we've achieved an optimal balance between semantic accuracy and visual excellence. The model demonstrates exceptional prompt adherence while delivering photorealistic imagery with stunning aesthetic quality and fine-grained details.
103
+
104
+ * 💭 **Intelligent World-Knowledge Reasoning:** The unified multimodal architecture endows HunyuanImage-3.0 with powerful reasoning capabilities. It leverages its extensive world knowledge to intelligently interpret user intent, automatically elaborating on sparse prompts with contextually appropriate details to produce superior, more complete visual outputs.
105
+
106
+
107
+ ## 🛠️ Dependencies and Installation
108
+
109
+ ### 💻 System Requirements
110
+
111
+ * 🖥️ **Operating System:** Linux
112
+ * 🎮 **GPU:** NVIDIA GPU with CUDA support
113
+ * 💾 **Disk Space:** 170GB for model weights
114
+ * 🧠 **GPU Memory:** ≥3×80GB (4×80GB recommended for better performance)
115
+
116
+ ### 📦 Environment Setup
117
+
118
+ * 🐍 **Python:** 3.12+ (recommended and tested)
119
+ * 🔥 **PyTorch:** 2.7.1
120
+ * ⚡ **CUDA:** 12.8
121
+
122
+ ### 📥 Install Dependencies
123
+
124
+ ```bash
125
+ # 1. First install PyTorch (CUDA 12.8 Version)
126
+ pip install torch==2.7.1 torchvision==0.22.1 torchaudio==2.7.1 --index-url https://download.pytorch.org/whl/cu128
127
+
128
+ # 2. Then install tencentcloud-sdk
129
+ pip install -i https://mirrors.tencent.com/pypi/simple/ --upgrade tencentcloud-sdk-python
130
+
131
+ # 3. Then install other dependencies
132
+ pip install -r requirements.txt
133
+ ```
134
+
135
+ #### Performance Optimizations
136
+
137
+ For **up to 3x faster inference**, install these optimizations:
138
+
139
+ ```bash
140
+ # FlashAttention for faster attention computation
141
+ pip install flash-attn==2.8.3 --no-build-isolation
142
+
143
+ # FlashInfer for optimized moe inference. v0.3.1 is tested.
144
+ pip install flashinfer-python
145
+ ```
146
+ > 💡**Installation Tips:** It is critical that the CUDA version used by PyTorch matches the system's CUDA version.
147
+ > FlashInfer relies on this compatibility when compiling kernels at runtime. Pytorch 2.7.1+cu128 is tested.
148
+ > GCC version >=9 is recommended for compiling FlashAttention and FlashInfer.
149
+
150
+ > ⚡ **Performance Tips:** These optimizations can significantly speed up your inference!
151
+
152
+ > 💡**Notation:** When FlashInfer is enabled, the first inference may be slower (about 10 minutes) due to kernel compilation. Subsequent inferences on the same machine will be much faster.
153
+
154
+ ## 🚀 Usage
155
+
156
+ ### 🔥 Quick Start with Transformers
157
+
158
+ #### 1️⃣ Download model weights
159
+
160
+ ```bash
161
+ # Download from HuggingFace and rename the directory.
162
+ # Notice that the directory name should not contain dots, which may cause issues when loading using Transformers.
163
+ hf download tencent/HunyuanImage-3.0 --local-dir ./HunyuanImage-3
164
+ ```
165
+
166
+ #### 2️⃣ Run with Transformers
167
+
168
+ ```python
169
+ from transformers import AutoModelForCausalLM
170
+
171
+ # Load the model
172
+ model_id = "./HunyuanImage-3"
173
+ # Currently we can not load the model using HF model_id `tencent/HunyuanImage-3.0` directly
174
+ # due to the dot in the name.
175
+
176
+ kwargs = dict(
177
+ attn_implementation="sdpa", # Use "flash_attention_2" if FlashAttention is installed
178
+ trust_remote_code=True,
179
+ torch_dtype="auto",
180
+ device_map="auto",
181
+ moe_impl="eager", # Use "flashinfer" if FlashInfer is installed
182
+ )
183
+
184
+ model = AutoModelForCausalLM.from_pretrained(model_id, **kwargs)
185
+ model.load_tokenizer(model_id)
186
+
187
+ # generate the image
188
+ prompt = "A brown and white dog is running on the grass"
189
+ image = model.generate_image(prompt=prompt, stream=True)
190
+ image.save("image.png")
191
+ ```
192
+
193
+ ### 🏠 Local Installation & Usage
194
+
195
+ #### 1️⃣ Clone the Repository
196
+
197
+ ```bash
198
+ git clone https://github.com/Tencent-Hunyuan/HunyuanImage-3.0.git
199
+ cd HunyuanImage-3.0/
200
+ ```
201
+
202
+ #### 2️⃣ Download Model Weights
203
+
204
+ ```bash
205
+ # Download from HuggingFace
206
+ hf download tencent/HunyuanImage-3.0 --local-dir ./HunyuanImage-3
207
+ ```
208
+
209
+ #### 3️⃣ Run the Demo
210
+ The Pretrain Checkpoint does not automatically rewrite or enhance input prompts, for optimal results currently, we recommend community partners to use deepseek to rewrite the prompts. You can go to [Tencent Cloud](https://cloud.tencent.com/document/product/1772/115963#.E5.BF.AB.E9.80.9F.E6.8E.A5.E5.85.A5) to apply for an API Key.
211
+
212
+ ```bash
213
+ # set env
214
+ export DEEPSEEK_KEY_ID="your_deepseek_key_id"
215
+ export DEEPSEEK_KEY_SECRET="your_deepseek_key_secret"
216
+
217
+ python3 run_image_gen.py --model-id ./HunyuanImage-3 --verbose 1 --sys-deepseek-prompt "universal" --prompt "A brown and white dog is running on the grass"
218
+ ```
219
+
220
+ #### 4️⃣ Command Line Arguments
221
+
222
+ | Arguments | Description | Default |
223
+ | ----------------------- | ------------------------------------------------------------ | ----------- |
224
+ | `--prompt` | Input prompt | (Required) |
225
+ | `--model-id` | Model path | (Required) |
226
+ | `--attn-impl` | Attention implementation. Either `sdpa` or `flash_attention_2`. | `sdpa` |
227
+ | `--moe-impl` | MoE implementation. Either `eager` or `flashinfer` | `eager` |
228
+ | `--seed` | Random seed for image generation | `None` |
229
+ | `--diff-infer-steps` | Diffusion infer steps | `50` |
230
+ | `--image-size` | Image resolution. Can be `auto`, like `1280x768` or `16:9` | `auto` |
231
+ | `--save` | Image save path. | `image.png` |
232
+ | `--verbose` | Verbose level. 0: No log; 1: log inference information. | `0` |
233
+ | `--rewrite` | Whether to enable rewriting | `1` |
234
+ | `--sys-deepseek-prompt` | Select sys-prompt from `universal` or `text_rendering` | `universal` |
235
+
236
+ ### 🎨 Interactive Gradio Demo
237
+
238
+ Launch an interactive web interface for easy text-to-image generation.
239
+
240
+ #### 1️⃣ Install Gradio
241
+
242
+ ```bash
243
+ pip install gradio>=4.21.0
244
+ ```
245
+
246
+ #### 2️⃣ Configure Environment
247
+
248
+ ```bash
249
+ # Set your model path
250
+ export MODEL_ID="path/to/your/model"
251
+
252
+ # Optional: Configure GPU usage (default: 0,1,2,3)
253
+ export GPUS="0,1,2,3"
254
+
255
+ # Optional: Configure host and port (default: 0.0.0.0:443)
256
+ export HOST="0.0.0.0"
257
+ export PORT="443"
258
+ ```
259
+
260
+ #### 3️⃣ Launch the Web Interface
261
+
262
+ **Basic Launch:**
263
+ ```bash
264
+ sh run_app.sh
265
+ ```
266
+
267
+ **With Performance Optimizations:**
268
+ ```bash
269
+ # Use both optimizations for maximum performance
270
+ sh run_app.sh --moe-impl flashinfer --attn-impl flash_attention_2
271
+ ```
272
+
273
+ #### 4️⃣ Access the Interface
274
+
275
+ > 🌐 **Web Interface:** Open your browser and navigate to `http://localhost:443` (or your configured port)
276
+
277
+
278
+ ## 🧱 Models Cards
279
+
280
+ | Model | Params | Download | Recommended VRAM | Supported |
281
+ |---------------------------| --- | --- | --- | --- |
282
+ | HunyuanImage-3.0 | 80B total (13B active) | [HuggingFace](https://huggingface.co/tencent/HunyuanImage-3.0) | ≥ 3 × 80 GB | ✅ Text-to-Image
283
+ | HunyuanImage-3.0-Instruct | 80B total (13B active) | [HuggingFace](https://huggingface.co/tencent/HunyuanImage-3.0-Instruct) | ≥ 3 × 80 GB | ✅ Text-to-Image<br>✅ Prompt Self-Rewrite <br>✅ CoT Think
284
+
285
+
286
+
287
+ Notes:
288
+ - Install performance extras (FlashAttention, FlashInfer) for faster inference.
289
+ - Multi‑GPU inference is recommended for the Base model.
290
+
291
+
292
+ ## 📝 Prompt Guide
293
+
294
+ ### Manually Writing Prompts.
295
+ The Pretrain Checkpoint does not automatically rewrite or enhance input prompts, Instruct Checkpoint can rewrite or enhance input prompts with thinking . For optimal results currently, we recommend community partners consulting our official guide on how to write effective prompts.
296
+
297
+ Reference: [HunyuanImage 3.0 Prompt Handbook](
298
+ https://docs.qq.com/doc/DUVVadmhCdG9qRXBU)
299
+
300
+
301
+ ### System Prompt For Automatic Rewriting the Prompt.
302
+
303
+ We've included two system prompts in the PE folder of this repository that leverage DeepSeek to automatically enhance user inputs:
304
+
305
+ * **system_prompt_universal**: This system prompt converts photographic style, artistic prompts into a detailed one.
306
+ * **system_prompt_text_rendering**: This system prompt converts UI/Poster/Text Rending prompts to a deailed on that suits the model.
307
+
308
+ Note that these system prompts are in Chinese because Deepseek works better with Chinese system prompts. If you want to use it for English oriented model, you may translate it into English or refer to the comments in the PE file as a guide.
309
+
310
+ We also create a [Yuanqi workflow](https://yuanqi.tencent.com/agent/H69VgtJdj3Dz) to implement the universal one, you can directly try it.
311
+
312
+ ### Advanced Tips
313
+ - **Content Priority**: Focus on describing the main subject and action first, followed by details about the environment and style. A more general description framework is: **Main subject and scene + Image quality and style + Composition and perspective + Lighting and atmosphere + Technical parameters**. Keywords can be added both before and after this structure.
314
+
315
+ - **Image resolution**: Our model not only supports multiple resolutions but also offers both **automatic and specified resolution** options. In auto mode, the model automatically predicts the image resolution based on the input prompt. In specified mode (like traditional DiT), the model outputs an image resolution that strictly aligns with the user's chosen resolution.
316
+
317
+ ### More Cases
318
+ Our model can follow complex instructions to generate high‑quality, creative images.
319
+
320
+ <div align="center">
321
+ <img src="./assets/banner_all.jpg" width=100% alt="HunyuanImage 3.0 Demo">
322
+ </div>
323
+
324
+ Our model can effectively process very long text inputs, enabling users to precisely control the finer details of generated images. Extended prompts allow for intricate elements to be accurately captured, making it ideal for complex projects requiring precision and creativity.
325
+
326
+ <p align="center">
327
+ <table>
328
+ <thead>
329
+ </thead>
330
+ <tbody>
331
+ <tr>
332
+ <td>
333
+ <img src="./assets/pg_imgs/image1.png" width=100%><details>
334
+ <summary>Show prompt</summary>
335
+ A cinematic medium shot captures a single Asian woman seated on a chair within a dimly lit room, creating an intimate and theatrical atmosphere. The composition is focused on the subject, rendered with rich colors and intricate textures that evoke a nostalgic and moody feeling.
336
+
337
+ The primary subject is a young Asian woman with a thoughtful and expressive countenance, her gaze directed slightly away from the camera. She is seated in a relaxed yet elegant posture on an ornate, vintage armchair. The chair is upholstered in a deep red velvet, its fabric showing detailed, intricate textures and slight signs of wear. She wears a simple, elegant dress in a dark teal hue, the material catching the light in a way that reveals its fine-woven texture. Her skin has a soft, matte quality, and the light delicately models the contours of her face and arms.
338
+
339
+ The surrounding room is characterized by its vintage decor, which contributes to the historic and evocative mood. In the immediate background, partially blurred due to a shallow depth of field consistent with a f/2.8 aperture, the wall is covered with wallpaper featuring a subtle, damask pattern. The overall color palette is a carefully balanced interplay of deep teal and rich red hues, creating a visually compelling and cohesive environment. The entire scene is detailed, from the fibers of the upholstery to the subtle patterns on the wall.
340
+
341
+ The lighting is highly dramatic and artistic, defined by high contrast and pronounced shadow play. A single key light source, positioned off-camera, projects gobo lighting patterns onto the scene, casting intricate shapes of light and shadow across the woman and the back wall. These dramatic shadows create a strong sense of depth and a theatrical quality. While some shadows are deep and defined, others remain soft, gently wrapping around the subject and preventing the loss of detail in darker areas. The soft focus on the background enhances the intimate feeling, drawing all attention to the expressive subject. The overall image presents a cinematic, photorealistic photography style.
342
+ </details>
343
+ </td>
344
+ <td><img src="./assets/pg_imgs/image2.png" width=100%><details>
345
+ <summary>Show prompt</summary>
346
+ A cinematic, photorealistic medium shot captures a high-contrast urban street corner, defined by the sharp intersection of light and shadow. The primary subject is the exterior corner of a building, rendered in a low-saturation, realistic style.
347
+
348
+ The building wall, which occupies the majority of the frame, is painted a warm orange with a finely detailed, rough stucco texture. Horizontal white stripes run across its surface. The base of the building is constructed from large, rough-hewn stone blocks, showing visible particles and texture. On the left, illuminated side of the building, there is a single window with closed, dark-colored shutters. Adjacent to the window, a simple black pendant lamp hangs from a thin, taut rope, casting a distinct, sharp-edged shadow onto the sunlit orange wall. The composition is split diagonally, with the right side of the building enveloped in a deep brown shadow. At the bottom of the frame, a smooth concrete sidewalk is visible, upon which the dynamic silhouette of a person is captured mid-stride, walking from right to left.
349
+
350
+ In the shallow background, the faint, out-of-focus outlines of another building and the bare, skeletal branches of trees are softly visible, contributing to the quiet urban atmosphere and adding a sense of depth to the scene. These elements are rendered with minimal detail to keep the focus on the foreground architecture.
351
+
352
+ The scene is illuminated by strong, natural sunlight originating from the upper left, creating a dramatic chiaroscuro effect. This hard light source casts deep, well-defined shadows, producing a sharp contrast between the brightly lit warm orange surfaces and the deep brown shadow areas. The lighting highlights the fine details in the wall texture and stone particles, emphasizing the photorealistic quality. The overall presentation reflects a high-quality photorealistic photography style, infused with a cinematic film noir aesthetic.
353
+ </details>
354
+ </td>
355
+ </tr>
356
+ <tr>
357
+ <td>
358
+ <img src="./assets/pg_imgs/image3.png" width=100%><details>
359
+ <summary>Show prompt</summary>
360
+ 一幅极具视觉张力的杂志封面风格人像特写。画面主体是一个身着古风汉服的人物,构图采用了从肩部以上的超级近距离特写,人物占据了画面的绝大部分,形成了强烈的视觉冲击力。
361
+
362
+ 画面中的人物以一种慵懒的姿态出现,微微倾斜着头部,裸露的一侧肩膀线条流畅。她正用一种妩媚而直接的眼神凝视着镜头,双眼微张,眼神深邃,传递出一种神秘而勾人的气质。人物的面部特征精致,皮肤质感细腻,在特定的光线下,面部轮廓清晰分明,展现出一种古典与现代融合的时尚美感。
363
+
364
+ 整个画面的背景被设定为一种简约而高级的纯红色。这种红色色调深沉,呈现出哑光质感,既纯粹又无任何杂质,为整个暗黑神秘的氛围奠定了沉稳而富有张力的基调。这个纯色的背景有效地突出了前景中的人物主体,使得所有视觉焦点都集中在其身上。
365
+
366
+ 光线和氛围的营造是这幅杂志风海报的关键。一束暗橘色的柔和光线作为主光源,从人���的一侧斜上方投射下来,精准地勾勒出人物的脸颊、鼻梁和肩膀的轮廓,在皮肤上形成微妙的光影过渡。同时,人物的周身萦绕着一层暗淡且低饱和度的银白色辉光,如同清冷的月光,形成一道朦胧的轮廓光。这道银辉为人物增添了几分疏离的幽灵感,强化了整体暗黑风格的神秘气质。光影的强烈对比与色彩的独特搭配,共同塑造了这张充满故事感的特写画面。整体图像呈现出一种融合了古典元素的现代时尚摄影风格。
367
+ </details>
368
+ </td>
369
+ <td>
370
+ <img src="./assets/pg_imgs/image4.png" width=100%><details>
371
+ <summary>Show prompt</summary>
372
+ 一幅采用极简俯视视角的油画作品,画面主体由一道居中斜向的红色笔触构成。
373
+
374
+ 这道醒目的红色笔触运用了厚涂技法,颜料堆叠形成了强烈的物理厚度和三维立体感。它从画面的左上角附近延伸至右下角附近,构成一个动态的对角线。颜料表面可以清晰地看到画刀刮擦和笔刷拖曳留下的痕迹,边缘处的颜料层相对较薄,而中央部分则高高隆起,形成了不规则的起伏。
375
+
376
+ 在这道立体的红色颜料之上,巧妙地构建了一处精致的微缩景观。景观的核心是一片模拟红海滩的区域,由细腻的深红色颜料点缀而成,与下方基底的鲜红色形成丰富的层次对比。紧邻着“红海滩”的是一小片湖泊,由一层平滑且带有光泽的蓝色与白色混合颜料构成,质感如同平静无波的水面。湖泊边缘,一小撮芦苇丛生,由几根纤细挺拔的、用淡黄色和棕色颜料勾勒出的线条来表现。一只小巧的白鹭立于芦苇旁,其形态由一小块纯白色的厚涂颜料塑造,仅用一抹精炼的黑色颜料点出其尖喙,姿态优雅宁静。
377
+
378
+ 整个构图的背景是大面积的留白,呈现为一张带有细微凹凸纹理的白色纸质基底,这种极简处理极大地突出了中央的红色笔触及其上的微缩景观。
379
+
380
+ 光线从画面一侧柔和地照射下来,在厚涂的颜料堆叠处投下淡淡的、轮廓分明的阴影,进一步增强了画面的三维立体感和油画质感。整幅画面呈现出一种结合了厚涂技法的现代极简主义油画风格。
381
+ </details>
382
+ </td>
383
+ </tr>
384
+ <tr>
385
+ <td>
386
+ <img src="./assets/pg_imgs/image5.png" width=100%><details>
387
+ <summary>Show prompt</summary>
388
+ 整体画面采用一个二乘二的四宫格布局,以产品可视化的风格,展示了一只兔子在四种不同材质下的渲染效果。每个宫格内都有一只姿态完全相同的兔子模型,它呈坐姿,双耳竖立,面朝前方。所有宫格的背景均是统一的中性深灰色,这种简约背景旨在最大限度地突出每种材质的独特质感。
389
+
390
+ 左上角的宫格中,兔子模型由哑光白色石膏材质构成。其表面平滑、均匀且无反射,在模型的耳朵根部、四肢交接处等凹陷区域呈现出柔和的环境光遮蔽阴影,这种微妙的阴影变化凸显了其纯粹的几何形态,整体感觉像一个用于美术研究的基础模型。
391
+
392
+ 右上角的宫格中,兔子模型由晶莹剔透的无瑕疵玻璃制成。它展现了逼真的物理折射效果,透过其透明的身体看到的背景呈现出轻微的扭曲。清晰的镜面高光沿着其身体的曲线轮廓流动,表面上还能看到微弱而清晰的环境反射,赋予其一种精致而易碎的质感。
393
+
394
+ 左下角的宫格中,兔子模型呈现为带有拉丝纹理的钛金属材质。金属表面具有明显的各向异性反射效果,呈现出冷峻的灰调金属光泽。锐利明亮的高光和深邃的阴影形成了强烈对比,精确地定义了其坚固的三维形态,展现了工业设计般的美感。
395
+
396
+ 右下角的宫格中,兔子模型覆盖着一层柔软浓密的灰色毛绒。根根分明的绒毛清晰可见,创造出一种温暖、可触摸的质地。光线照射在绒毛的末梢,形成柔和的光晕效果,而毛绒内部的阴影则显得深邃而柔软,展现了高度写实的毛发渲染效果。
397
+
398
+ 整个四宫格由来自多个方向的、柔和均匀的影棚灯光照亮,确保了每种材质的细节和特性都得到清晰的展现,没有任何刺眼的阴影或过曝的高光。这张图像以一种高度写实的3D渲染风格呈现,完美地诠释了产品可视化的精髓
399
+ </details>
400
+ </td>
401
+ <td>
402
+ <img src="./assets/pg_imgs/image6.png" width=100%><details>
403
+ <summary>Show prompt</summary>
404
+ 由一个两行两列的网格构成,共包含四个独立的场景,每个场景都以不同的艺术风格描绘了一个小男孩(小明)一天中的不同活动。
405
+
406
+ 左上角的第一个场景,以超写实摄影风格呈现。画面主体是一个大约8岁的东亚小男孩,他穿着整洁的小学制服——一件白色短袖衬衫和蓝色短裤,脖子上系着红领巾。他背着一个蓝色的双肩书包,正走在去上学的路上。他位于画面的前景偏右侧,面带微笑,步伐轻快。场景设定���清晨,柔和的阳光从左上方照射下来,在人行道上投下清晰而柔和的影子。背景是绿树成荫的街道和模糊可见的学校铁艺大门,营造出宁静的早晨氛围。这张图片的细节表现极为丰富,可以清晰地看到男孩头发的光泽、衣服的褶皱纹理以及书包的帆布材质,完全展现了专业摄影的质感。
407
+
408
+ 右上角的第二个场景,采用日式赛璐璐动漫风格绘制。画面中,小男孩坐在家中的木质餐桌旁吃午饭。他的形象被动漫化,拥有大而明亮的眼睛和简洁的五官线条。他身穿一件简单的黄色T恤,正用筷子夹起碗里的米饭。桌上摆放着一碗汤和两盘家常菜。背景是一个温馨的室内环境,一扇明亮的窗户透进正午的阳光,窗外是蓝天白云。整个画面色彩鲜艳、饱和度高,角色轮廓线清晰明确,阴影部分采用平涂的色块处理,是典型的赛璐璐动漫风格。
409
+
410
+ 左下角的第三个场景,以细腻的铅笔素描风格呈现。画面描绘了下午在操场上踢足球的小男孩。整个图像由不同灰度的石墨色调构成,没有其他颜色。小男孩身穿运动短袖和短裤,身体呈前倾姿态,右脚正要踢向一个足球,动作充满动感。背景是空旷的操场和远处的球门,用简练的线条和排线勾勒。艺术家通过交叉排线和涂抹技巧来表现光影和体积感,足球上的阴影、人物身上的肌肉线条以及地面粗糙的质感都通过铅笔的笔触得到了充分的展现。这张铅笔画突出了素描的光影关系和线条美感。
411
+
412
+ 右下角的第四个场景,以文森特·梵高的后印象派油画风格进行诠释。画面描绘了夜晚时分,小男孩独自在河边钓鱼的景象。他坐在一块岩石上,手持一根简易的钓鱼竿,身影在深蓝色的夜幕下显得很渺小。整个画面的视觉焦点是天空和水面,天空布满了旋转、卷曲的星云,星星和月亮被描绘成巨大、发光的光团,使用了厚涂的油画颜料(Impasto),笔触粗犷而充满能量。深蓝、亮黄和白色的颜料在画布上相互交织,形成强烈的视觉冲击力。水面倒映着天空中扭曲的光影,整个场景充满了梵高作品中特有的强烈情感和动荡不安的美感。这幅画作是对梵高风格的深度致敬。
413
+ </details>
414
+ </td>
415
+ </tr>
416
+ <tr>
417
+ <td>
418
+ <img src="./assets/pg_imgs/image7.png" width=100%><details>
419
+ <summary>Show prompt</summary>
420
+ 以平视视角,呈现了一幅关于如何用素描技法绘制鹦鹉的九宫格教学图。整体构图规整,九个大小一致的方形画框以三行三列的形式均匀分布在浅灰色背景上,清晰地展示了从基本形状到最终成品的全过程。
421
+
422
+ 第一行从左至右展示了绘画的初始步骤。左上角的第一个画框中,用简洁的铅笔线条勾勒出鹦鹉的基本几何形态:一个圆形代表头部,一个稍大的椭圆形代表身体。右上角有一个小号的无衬线字体数字“1”。中间的第二个画框中,在基础形态上添加了三角形的鸟喙轮廓和一条长长的弧线作为尾巴的雏形,头部和身体的连接处线条变得更加流畅;右上角标有数字“2”。右侧的第三个画框中,进一步精确了鹦鹉的整体轮廓,勾勒出头部顶端的羽冠和清晰的眼部圆形轮廓;右上角标有数字“3”。
423
+
424
+ 第二行专注于结构与细节的添加,描绘了绘画的中期阶段。左侧的第四个画框里,鹦鹉的身体上添加了翅膀的基本形状,同时在身体下方画出了一根作为栖木的横向树枝,鹦鹉的爪子初步搭在树枝上;右上角标有数字“4”。中间的第五个画框中,开始细化翅膀和尾部的羽毛分组,用短促的线条表现出层次感,并清晰地画出爪子紧握树枝的细节;右上角标有数字“5”。右侧的第六个画框里,开始为鹦鹉添加初步的阴影,使用交叉排线的素描技法在腹部、翅膀下方和颈部制造出体积感;右上角标有数字“6”。
425
+
426
+ 第三行则展示了最终的润色与完成阶段。左下角的第七个画框中,素描的排线更加密集,阴影层次更加丰富,羽毛的纹理细节被仔细刻画出来,眼珠也添加了高光点缀,显得炯炯有神;右上角标有数字“7”。中间的第八个画框里,描绘的重点转移到栖木上,增加了树枝的纹理和节疤细节,同时整体调整了鹦鹉身上的光影关系,使立体感更为突出;右上角标有数字“8”。右下角的第九个画框是最终完成图,所有线条都经过了精炼,光影对比强烈,鹦鹉的羽毛质感、木质栖木的粗糙感都表现得淋漓尽致,呈现出一幅完整且细节丰富的素描作品;右上角标有数字“9”。
427
+
428
+ 整个画面的光线均匀而明亮,没有任何特定的光源方向,确保了每个教学步骤的视觉清晰度。整体呈现出一种清晰、有条理的数字插画教程风格。
429
+ </details>
430
+ </td>
431
+ <td>
432
+ <img src="./assets/pg_imgs/image8.png" width=100%><details>
433
+ <summary>Show prompt</summary>
434
+ 一张现代平面设计风格的海报占据了整个画面,构图简洁且中心突出。
435
+
436
+ 海报的主体是位于画面正中央的一只腾讯QQ企鹅。这只企鹅采用了圆润可爱的3D卡通渲染风格,身体主要为饱满的黑色,腹部为纯白色。它的眼睛大而圆,眼神好奇地直视前方。黄色的嘴巴小巧而立体,双脚同样为鲜明的黄色,稳稳地站立着。一条标志性的红色围巾整齐地系在它的脖子上,围巾的材质带有轻微的布料质感,末端自然下垂。企鹅的整体造型干净利落,边缘光滑,呈现出一种精致的数字插画质感。
437
+
438
+ 海报的背景是一种从上到下由浅蓝色平滑过渡到白色的柔和渐变,营造出一种开阔、明亮的空间感。在企鹅的身后,散布着一些淡淡的、模糊的圆形光斑和几道柔和的抽象光束,为这个简约的平面设计海报增添了微妙的深度和科技感。
439
+
440
+ 画面的底部区域是文字部分,排版居中对齐。上半部分是一行稍大的黑色黑体字,内容为“Hunyuan Image 3.0”。紧随其下的是一行字号略小的深灰色黑体字,内容为“原生多模态大模型”。两行文字清晰易读,与整体的现代平面设计风格保持一致。
441
+
442
+ 整体光线明亮、均匀,没有明显的阴影,突出了企鹅和文字信息,符合现代设计海报的视觉要求。这张图像呈现了现代、简洁的平面设计海报风格。
443
+ </details>
444
+ </td>
445
+ </tr>
446
+ </tbody>
447
+ </table>
448
+ </p>
449
+
450
+ ## 📊 Evaluation
451
+
452
+ * 🤖 **SSAE (Machine Evaluation)**
453
+ SSAE (Structured Semantic Alignment Evaluation) is an intelligent evaluation metric for image-text alignment based on advanced multimodal large language models (MLLMs). We extracted 3500 key points across 12 categories, then used multimodal large language models to automatically evaluate and score by comparing the generated images with these key points based on the visual content of the images. Mean Image Accuracy represents the image-wise average score across all key points, while Global Accuracy directly calculates the average score across all key points.
454
+
455
+ <p align="center">
456
+ <img src="./assets/ssae_side_by_side_comparison.png" width=98% alt="Human Evaluation with Other Models">
457
+ </p>
458
+
459
+ <p align="center">
460
+ <img src="./assets/ssae_side_by_side_heatmap.png" width=98% alt="Human Evaluation with Other Models">
461
+ </p>
462
+
463
+
464
+ * 👥 **GSB (Human Evaluation)**
465
+
466
+ We adopted the GSB (Good/Same/Bad) evaluation method commonly used to assess the relative performance between two models from an overall image perception perspective. In total, we utilized 1,000 text prompts, generating an equal number of image samples for all compared models in a single run. For a fair comparison, we conducted inference only once for each prompt, avoiding any cherry-picking of results. When comparing with the baseline methods, we maintained the default settings for all selected models. The evaluation was performed by more than 100 professional evaluators.
467
+
468
+ <p align="center">
469
+ <img src="./assets/gsb.png" width=98% alt="Human Evaluation with Other Models">
470
+ </p>
471
+
472
+
473
+ ## 📚 Citation
474
+
475
+ If you find HunyuanImage-3.0 useful in your research, please cite our work:
476
+
477
+ ```bibtex
478
+ @article{cao2025hunyuanimage,
479
+ title={HunyuanImage 3.0 Technical Report},
480
+ author={Cao, Siyu and Chen, Hangting and Chen, Peng and Cheng, Yiji and Cui, Yutao and Deng, Xinchi and Dong, Ying and Gong, Kipper and Gu, Tianpeng and Gu, Xiusen and others},
481
+ journal={arXiv preprint arXiv:2509.23951},
482
+ year={2025}
483
+ }
484
+ ```
485
+
486
+ ## 🙏 Acknowledgements
487
+
488
+ We extend our heartfelt gratitude to the following open-source projects and communities for their invaluable contributions:
489
+
490
+ * 🤗 [Transformers](https://github.com/huggingface/transformers) - State-of-the-art NLP library
491
+ * 🎨 [Diffusers](https://github.com/huggingface/diffusers) - Diffusion models library
492
+ * 🌐 [HuggingFace](https://huggingface.co/) - AI model hub and community
493
+ * ⚡ [FlashAttention](https://github.com/Dao-AILab/flash-attention) - Memory-efficient attention
494
+ * 🚀 [FlashInfer](https://github.com/flashinfer-ai/flashinfer) - Optimized inference engine
495
+
496
+ ## 🌟🚀 Github Star History
497
+
498
+ [![GitHub stars](https://img.shields.io/github/stars/Tencent-Hunyuan/HunyuanImage-3.0?style=social)](https://github.com/Tencent-Hunyuan/HunyuanImage-3.0)
499
+ [![GitHub forks](https://img.shields.io/github/forks/Tencent-Hunyuan/HunyuanImage-3.0?style=social)](https://github.com/Tencent-Hunyuan/HunyuanImage-3.0)
500
+
501
+
502
+ [![Star History Chart](https://api.star-history.com/svg?repos=Tencent-Hunyuan/HunyuanImage-3.0&type=Date)](https://www.star-history.com/#Tencent-Hunyuan/HunyuanImage-3.0&Date)
__init__.py ADDED
File without changes
assets/WECHAT.md ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <div align="center">
2
+ <img src=wechat.png width="60%"/>
3
+
4
+ <p> 扫码关注混元图像系列工作,加入「 腾讯混元生图交流群 」 </p>
5
+ <p> Scan the QR code to join the "Tencent Hunyuan Image Generation Discussion Group" </p>
6
+ </div>
assets/banner.png ADDED

Git LFS Details

  • SHA256: 53bef578e373fc53c8c16d26a1011f85ce8d4f46aeac1222be7263a09a3d8c7f
  • Pointer size: 132 Bytes
  • Size of remote file: 1.41 MB
assets/banner_all.jpg ADDED

Git LFS Details

  • SHA256: 667e956a3c27f6722eceacebe907b56b6669cc35b706a4957893b7ae88b7fbc0
  • Pointer size: 133 Bytes
  • Size of remote file: 15.3 MB
assets/framework.png ADDED

Git LFS Details

  • SHA256: f6c0e6751b4bf0f30daeb6a2b7cdb4dc3276bd61ea58294ba528fd807a17473f
  • Pointer size: 131 Bytes
  • Size of remote file: 248 kB
assets/gsb.png ADDED

Git LFS Details

  • SHA256: 8570c87bbb477e206f61d99c036a53d3671dc589829b335656931dcf01194536
  • Pointer size: 131 Bytes
  • Size of remote file: 191 kB
assets/logo.png ADDED

Git LFS Details

  • SHA256: f59d594e65aff85c3ac35ff02aa7e14cccfb88d4c6296948efe5cea9a3bfb690
  • Pointer size: 130 Bytes
  • Size of remote file: 95.1 kB
assets/pg_imgs/image1.png ADDED

Git LFS Details

  • SHA256: a385db722efc89cff4d5e4afdb82c96156f223907d70f0d3c7eb8b3e59edbccb
  • Pointer size: 132 Bytes
  • Size of remote file: 1.64 MB
assets/pg_imgs/image2.png ADDED

Git LFS Details

  • SHA256: 84a7d37d3ff8452c32ecb79f98692ad38f8f190dc201922a924ae2fda4515e12
  • Pointer size: 132 Bytes
  • Size of remote file: 1.7 MB
assets/pg_imgs/image3.png ADDED

Git LFS Details

  • SHA256: 913376a1ad5d10bc1549f9f28fc25a0c9f94a119b99434618bacacc7429996fa
  • Pointer size: 132 Bytes
  • Size of remote file: 1.39 MB
assets/pg_imgs/image4.png ADDED

Git LFS Details

  • SHA256: 71dcfd968f4c76ccec2ccc1806e9ce97babed56c73441d744f7264433bf9339a
  • Pointer size: 132 Bytes
  • Size of remote file: 1.55 MB
assets/pg_imgs/image5.png ADDED

Git LFS Details

  • SHA256: b93338e2f81f9809f8a9f674e0fe3da7c03de4fc4d7aba1819acb878384abb3e
  • Pointer size: 132 Bytes
  • Size of remote file: 3.31 MB
assets/pg_imgs/image6.png ADDED

Git LFS Details

  • SHA256: 84e7c73dafea831bf1ceb8c3dd76c16238c9b4a31ac30e3f46c38d61005b5895
  • Pointer size: 132 Bytes
  • Size of remote file: 2.02 MB
assets/pg_imgs/image7.png ADDED

Git LFS Details

  • SHA256: 701f2449436f8d46f537100bdaa63569586e39448e30fffe2e5bc3e95e558daa
  • Pointer size: 132 Bytes
  • Size of remote file: 1.55 MB
assets/pg_imgs/image8.png ADDED

Git LFS Details

  • SHA256: b80b97174d4f98030eda02d5dbaac2e294d814f086d957d47957482eb7b70251
  • Pointer size: 132 Bytes
  • Size of remote file: 1.27 MB
assets/robot.png ADDED

Git LFS Details

  • SHA256: 2a5b09f264c3752199536e92ca57836119604a79e3d08471d2818d2d576dd79b
  • Pointer size: 130 Bytes
  • Size of remote file: 16.4 kB
assets/ssae_side_by_side_comparison.png ADDED

Git LFS Details

  • SHA256: 665dce959769e799a14fa7d176d4a676feb33193c6575db21da97458732488fc
  • Pointer size: 132 Bytes
  • Size of remote file: 1.52 MB
assets/ssae_side_by_side_heatmap.png ADDED

Git LFS Details

  • SHA256: 00e2342afb5cabaf20b9b415587fb2986456f8c7c8cb96dd6ecc68455457045e
  • Pointer size: 131 Bytes
  • Size of remote file: 639 kB
assets/user.png ADDED

Git LFS Details

  • SHA256: 75543c163927df138a1c3d2958322e151ba259fc52fcd91bebb4cea92fc1af89
  • Pointer size: 130 Bytes
  • Size of remote file: 13.5 kB
assets/wechat.png ADDED

Git LFS Details

  • SHA256: 7bb1d5e06408b09ca6764ddd0a70fe9acfd045295c4144c99235f74336ea0169
  • Pointer size: 130 Bytes
  • Size of remote file: 30.9 kB
autoencoder_kl_3d.py ADDED
@@ -0,0 +1,793 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Licensed under the TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT (the "License");
2
+ # you may not use this file except in compliance with the License.
3
+ # You may obtain a copy of the License at
4
+ #
5
+ # https://github.com/Tencent-Hunyuan/HunyuanImage-3.0/blob/main/LICENSE
6
+ #
7
+ # Unless required by applicable law or agreed to in writing, software
8
+ # distributed under the License is distributed on an "AS IS" BASIS,
9
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10
+ # See the License for the specific language governing permissions and
11
+ # limitations under the License.
12
+ # ==============================================================================
13
+
14
+ from dataclasses import dataclass
15
+ from typing import Tuple, Optional
16
+ import math
17
+ import random
18
+ import numpy as np
19
+ from einops import rearrange
20
+ import torch
21
+ from torch import Tensor, nn
22
+ import torch.nn.functional as F
23
+
24
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
25
+ from diffusers.models.modeling_outputs import AutoencoderKLOutput
26
+ from diffusers.models.modeling_utils import ModelMixin
27
+ from diffusers.utils.torch_utils import randn_tensor
28
+ from diffusers.utils import BaseOutput
29
+
30
+
31
+ class DiagonalGaussianDistribution(object):
32
+ def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
33
+ if parameters.ndim == 3:
34
+ dim = 2 # (B, L, C)
35
+ elif parameters.ndim == 5 or parameters.ndim == 4:
36
+ dim = 1 # (B, C, T, H ,W) / (B, C, H, W)
37
+ else:
38
+ raise NotImplementedError
39
+ self.parameters = parameters
40
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=dim)
41
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
42
+ self.deterministic = deterministic
43
+ self.std = torch.exp(0.5 * self.logvar)
44
+ self.var = torch.exp(self.logvar)
45
+ if self.deterministic:
46
+ self.var = self.std = torch.zeros_like(
47
+ self.mean, device=self.parameters.device, dtype=self.parameters.dtype
48
+ )
49
+
50
+ def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
51
+ # make sure sample is on the same device as the parameters and has same dtype
52
+ sample = randn_tensor(
53
+ self.mean.shape,
54
+ generator=generator,
55
+ device=self.parameters.device,
56
+ dtype=self.parameters.dtype,
57
+ )
58
+ x = self.mean + self.std * sample
59
+ return x
60
+
61
+ def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor:
62
+ if self.deterministic:
63
+ return torch.Tensor([0.0])
64
+ else:
65
+ reduce_dim = list(range(1, self.mean.ndim))
66
+ if other is None:
67
+ return 0.5 * torch.sum(
68
+ torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
69
+ dim=reduce_dim,
70
+ )
71
+ else:
72
+ return 0.5 * torch.sum(
73
+ torch.pow(self.mean - other.mean, 2) / other.var +
74
+ self.var / other.var -
75
+ 1.0 -
76
+ self.logvar +
77
+ other.logvar,
78
+ dim=reduce_dim,
79
+ )
80
+
81
+ def nll(self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]) -> torch.Tensor:
82
+ if self.deterministic:
83
+ return torch.Tensor([0.0])
84
+ logtwopi = np.log(2.0 * np.pi)
85
+ return 0.5 * torch.sum(
86
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
87
+ dim=dims,
88
+ )
89
+
90
+ def mode(self) -> torch.Tensor:
91
+ return self.mean
92
+
93
+
94
+ @dataclass
95
+ class DecoderOutput(BaseOutput):
96
+ sample: torch.FloatTensor
97
+ posterior: Optional[DiagonalGaussianDistribution] = None
98
+
99
+
100
+ def swish(x: Tensor) -> Tensor:
101
+ return x * torch.sigmoid(x)
102
+
103
+
104
+ def forward_with_checkpointing(module, *inputs, use_checkpointing=False):
105
+ def create_custom_forward(module):
106
+ def custom_forward(*inputs):
107
+ return module(*inputs)
108
+ return custom_forward
109
+
110
+ if use_checkpointing:
111
+ return torch.utils.checkpoint.checkpoint(create_custom_forward(module), *inputs, use_reentrant=False)
112
+ else:
113
+ return module(*inputs)
114
+
115
+
116
+ class Conv3d(nn.Conv3d):
117
+ """
118
+ Perform Conv3d on patches with numerical differences from nn.Conv3d within 1e-5.
119
+ Only symmetric padding is supported.
120
+ """
121
+
122
+ def forward(self, input):
123
+ B, C, T, H, W = input.shape
124
+ memory_count = (C * T * H * W) * 2 / 1024**3
125
+ if memory_count > 2:
126
+ n_split = math.ceil(memory_count / 2)
127
+ assert n_split >= 2
128
+ chunks = torch.chunk(input, chunks=n_split, dim=-3)
129
+ padded_chunks = []
130
+ for i in range(len(chunks)):
131
+ if self.padding[0] > 0:
132
+ padded_chunk = F.pad(
133
+ chunks[i],
134
+ (0, 0, 0, 0, self.padding[0], self.padding[0]),
135
+ mode="constant" if self.padding_mode == "zeros" else self.padding_mode,
136
+ value=0,
137
+ )
138
+ if i > 0:
139
+ padded_chunk[:, :, :self.padding[0]] = chunks[i - 1][:, :, -self.padding[0]:]
140
+ if i < len(chunks) - 1:
141
+ padded_chunk[:, :, -self.padding[0]:] = chunks[i + 1][:, :, :self.padding[0]]
142
+ else:
143
+ padded_chunk = chunks[i]
144
+ padded_chunks.append(padded_chunk)
145
+ padding_bak = self.padding
146
+ self.padding = (0, self.padding[1], self.padding[2])
147
+ outputs = []
148
+ for i in range(len(padded_chunks)):
149
+ outputs.append(super().forward(padded_chunks[i]))
150
+ self.padding = padding_bak
151
+ return torch.cat(outputs, dim=-3)
152
+ else:
153
+ return super().forward(input)
154
+
155
+
156
+ class AttnBlock(nn.Module):
157
+ """ Attention with torch sdpa implementation. """
158
+ def __init__(self, in_channels: int):
159
+ super().__init__()
160
+ self.in_channels = in_channels
161
+
162
+ self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
163
+
164
+ self.q = Conv3d(in_channels, in_channels, kernel_size=1)
165
+ self.k = Conv3d(in_channels, in_channels, kernel_size=1)
166
+ self.v = Conv3d(in_channels, in_channels, kernel_size=1)
167
+ self.proj_out = Conv3d(in_channels, in_channels, kernel_size=1)
168
+
169
+ def attention(self, h_: Tensor) -> Tensor:
170
+ h_ = self.norm(h_)
171
+ q = self.q(h_)
172
+ k = self.k(h_)
173
+ v = self.v(h_)
174
+
175
+ b, c, f, h, w = q.shape
176
+ q = rearrange(q, "b c f h w -> b 1 (f h w) c").contiguous()
177
+ k = rearrange(k, "b c f h w -> b 1 (f h w) c").contiguous()
178
+ v = rearrange(v, "b c f h w -> b 1 (f h w) c").contiguous()
179
+ h_ = nn.functional.scaled_dot_product_attention(q, k, v)
180
+
181
+ return rearrange(h_, "b 1 (f h w) c -> b c f h w", f=f, h=h, w=w, c=c, b=b)
182
+
183
+ def forward(self, x: Tensor) -> Tensor:
184
+ return x + self.proj_out(self.attention(x))
185
+
186
+
187
+ class ResnetBlock(nn.Module):
188
+ def __init__(self, in_channels: int, out_channels: int):
189
+ super().__init__()
190
+ self.in_channels = in_channels
191
+ out_channels = in_channels if out_channels is None else out_channels
192
+ self.out_channels = out_channels
193
+
194
+ self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
195
+ self.conv1 = Conv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
196
+ self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
197
+ self.conv2 = Conv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
198
+ if self.in_channels != self.out_channels:
199
+ self.nin_shortcut = Conv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
200
+
201
+ def forward(self, x):
202
+ h = x
203
+ h = self.norm1(h)
204
+ h = swish(h)
205
+ h = self.conv1(h)
206
+
207
+ h = self.norm2(h)
208
+ h = swish(h)
209
+ h = self.conv2(h)
210
+
211
+ if self.in_channels != self.out_channels:
212
+ x = self.nin_shortcut(x)
213
+ return x + h
214
+
215
+
216
+ class Downsample(nn.Module):
217
+ def __init__(self, in_channels: int, add_temporal_downsample: bool = True):
218
+ super().__init__()
219
+ self.add_temporal_downsample = add_temporal_downsample
220
+ stride = (2, 2, 2) if add_temporal_downsample else (1, 2, 2) # THW
221
+ # no asymmetric padding in torch conv, must do it ourselves
222
+ self.conv = Conv3d(in_channels, in_channels, kernel_size=3, stride=stride, padding=0)
223
+
224
+ def forward(self, x: Tensor):
225
+ spatial_pad = (0, 1, 0, 1, 0, 0) # WHT
226
+ x = nn.functional.pad(x, spatial_pad, mode="constant", value=0)
227
+
228
+ temporal_pad = (0, 0, 0, 0, 0, 1) if self.add_temporal_downsample else (0, 0, 0, 0, 1, 1)
229
+ x = nn.functional.pad(x, temporal_pad, mode="replicate")
230
+
231
+ x = self.conv(x)
232
+ return x
233
+
234
+
235
+ class DownsampleDCAE(nn.Module):
236
+ def __init__(self, in_channels: int, out_channels: int, add_temporal_downsample: bool = True):
237
+ super().__init__()
238
+ factor = 2 * 2 * 2 if add_temporal_downsample else 1 * 2 * 2
239
+ assert out_channels % factor == 0
240
+ self.conv = Conv3d(in_channels, out_channels // factor, kernel_size=3, stride=1, padding=1)
241
+
242
+ self.add_temporal_downsample = add_temporal_downsample
243
+ self.group_size = factor * in_channels // out_channels
244
+
245
+ def forward(self, x: Tensor):
246
+ r1 = 2 if self.add_temporal_downsample else 1
247
+ h = self.conv(x)
248
+ h = rearrange(h, "b c (f r1) (h r2) (w r3) -> b (r1 r2 r3 c) f h w", r1=r1, r2=2, r3=2)
249
+ shortcut = rearrange(x, "b c (f r1) (h r2) (w r3) -> b (r1 r2 r3 c) f h w", r1=r1, r2=2, r3=2)
250
+
251
+ B, C, T, H, W = shortcut.shape
252
+ shortcut = shortcut.view(B, h.shape[1], self.group_size, T, H, W).mean(dim=2)
253
+ return h + shortcut
254
+
255
+
256
+ class Upsample(nn.Module):
257
+ def __init__(self, in_channels: int, add_temporal_upsample: bool = True):
258
+ super().__init__()
259
+ self.add_temporal_upsample = add_temporal_upsample
260
+ self.scale_factor = (2, 2, 2) if add_temporal_upsample else (1, 2, 2) # THW
261
+ self.conv = Conv3d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
262
+
263
+ def forward(self, x: Tensor):
264
+ x = nn.functional.interpolate(x, scale_factor=self.scale_factor, mode="nearest")
265
+ x = self.conv(x)
266
+ return x
267
+
268
+
269
+ class UpsampleDCAE(nn.Module):
270
+ def __init__(self, in_channels: int, out_channels: int, add_temporal_upsample: bool = True):
271
+ super().__init__()
272
+ factor = 2 * 2 * 2 if add_temporal_upsample else 1 * 2 * 2
273
+ self.conv = Conv3d(in_channels, out_channels * factor, kernel_size=3, stride=1, padding=1)
274
+
275
+ self.add_temporal_upsample = add_temporal_upsample
276
+ self.repeats = factor * out_channels // in_channels
277
+
278
+ def forward(self, x: Tensor):
279
+ r1 = 2 if self.add_temporal_upsample else 1
280
+ h = self.conv(x)
281
+ h = rearrange(h, "b (r1 r2 r3 c) f h w -> b c (f r1) (h r2) (w r3)", r1=r1, r2=2, r3=2)
282
+ shortcut = x.repeat_interleave(repeats=self.repeats, dim=1)
283
+ shortcut = rearrange(shortcut, "b (r1 r2 r3 c) f h w -> b c (f r1) (h r2) (w r3)", r1=r1, r2=2, r3=2)
284
+ return h + shortcut
285
+
286
+
287
+ class Encoder(nn.Module):
288
+ """
289
+ The encoder network of AutoencoderKLConv3D.
290
+ """
291
+ def __init__(
292
+ self,
293
+ in_channels: int,
294
+ z_channels: int,
295
+ block_out_channels: Tuple[int, ...],
296
+ num_res_blocks: int,
297
+ ffactor_spatial: int,
298
+ ffactor_temporal: int,
299
+ downsample_match_channel: bool = True,
300
+ ):
301
+ super().__init__()
302
+ assert block_out_channels[-1] % (2 * z_channels) == 0
303
+
304
+ self.z_channels = z_channels
305
+ self.block_out_channels = block_out_channels
306
+ self.num_res_blocks = num_res_blocks
307
+
308
+ # downsampling
309
+ self.conv_in = Conv3d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
310
+
311
+ self.down = nn.ModuleList()
312
+ block_in = block_out_channels[0]
313
+ for i_level, ch in enumerate(block_out_channels):
314
+ block = nn.ModuleList()
315
+ block_out = ch
316
+ for _ in range(self.num_res_blocks):
317
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
318
+ block_in = block_out
319
+ down = nn.Module()
320
+ down.block = block
321
+
322
+ add_spatial_downsample = bool(i_level < np.log2(ffactor_spatial))
323
+ add_temporal_downsample = (add_spatial_downsample and
324
+ bool(i_level >= np.log2(ffactor_spatial // ffactor_temporal)))
325
+ if add_spatial_downsample or add_temporal_downsample:
326
+ assert i_level < len(block_out_channels) - 1
327
+ block_out = block_out_channels[i_level + 1] if downsample_match_channel else block_in
328
+ down.downsample = DownsampleDCAE(block_in, block_out, add_temporal_downsample)
329
+ block_in = block_out
330
+ self.down.append(down)
331
+
332
+ # middle
333
+ self.mid = nn.Module()
334
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
335
+ self.mid.attn_1 = AttnBlock(block_in)
336
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
337
+
338
+ # end
339
+ self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
340
+ self.conv_out = Conv3d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
341
+
342
+ self.gradient_checkpointing = False
343
+
344
+ def forward(self, x: Tensor) -> Tensor:
345
+ use_checkpointing = bool(self.training and self.gradient_checkpointing)
346
+
347
+ # downsampling
348
+ h = self.conv_in(x)
349
+ for i_level in range(len(self.block_out_channels)):
350
+ for i_block in range(self.num_res_blocks):
351
+ h = forward_with_checkpointing(
352
+ self.down[i_level].block[i_block], h, use_checkpointing=use_checkpointing)
353
+ if hasattr(self.down[i_level], "downsample"):
354
+ h = forward_with_checkpointing(self.down[i_level].downsample, h, use_checkpointing=use_checkpointing)
355
+
356
+ # middle
357
+ h = forward_with_checkpointing(self.mid.block_1, h, use_checkpointing=use_checkpointing)
358
+ h = forward_with_checkpointing(self.mid.attn_1, h, use_checkpointing=use_checkpointing)
359
+ h = forward_with_checkpointing(self.mid.block_2, h, use_checkpointing=use_checkpointing)
360
+
361
+ # end
362
+ group_size = self.block_out_channels[-1] // (2 * self.z_channels)
363
+ shortcut = rearrange(h, "b (c r) f h w -> b c r f h w", r=group_size).mean(dim=2)
364
+ h = self.norm_out(h)
365
+ h = swish(h)
366
+ h = self.conv_out(h)
367
+ h += shortcut
368
+ return h
369
+
370
+
371
+ class Decoder(nn.Module):
372
+ """
373
+ The decoder network of AutoencoderKLConv3D.
374
+ """
375
+ def __init__(
376
+ self,
377
+ z_channels: int,
378
+ out_channels: int,
379
+ block_out_channels: Tuple[int, ...],
380
+ num_res_blocks: int,
381
+ ffactor_spatial: int,
382
+ ffactor_temporal: int,
383
+ upsample_match_channel: bool = True,
384
+ ):
385
+ super().__init__()
386
+ assert block_out_channels[0] % z_channels == 0
387
+
388
+ self.z_channels = z_channels
389
+ self.block_out_channels = block_out_channels
390
+ self.num_res_blocks = num_res_blocks
391
+
392
+ # z to block_in
393
+ block_in = block_out_channels[0]
394
+ self.conv_in = Conv3d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
395
+
396
+ # middle
397
+ self.mid = nn.Module()
398
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
399
+ self.mid.attn_1 = AttnBlock(block_in)
400
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
401
+
402
+ # upsampling
403
+ self.up = nn.ModuleList()
404
+ for i_level, ch in enumerate(block_out_channels):
405
+ block = nn.ModuleList()
406
+ block_out = ch
407
+ for _ in range(self.num_res_blocks + 1):
408
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
409
+ block_in = block_out
410
+ up = nn.Module()
411
+ up.block = block
412
+
413
+ add_spatial_upsample = bool(i_level < np.log2(ffactor_spatial))
414
+ add_temporal_upsample = bool(i_level < np.log2(ffactor_temporal))
415
+ if add_spatial_upsample or add_temporal_upsample:
416
+ assert i_level < len(block_out_channels) - 1
417
+ block_out = block_out_channels[i_level + 1] if upsample_match_channel else block_in
418
+ up.upsample = UpsampleDCAE(block_in, block_out, add_temporal_upsample)
419
+ block_in = block_out
420
+ self.up.append(up)
421
+
422
+ # end
423
+ self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
424
+ self.conv_out = Conv3d(block_in, out_channels, kernel_size=3, stride=1, padding=1)
425
+
426
+ self.gradient_checkpointing = False
427
+
428
+ def forward(self, z: Tensor) -> Tensor:
429
+ use_checkpointing = bool(self.training and self.gradient_checkpointing)
430
+
431
+ # z to block_in
432
+ repeats = self.block_out_channels[0] // (self.z_channels)
433
+ h = self.conv_in(z) + z.repeat_interleave(repeats=repeats, dim=1)
434
+
435
+ # middle
436
+ h = forward_with_checkpointing(self.mid.block_1, h, use_checkpointing=use_checkpointing)
437
+ h = forward_with_checkpointing(self.mid.attn_1, h, use_checkpointing=use_checkpointing)
438
+ h = forward_with_checkpointing(self.mid.block_2, h, use_checkpointing=use_checkpointing)
439
+
440
+ # upsampling
441
+ for i_level in range(len(self.block_out_channels)):
442
+ for i_block in range(self.num_res_blocks + 1):
443
+ h = forward_with_checkpointing(self.up[i_level].block[i_block], h, use_checkpointing=use_checkpointing)
444
+ if hasattr(self.up[i_level], "upsample"):
445
+ h = forward_with_checkpointing(self.up[i_level].upsample, h, use_checkpointing=use_checkpointing)
446
+
447
+ # end
448
+ h = self.norm_out(h)
449
+ h = swish(h)
450
+ h = self.conv_out(h)
451
+ return h
452
+
453
+
454
+ class AutoencoderKLConv3D(ModelMixin, ConfigMixin):
455
+ """
456
+ Autoencoder model with KL-regularized latent space based on 3D convolutions.
457
+ """
458
+ _supports_gradient_checkpointing = True
459
+
460
+ @register_to_config
461
+ def __init__(
462
+ self,
463
+ in_channels: int,
464
+ out_channels: int,
465
+ latent_channels: int,
466
+ block_out_channels: Tuple[int, ...],
467
+ layers_per_block: int,
468
+ ffactor_spatial: int,
469
+ ffactor_temporal: int,
470
+ sample_size: int,
471
+ sample_tsize: int,
472
+ scaling_factor: float = None,
473
+ shift_factor: Optional[float] = None,
474
+ downsample_match_channel: bool = True,
475
+ upsample_match_channel: bool = True,
476
+ only_encoder: bool = False, # only build encoder for saving memory
477
+ only_decoder: bool = False, # only build decoder for saving memory
478
+ ):
479
+ super().__init__()
480
+ self.ffactor_spatial = ffactor_spatial
481
+ self.ffactor_temporal = ffactor_temporal
482
+ self.scaling_factor = scaling_factor
483
+ self.shift_factor = shift_factor
484
+
485
+ # build model
486
+ if not only_decoder:
487
+ self.encoder = Encoder(
488
+ in_channels=in_channels,
489
+ z_channels=latent_channels,
490
+ block_out_channels=block_out_channels,
491
+ num_res_blocks=layers_per_block,
492
+ ffactor_spatial=ffactor_spatial,
493
+ ffactor_temporal=ffactor_temporal,
494
+ downsample_match_channel=downsample_match_channel,
495
+ )
496
+ if not only_encoder:
497
+ self.decoder = Decoder(
498
+ z_channels=latent_channels,
499
+ out_channels=out_channels,
500
+ block_out_channels=list(reversed(block_out_channels)),
501
+ num_res_blocks=layers_per_block,
502
+ ffactor_spatial=ffactor_spatial,
503
+ ffactor_temporal=ffactor_temporal,
504
+ upsample_match_channel=upsample_match_channel,
505
+ )
506
+
507
+ # slicing and tiling related
508
+ self.use_slicing = False
509
+ self.slicing_bsz = 1
510
+ self.use_spatial_tiling = False
511
+ self.use_temporal_tiling = False
512
+ self.use_tiling_during_training = False
513
+
514
+ # only relevant if vae tiling is enabled
515
+ self.tile_sample_min_size = sample_size
516
+ self.tile_latent_min_size = sample_size // ffactor_spatial
517
+ self.tile_sample_min_tsize = sample_tsize
518
+ self.tile_latent_min_tsize = sample_tsize // ffactor_temporal
519
+ self.tile_overlap_factor = 0.25
520
+
521
+ # use torch.compile for faster encode speed
522
+ self.use_compile = False
523
+
524
+ def _set_gradient_checkpointing(self, module, value=False):
525
+ if isinstance(module, (Encoder, Decoder)):
526
+ module.gradient_checkpointing = value
527
+
528
+ def enable_tiling_during_training(self, use_tiling: bool = True):
529
+ self.use_tiling_during_training = use_tiling
530
+
531
+ def disable_tiling_during_training(self):
532
+ self.enable_tiling_during_training(False)
533
+
534
+ def enable_temporal_tiling(self, use_tiling: bool = True):
535
+ self.use_temporal_tiling = use_tiling
536
+
537
+ def disable_temporal_tiling(self):
538
+ self.enable_temporal_tiling(False)
539
+
540
+ def enable_spatial_tiling(self, use_tiling: bool = True):
541
+ self.use_spatial_tiling = use_tiling
542
+
543
+ def disable_spatial_tiling(self):
544
+ self.enable_spatial_tiling(False)
545
+
546
+ def enable_tiling(self, use_tiling: bool = True):
547
+ self.enable_spatial_tiling(use_tiling)
548
+
549
+ def disable_tiling(self):
550
+ self.disable_spatial_tiling()
551
+
552
+ def enable_slicing(self):
553
+ self.use_slicing = True
554
+
555
+ def disable_slicing(self):
556
+ self.use_slicing = False
557
+
558
+ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int):
559
+ blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
560
+ for x in range(blend_extent):
561
+ b[:, :, :, :, x] = \
562
+ a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (x / blend_extent)
563
+ return b
564
+
565
+ def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int):
566
+ blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
567
+ for y in range(blend_extent):
568
+ b[:, :, :, y, :] = \
569
+ a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (y / blend_extent)
570
+ return b
571
+
572
+ def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int):
573
+ blend_extent = min(a.shape[-3], b.shape[-3], blend_extent)
574
+ for x in range(blend_extent):
575
+ b[:, :, x, :, :] = \
576
+ a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * (x / blend_extent)
577
+ return b
578
+
579
+ def spatial_tiled_encode(self, x: torch.Tensor):
580
+ """ spatial tailing for frames """
581
+ B, C, T, H, W = x.shape
582
+ overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) # 256 * (1 - 0.25) = 192
583
+ blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) # 8 * 0.25 = 2
584
+ row_limit = self.tile_latent_min_size - blend_extent # 8 - 2 = 6
585
+
586
+ rows = []
587
+ for i in range(0, H, overlap_size):
588
+ row = []
589
+ for j in range(0, W, overlap_size):
590
+ tile = x[:, :, :, i: i + self.tile_sample_min_size, j: j + self.tile_sample_min_size]
591
+ tile = self.encoder(tile)
592
+ row.append(tile)
593
+ rows.append(row)
594
+ result_rows = []
595
+ for i, row in enumerate(rows):
596
+ result_row = []
597
+ for j, tile in enumerate(row):
598
+ if i > 0:
599
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
600
+ if j > 0:
601
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
602
+ result_row.append(tile[:, :, :, :row_limit, :row_limit])
603
+ result_rows.append(torch.cat(result_row, dim=-1))
604
+ moments = torch.cat(result_rows, dim=-2)
605
+ return moments
606
+
607
+ def temporal_tiled_encode(self, x: torch.Tensor):
608
+ """ temporal tailing for frames """
609
+ B, C, T, H, W = x.shape
610
+ overlap_size = int(self.tile_sample_min_tsize * (1 - self.tile_overlap_factor)) # 64 * (1 - 0.25) = 48
611
+ blend_extent = int(self.tile_latent_min_tsize * self.tile_overlap_factor) # 8 * 0.25 = 2
612
+ t_limit = self.tile_latent_min_tsize - blend_extent # 8 - 2 = 6
613
+
614
+ row = []
615
+ for i in range(0, T, overlap_size):
616
+ tile = x[:, :, i: i + self.tile_sample_min_tsize, :, :]
617
+ if self.use_spatial_tiling and (
618
+ tile.shape[-1] > self.tile_sample_min_size or tile.shape[-2] > self.tile_sample_min_size):
619
+ tile = self.spatial_tiled_encode(tile)
620
+ else:
621
+ tile = self.encoder(tile)
622
+ row.append(tile)
623
+ result_row = []
624
+ for i, tile in enumerate(row):
625
+ if i > 0:
626
+ tile = self.blend_t(row[i - 1], tile, blend_extent)
627
+ result_row.append(tile[:, :, :t_limit, :, :])
628
+ moments = torch.cat(result_row, dim=-3)
629
+ return moments
630
+
631
+ def spatial_tiled_decode(self, z: torch.Tensor):
632
+ """ spatial tailing for frames """
633
+ B, C, T, H, W = z.shape
634
+ overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor)) # 8 * (1 - 0.25) = 6
635
+ blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor) # 256 * 0.25 = 64
636
+ row_limit = self.tile_sample_min_size - blend_extent # 256 - 64 = 192
637
+
638
+ rows = []
639
+ for i in range(0, H, overlap_size):
640
+ row = []
641
+ for j in range(0, W, overlap_size):
642
+ tile = z[:, :, :, i: i + self.tile_latent_min_size, j: j + self.tile_latent_min_size]
643
+ decoded = self.decoder(tile)
644
+ row.append(decoded)
645
+ rows.append(row)
646
+
647
+ result_rows = []
648
+ for i, row in enumerate(rows):
649
+ result_row = []
650
+ for j, tile in enumerate(row):
651
+ if i > 0:
652
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
653
+ if j > 0:
654
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
655
+ result_row.append(tile[:, :, :, :row_limit, :row_limit])
656
+ result_rows.append(torch.cat(result_row, dim=-1))
657
+ dec = torch.cat(result_rows, dim=-2)
658
+ return dec
659
+
660
+ def temporal_tiled_decode(self, z: torch.Tensor):
661
+ """ temporal tailing for frames """
662
+ B, C, T, H, W = z.shape
663
+ overlap_size = int(self.tile_latent_min_tsize * (1 - self.tile_overlap_factor)) # 8 * (1 - 0.25) = 6
664
+ blend_extent = int(self.tile_sample_min_tsize * self.tile_overlap_factor) # 64 * 0.25 = 16
665
+ t_limit = self.tile_sample_min_tsize - blend_extent # 64 - 16 = 48
666
+ assert 0 < overlap_size < self.tile_latent_min_tsize
667
+
668
+ row = []
669
+ for i in range(0, T, overlap_size):
670
+ tile = z[:, :, i: i + self.tile_latent_min_tsize, :, :]
671
+ if self.use_spatial_tiling and (
672
+ tile.shape[-1] > self.tile_latent_min_size or tile.shape[-2] > self.tile_latent_min_size):
673
+ decoded = self.spatial_tiled_decode(tile)
674
+ else:
675
+ decoded = self.decoder(tile)
676
+ row.append(decoded)
677
+
678
+ result_row = []
679
+ for i, tile in enumerate(row):
680
+ if i > 0:
681
+ tile = self.blend_t(row[i - 1], tile, blend_extent)
682
+ result_row.append(tile[:, :, :t_limit, :, :])
683
+ dec = torch.cat(result_row, dim=-3)
684
+ return dec
685
+
686
+ def encode(self, x: Tensor, return_dict: bool = True):
687
+ """
688
+ Encodes the input by passing through the encoder network.
689
+ Support slicing and tiling for memory efficiency.
690
+ """
691
+ def _encode(x):
692
+ if self.use_temporal_tiling and x.shape[-3] > self.tile_sample_min_tsize:
693
+ return self.temporal_tiled_encode(x)
694
+ if self.use_spatial_tiling and (
695
+ x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
696
+ return self.spatial_tiled_encode(x)
697
+
698
+ if self.use_compile:
699
+ @torch.compile
700
+ def encoder(x):
701
+ return self.encoder(x)
702
+ return encoder(x)
703
+ return self.encoder(x)
704
+
705
+ if len(x.shape) != 5: # (B, C, T, H, W)
706
+ x = x[:, :, None]
707
+ assert len(x.shape) == 5 # (B, C, T, H, W)
708
+ if x.shape[2] == 1:
709
+ x = x.expand(-1, -1, self.ffactor_temporal, -1, -1)
710
+ else:
711
+ assert x.shape[2] != self.ffactor_temporal and x.shape[2] % self.ffactor_temporal == 0
712
+
713
+ if self.use_slicing and x.shape[0] > 1:
714
+ if self.slicing_bsz == 1:
715
+ encoded_slices = [_encode(x_slice) for x_slice in x.split(1)]
716
+ else:
717
+ sections = [self.slicing_bsz] * (x.shape[0] // self.slicing_bsz)
718
+ if x.shape[0] % self.slicing_bsz != 0:
719
+ sections.append(x.shape[0] % self.slicing_bsz)
720
+ encoded_slices = [_encode(x_slice) for x_slice in x.split(sections)]
721
+ h = torch.cat(encoded_slices)
722
+ else:
723
+ h = _encode(x)
724
+ posterior = DiagonalGaussianDistribution(h)
725
+
726
+ if not return_dict:
727
+ return (posterior,)
728
+
729
+ return AutoencoderKLOutput(latent_dist=posterior)
730
+
731
+ def decode(self, z: Tensor, return_dict: bool = True, generator=None):
732
+ """
733
+ Decodes the input by passing through the decoder network.
734
+ Support slicing and tiling for memory efficiency.
735
+ """
736
+ def _decode(z):
737
+ if self.use_temporal_tiling and z.shape[-3] > self.tile_latent_min_tsize:
738
+ return self.temporal_tiled_decode(z)
739
+ if self.use_spatial_tiling and (
740
+ z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
741
+ return self.spatial_tiled_decode(z)
742
+ return self.decoder(z)
743
+
744
+ if self.use_slicing and z.shape[0] > 1:
745
+ decoded_slices = [_decode(z_slice) for z_slice in z.split(1)]
746
+ decoded = torch.cat(decoded_slices)
747
+ else:
748
+ decoded = _decode(z)
749
+
750
+ if z.shape[-3] == 1:
751
+ decoded = decoded[:, :, -1:]
752
+
753
+ if not return_dict:
754
+ return (decoded,)
755
+
756
+ return DecoderOutput(sample=decoded)
757
+
758
+ def forward(
759
+ self,
760
+ sample: torch.Tensor,
761
+ sample_posterior: bool = False,
762
+ return_posterior: bool = True,
763
+ return_dict: bool = True
764
+ ):
765
+ posterior = self.encode(sample).latent_dist
766
+ z = posterior.sample() if sample_posterior else posterior.mode()
767
+ dec = self.decode(z).sample
768
+ return DecoderOutput(sample=dec, posterior=posterior) if return_dict else (dec, posterior)
769
+
770
+ def random_reset_tiling(self, x: torch.Tensor):
771
+ if x.shape[-3] == 1:
772
+ self.disable_spatial_tiling()
773
+ self.disable_temporal_tiling()
774
+ return
775
+
776
+ # Use fixed shape here
777
+ min_sample_size = int(1 / self.tile_overlap_factor) * self.ffactor_spatial
778
+ min_sample_tsize = int(1 / self.tile_overlap_factor) * self.ffactor_temporal
779
+ sample_size = random.choice([None, 1 * min_sample_size, 2 * min_sample_size, 3 * min_sample_size])
780
+ if sample_size is None:
781
+ self.disable_spatial_tiling()
782
+ else:
783
+ self.tile_sample_min_size = sample_size
784
+ self.tile_latent_min_size = sample_size // self.ffactor_spatial
785
+ self.enable_spatial_tiling()
786
+
787
+ sample_tsize = random.choice([None, 1 * min_sample_tsize, 2 * min_sample_tsize, 3 * min_sample_tsize])
788
+ if sample_tsize is None:
789
+ self.disable_temporal_tiling()
790
+ else:
791
+ self.tile_sample_min_tsize = sample_tsize
792
+ self.tile_latent_min_tsize = sample_tsize // self.ffactor_temporal
793
+ self.enable_temporal_tiling()
config.json ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_classification_head": false,
3
+ "anyres_pooling_size": 2,
4
+ "anyres_vit_max_image_size": null,
5
+ "anyres_vit_two_views": false,
6
+ "architectures": [
7
+ "HunyuanImage3ForCausalMM"
8
+ ],
9
+ "attention_bias": false,
10
+ "attention_dropout": 0.0,
11
+ "attention_head_dim": 128,
12
+ "auto_map": {
13
+ "AutoConfig": "configuration_hunyuan.HunyuanImage3Config",
14
+ "AutoModel": "hunyuan.HunyuanImage3Model",
15
+ "AutoModelForCausalLM": "hunyuan.HunyuanImage3ForCausalMM"
16
+ },
17
+ "bos_token_id": 127958,
18
+ "cla_share_factor": 2,
19
+ "class_num": 0,
20
+ "dense_list": [
21
+ 4096,
22
+ 0
23
+ ],
24
+ "eod_token_id": 3,
25
+ "eos_token_id": 127957,
26
+ "group_limited_greedy": false,
27
+ "hidden_act": "silu",
28
+ "hidden_size": 4096,
29
+ "im_end_id": 128001,
30
+ "im_newline_id": 11,
31
+ "im_start_id": 128000,
32
+ "image_token_id": 128006,
33
+ "initializer_range": 0.02,
34
+ "intermediate_size": 3072,
35
+ "kv_lora_rank": null,
36
+ "mask_init_id": 12,
37
+ "max_position_embeddings": 12800,
38
+ "mlp_bias": false,
39
+ "model_type": "hunyuan_image_3_moe",
40
+ "moe_drop_tokens": false,
41
+ "moe_intermediate_size": [
42
+ 3072,
43
+ 3072,
44
+ 3072,
45
+ 3072,
46
+ 3072,
47
+ 3072,
48
+ 3072,
49
+ 3072,
50
+ 3072,
51
+ 3072,
52
+ 3072,
53
+ 3072,
54
+ 3072,
55
+ 3072,
56
+ 3072,
57
+ 3072,
58
+ 3072,
59
+ 3072,
60
+ 3072,
61
+ 3072,
62
+ 3072,
63
+ 3072,
64
+ 3072,
65
+ 3072,
66
+ 3072,
67
+ 3072,
68
+ 3072,
69
+ 3072,
70
+ 3072,
71
+ 3072,
72
+ 3072,
73
+ 3072
74
+ ],
75
+ "moe_layer_num_skipped": 0,
76
+ "moe_random_routing_dropped_token": false,
77
+ "moe_topk": [
78
+ 8,
79
+ 8,
80
+ 8,
81
+ 8,
82
+ 8,
83
+ 8,
84
+ 8,
85
+ 8,
86
+ 8,
87
+ 8,
88
+ 8,
89
+ 8,
90
+ 8,
91
+ 8,
92
+ 8,
93
+ 8,
94
+ 8,
95
+ 8,
96
+ 8,
97
+ 8,
98
+ 8,
99
+ 8,
100
+ 8,
101
+ 8,
102
+ 8,
103
+ 8,
104
+ 8,
105
+ 8,
106
+ 8,
107
+ 8,
108
+ 8,
109
+ 8
110
+ ],
111
+ "n_group": false,
112
+ "norm_topk_prob": true,
113
+ "norm_type": "rms",
114
+ "num_attention_heads": 32,
115
+ "num_experts": 64,
116
+ "num_hidden_layers": 32,
117
+ "num_key_value_heads": 8,
118
+ "num_media_embeds": 257,
119
+ "num_shared_expert": [
120
+ 1,
121
+ 1,
122
+ 1,
123
+ 1,
124
+ 1,
125
+ 1,
126
+ 1,
127
+ 1,
128
+ 1,
129
+ 1,
130
+ 1,
131
+ 1,
132
+ 1,
133
+ 1,
134
+ 1,
135
+ 1,
136
+ 1,
137
+ 1,
138
+ 1,
139
+ 1,
140
+ 1,
141
+ 1,
142
+ 1,
143
+ 1,
144
+ 1,
145
+ 1,
146
+ 1,
147
+ 1,
148
+ 1,
149
+ 1,
150
+ 1,
151
+ 1
152
+ ],
153
+ "pad_id": 128009,
154
+ "pad_token_id": 128009,
155
+ "pool_type": "last",
156
+ "position_embedding_xdrope": false,
157
+ "pretraining_tp": 1,
158
+ "q_lora_rank": null,
159
+ "qk_nope_head_dim": null,
160
+ "qk_rope_head_dim": null,
161
+ "rms_norm_eps": 1e-05,
162
+ "rope_scaling": {
163
+ "alpha": 1.0,
164
+ "beta_fast": 32,
165
+ "beta_slow": 1,
166
+ "factor": 1.0,
167
+ "mscale": 1.0,
168
+ "mscale_all_dim": 1.0,
169
+ "type": "custom"
170
+ },
171
+ "rope_theta": 10000.0,
172
+ "routed_scaling_factor": false,
173
+ "skip_cls_token": false,
174
+ "text_end_id": 7,
175
+ "text_start_id": 6,
176
+ "tie_word_embeddings": false,
177
+ "topk_group": false,
178
+ "torch_dtype": "bfloat16",
179
+ "transformers_version": "4.50.0",
180
+ "use_cache": true,
181
+ "use_cla": false,
182
+ "use_mixed_mlp_moe": true,
183
+ "use_mla": false,
184
+ "use_qk_norm": true,
185
+ "use_rotary_pos_emb": true,
186
+ "v_head_dim": null,
187
+ "video_end_id": 10,
188
+ "video_start_id": 9,
189
+ "vit_add_patchemb_bias": false,
190
+ "vit_input_resolution": 224,
191
+ "vit_mapping_type": "resampler",
192
+ "vit_norm_type": "fused",
193
+ "vit_patch": 1,
194
+ "vit_path": null,
195
+ "vit_remove_prenorm": false,
196
+ "vit_token": 64,
197
+ "vit_type": null,
198
+ "vit_used_rms_norm": false,
199
+ "vocab_size": 133120,
200
+ "xdrope_section": null,
201
+ "head_dim": 128,
202
+ "vae_downsample_factor": [
203
+ 16,
204
+ 16
205
+ ],
206
+ "vae": {
207
+ "_class_name": "AutoencoderKLConv3D",
208
+ "block_out_channels": [
209
+ 128,
210
+ 256,
211
+ 512,
212
+ 1024,
213
+ 1024
214
+ ],
215
+ "in_channels": 3,
216
+ "out_channels": 3,
217
+ "latent_channels": 32,
218
+ "layers_per_block": 2,
219
+ "ffactor_spatial": 16,
220
+ "ffactor_temporal": 4,
221
+ "sample_size": 384,
222
+ "sample_tsize": 96,
223
+ "downsample_match_channel": true,
224
+ "upsample_match_channel": true,
225
+ "scaling_factor": 0.562679178327931
226
+ },
227
+ "vit": {
228
+ "_attn_implementation": "sdpa",
229
+ "attention_dropout": 0.0,
230
+ "hidden_act": "gelu_pytorch_tanh",
231
+ "hidden_size": 1152,
232
+ "intermediate_size": 4304,
233
+ "layer_norm_eps": 1e-06,
234
+ "num_attention_heads": 16,
235
+ "num_channels": 3,
236
+ "num_hidden_layers": 27,
237
+ "num_patches": 256,
238
+ "patch_size": 16,
239
+ "torch_dtype": "float32",
240
+ "output_attentions": false,
241
+ "output_hidden_states": false,
242
+ "use_return_dict": true
243
+ },
244
+ "vit_processor": {
245
+ "do_convert_rgb": null,
246
+ "do_normalize": true,
247
+ "do_rescale": true,
248
+ "do_resize": true,
249
+ "image_mean": [
250
+ 0.5,
251
+ 0.5,
252
+ 0.5
253
+ ],
254
+ "image_processor_type": "Siglip2ImageProcessorFast",
255
+ "image_std": [
256
+ 0.5,
257
+ 0.5,
258
+ 0.5
259
+ ],
260
+ "max_num_patches": 1024,
261
+ "patch_size": 16,
262
+ "processor_class": "Siglip2Processor",
263
+ "resample": 2,
264
+ "rescale_factor": 0.00392156862745098
265
+ },
266
+ "vit_aligner": {
267
+ "projector_type": "mlp_gelu",
268
+ "input_dim": 1152,
269
+ "n_embed": 4096,
270
+ "depth": 2,
271
+ "torch_dtype": "float32"
272
+ }
273
+ }
configuration_hunyuan.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Licensed under the TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT (the "License");
2
+ # you may not use this file except in compliance with the License.
3
+ # You may obtain a copy of the License at
4
+ #
5
+ # https://github.com/Tencent-Hunyuan/HunyuanImage-3.0/blob/main/LICENSE
6
+ #
7
+ # Unless required by applicable law or agreed to in writing, software
8
+ # distributed under the License is distributed on an "AS IS" BASIS,
9
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10
+ # See the License for the specific language governing permissions and
11
+ # limitations under the License.
12
+ # ==============================================================================
13
+
14
+ from transformers.configuration_utils import PretrainedConfig
15
+ from transformers.utils import logging
16
+ from typing import List, Union
17
+
18
+
19
+ logger = logging.get_logger(__name__)
20
+
21
+
22
+ class HunyuanImage3Config(PretrainedConfig):
23
+ r"""
24
+ This is the configuration class to store the configuration of a [`HunyuanImage3Model`]. It is used to instantiate
25
+ an Hunyuan model according to the specified arguments, defining the model architecture. Instantiating a
26
+ configuration with the defaults will yield a similar configuration to that of the Hunyuan-7B.
27
+
28
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
29
+ documentation from [`PretrainedConfig`] for more information.
30
+
31
+
32
+ Args:
33
+ vocab_size (`int`, *optional*, defaults to 32000):
34
+ Vocabulary size of the Hunyuan Image 3 model. Defines the number of different tokens that can be
35
+ represented by the `inputs_ids` passed when calling [`HunyuanImage3Model`]
36
+ hidden_size (`int`, *optional*, defaults to 4096):
37
+ Dimension of the hidden representations.
38
+ intermediate_size (`int`, *optional*, defaults to 11008):
39
+ Dimension of the MLP representations or shared MLP representations.
40
+ moe_intermediate_size (`int` or `List`, *optional*, defaults to 11008):
41
+ Dimension of the MLP representations in MoE. Use a list if you want a different size per layer.
42
+ num_hidden_layers (`int`, *optional*, defaults to 32):
43
+ Number of hidden layers in the Transformer decoder.
44
+ num_attention_heads (`int`, *optional*, defaults to 32):
45
+ Number of attention heads for each attention layer in the Transformer decoder.
46
+ num_key_value_heads (`int`, *optional*):
47
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
48
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
49
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
50
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
51
+ by meanpooling all the original heads within that group. For more details checkout [this
52
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
53
+ `num_attention_heads`.
54
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
55
+ The non-linear activation function (function or string) in the decoder.
56
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
57
+ The maximum sequence length that this model might ever be used with.
58
+ initializer_range (`float`, *optional*, defaults to 0.02):
59
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
60
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
61
+ The epsilon used by the rms normalization layers.
62
+ use_cache (`bool`, *optional*, defaults to `True`):
63
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
64
+ relevant if `config.is_decoder=True`.
65
+ pad_token_id (`int`, *optional*):
66
+ Padding token id.
67
+ bos_token_id (`int`, *optional*, defaults to 1):
68
+ Beginning of stream token id.
69
+ eos_token_id (`int`, *optional*, defaults to 2):
70
+ End of stream token id.
71
+ pretraining_tp (`int`, *optional*, defaults to 1):
72
+ Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
73
+ document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
74
+ necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
75
+ issue](https://github.com/pytorch/pytorch/issues/76232).
76
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
77
+ Whether to tie weight embeddings
78
+ rope_theta (`float`, *optional*, defaults to 10000.0):
79
+ The base period of the RoPE embeddings.
80
+ rope_scaling (`Dict`, *optional*):
81
+ Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
82
+ strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
83
+ `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
84
+ `max_position_embeddings` to the expected new maximum. See the following thread for more information on how
85
+ these scaling strategies behave:
86
+ https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
87
+ experimental feature, subject to breaking API changes in future versions.
88
+ attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
89
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
90
+ attention_dropout (`float`, *optional*, defaults to 0.0):
91
+ The dropout ratio for the attention probabilities.
92
+ use_qk_norm (`bool`, *optional*, defaults to `False`):
93
+ Whether query and key in attention use norm
94
+ use_cla (`bool`, *optional*, defaults to `False`):
95
+ Whether to use CLA in attention
96
+ cla_share_factor (`int`, *optional*, defaults to 1):
97
+ The share factor of CLA
98
+ num_experts (`int` or `List`, *optional*, defaults to 1):
99
+ The number of experts for moe. If it is a list, it will be used as the number of experts for each layer.
100
+ num_shared_expert (`int` or `List`, *optional*, defaults to 1):
101
+ The number of shared experts for moe. If it is a list, it will be used as the number of shared experts
102
+ for each layer.
103
+ moe_topk (`int` or `List`, *optional*, defaults to 1):
104
+ The topk value for moe. If it is a list, it will be used as the topk value for each layer.
105
+ capacity_factor (Not used) (`float` or `List`, *optional*, defaults to 1.0):
106
+ The capacity factor for moe. If it is a list, it will be used as the capacity factor for each layer.
107
+ moe_layer_num_skipped (`int`, *optional*, defaults to 0):
108
+ First moe_layer_num_skipped layers do not use MoE.
109
+ """
110
+
111
+ model_type = "Hunyuan"
112
+ keys_to_ignore_at_inference = ["past_key_values"]
113
+
114
+ def __init__(
115
+ self,
116
+ vocab_size=290943,
117
+ hidden_size=4096,
118
+ intermediate_size: int=11008,
119
+ moe_intermediate_size: Union[int, List]=None,
120
+ num_hidden_layers=32,
121
+ num_attention_heads=32,
122
+ num_key_value_heads=None,
123
+ attention_head_dim=None,
124
+ hidden_act="silu",
125
+ max_position_embeddings=2048,
126
+ initializer_range=0.02,
127
+ rms_norm_eps=1e-5,
128
+ use_cache=True,
129
+ pad_token_id=0,
130
+ bos_token_id=1,
131
+ eos_token_id=2,
132
+ eod_token_id=3,
133
+ im_start_id=4,
134
+ im_end_id=5,
135
+ text_start_id=6,
136
+ text_end_id=7,
137
+ image_token_id=8,
138
+ video_start_id=9,
139
+ video_end_id=10,
140
+ im_newline_id=11,
141
+ mask_init_id=12,
142
+ pretraining_tp=1,
143
+ tie_word_embeddings=False,
144
+ rope_theta=10000.0,
145
+ rope_scaling=None,
146
+ attention_bias=False,
147
+ mlp_bias=False,
148
+ attention_dropout=0.0,
149
+ use_qk_norm=False,
150
+ use_rotary_pos_emb=True,
151
+ use_cla=False,
152
+ cla_share_factor=1,
153
+ norm_type="hf_rms",
154
+ num_experts: Union[int, List] = 1,
155
+ use_mixed_mlp_moe=False,
156
+ num_shared_expert: Union[int, List] = 1,
157
+ moe_topk: Union[int, List] = 1,
158
+ capacity_factor: int = 1.0,
159
+ moe_drop_tokens=False,
160
+ moe_random_routing_dropped_token=False,
161
+ use_mla=False,
162
+ kv_lora_rank=512,
163
+ q_lora_rank=1536,
164
+ qk_rope_head_dim=64,
165
+ v_head_dim=128,
166
+ qk_nope_head_dim=128,
167
+ moe_layer_num_skipped=0,
168
+ norm_topk_prob=True,
169
+ routed_scaling_factor=1.0,
170
+ group_limited_greedy=False,
171
+ n_group=None,
172
+ topk_group=None,
173
+ add_classification_head=False,
174
+ class_num=0,
175
+ pool_type="last",
176
+ pad_id=-1,
177
+ # Added
178
+ moe_impl="eager",
179
+ vae_downsample_factor=(16, 16), # (h, w)
180
+ img_proj_type="unet",
181
+ patch_size=1,
182
+ patch_embed_hidden_dim=1024,
183
+ image_base_size=1024,
184
+ vae=None,
185
+ vit=None,
186
+ vit_processor=None,
187
+ vit_aligner=None,
188
+ **kwargs,
189
+ ):
190
+ self.vocab_size = vocab_size
191
+ self.max_position_embeddings = max_position_embeddings
192
+ self.hidden_size = hidden_size
193
+ self.intermediate_size = intermediate_size
194
+ self.moe_intermediate_size = moe_intermediate_size
195
+ self.num_hidden_layers = num_hidden_layers
196
+ self.num_attention_heads = num_attention_heads
197
+ self.moe_impl = moe_impl
198
+ self.num_experts = num_experts
199
+ self.use_mixed_mlp_moe = use_mixed_mlp_moe
200
+ self.num_shared_expert = num_shared_expert
201
+ self.moe_topk = moe_topk
202
+ self.capacity_factor = capacity_factor
203
+ self.moe_drop_tokens = moe_drop_tokens
204
+ self.moe_random_routing_dropped_token = moe_random_routing_dropped_token
205
+
206
+ if attention_head_dim is not None:
207
+ self.attention_head_dim = attention_head_dim
208
+ else:
209
+ self.attention_head_dim = self.hidden_size // num_attention_heads
210
+
211
+ # for backward compatibility
212
+ if num_key_value_heads is None:
213
+ num_key_value_heads = num_attention_heads
214
+
215
+ self.num_key_value_heads = num_key_value_heads
216
+ self.hidden_act = hidden_act
217
+ self.initializer_range = initializer_range
218
+ self.rms_norm_eps = rms_norm_eps
219
+ self.pretraining_tp = pretraining_tp
220
+ self.use_cache = use_cache
221
+ self.rope_theta = rope_theta
222
+ self.rope_scaling = rope_scaling
223
+ self.attention_bias = attention_bias
224
+ self.mlp_bias = mlp_bias
225
+ self.attention_dropout = attention_dropout
226
+ self.use_qk_norm = use_qk_norm
227
+ self.use_rotary_pos_emb = use_rotary_pos_emb
228
+ self.use_cla = use_cla
229
+ self.cla_share_factor = cla_share_factor
230
+ self.norm_type = norm_type
231
+ # MLA args
232
+ self.use_mla = use_mla
233
+ self.kv_lora_rank = kv_lora_rank
234
+ self.q_lora_rank = q_lora_rank
235
+ self.qk_rope_head_dim = qk_rope_head_dim
236
+ self.qk_nope_head_dim = qk_nope_head_dim
237
+ self.v_head_dim = v_head_dim
238
+
239
+ # DeepSeek related args
240
+ self.moe_layer_num_skipped = moe_layer_num_skipped
241
+ self.norm_topk_prob = norm_topk_prob
242
+ self.routed_scaling_factor = routed_scaling_factor
243
+ self.group_limited_greedy = group_limited_greedy
244
+ self.n_group = n_group
245
+ self.topk_group = topk_group
246
+ self.add_classification_head = add_classification_head
247
+ self.class_num = class_num
248
+ self.pool_type = pool_type
249
+ self.pad_id = pad_id
250
+
251
+ if self.class_num is not None:
252
+ self.dense_list = [self.hidden_size, self.class_num]
253
+
254
+ # ViT args
255
+ self.vit = vit
256
+ self.vit_processor = vit_processor
257
+ self.vit_aligner = vit_aligner
258
+
259
+ # Image Gen args
260
+ self.vae = vae
261
+ self.vae_downsample_factor = vae_downsample_factor
262
+ self.img_proj_type = img_proj_type
263
+ self.patch_size = patch_size
264
+ self.patch_embed_hidden_dim = patch_embed_hidden_dim
265
+ self.image_base_size = image_base_size
266
+
267
+ # token id
268
+ self.eod_token_id = eod_token_id
269
+ self.im_start_id = im_start_id
270
+ self.im_end_id = im_end_id
271
+ self.text_start_id = text_start_id
272
+ self.text_end_id = text_end_id
273
+ self.image_token_id = image_token_id
274
+ self.video_start_id = video_start_id
275
+ self.video_end_id = video_end_id
276
+ self.im_newline_id = im_newline_id
277
+ self.mask_init_id = mask_init_id
278
+
279
+ super().__init__(
280
+ pad_token_id=pad_token_id,
281
+ bos_token_id=bos_token_id,
282
+ eos_token_id=eos_token_id,
283
+ tie_word_embeddings=tie_word_embeddings,
284
+ **kwargs,
285
+ )
generation_config.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "disable_compile": true,
3
+ "eos_token_id": [
4
+ 127957
5
+ ],
6
+ "pad_token_id": 128009,
7
+ "do_sample": true,
8
+ "top_k": 1024,
9
+ "top_p": 0.95,
10
+ "temperature": 0.6,
11
+ "max_length": 12800,
12
+ "sequence_template": "pretrain",
13
+ "diff_infer_steps": 50,
14
+ "diff_guidance_scale": 5.0,
15
+ "flow_shift": 3.0,
16
+ "use_system_prompt": "None",
17
+ "drop_think": false,
18
+ "bot_task": "image",
19
+ "transformers_version": "4.50.0"
20
+ }
hunyuan.py ADDED
The diff for this file is too large to render. See raw diff
 
hunyuan_image_3_pipeline.py ADDED
@@ -0,0 +1,879 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Licensed under the TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT (the "License");
2
+ # you may not use this file except in compliance with the License.
3
+ # You may obtain a copy of the License at
4
+ #
5
+ # https://github.com/Tencent-Hunyuan/HunyuanImage-3.0/blob/main/LICENSE
6
+ #
7
+ # Unless required by applicable law or agreed to in writing, software
8
+ # distributed under the License is distributed on an "AS IS" BASIS,
9
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10
+ # See the License for the specific language governing permissions and
11
+ # limitations under the License.
12
+ # ==============================================================================
13
+ #
14
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
15
+ #
16
+ # Licensed under the Apache License, Version 2.0 (the "License");
17
+ # you may not use this file except in compliance with the License.
18
+ # You may obtain a copy of the License at
19
+ #
20
+ # http://www.apache.org/licenses/LICENSE-2.0
21
+ #
22
+ # Unless required by applicable law or agreed to in writing, software
23
+ # distributed under the License is distributed on an "AS IS" BASIS,
24
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25
+ # See the License for the specific language governing permissions and
26
+ # limitations under the License.
27
+ # ==============================================================================================
28
+
29
+ import inspect
30
+ import math
31
+ from dataclasses import dataclass
32
+ from typing import Any, Callable, Dict, List
33
+ from typing import Optional, Tuple, Union
34
+
35
+ import numpy as np
36
+ import torch
37
+ from PIL import Image
38
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
39
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
40
+ from diffusers.image_processor import VaeImageProcessor
41
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
42
+ from diffusers.schedulers.scheduling_utils import SchedulerMixin
43
+ from diffusers.utils import BaseOutput, logging
44
+ from diffusers.utils.torch_utils import randn_tensor
45
+
46
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
47
+
48
+
49
+ def retrieve_timesteps(
50
+ scheduler,
51
+ num_inference_steps: Optional[int] = None,
52
+ device: Optional[Union[str, torch.device]] = None,
53
+ timesteps: Optional[List[int]] = None,
54
+ sigmas: Optional[List[float]] = None,
55
+ **kwargs,
56
+ ):
57
+ """
58
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
59
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
60
+
61
+ Args:
62
+ scheduler (`SchedulerMixin`):
63
+ The scheduler to get timesteps from.
64
+ num_inference_steps (`int`):
65
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
66
+ must be `None`.
67
+ device (`str` or `torch.device`, *optional*):
68
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
69
+ timesteps (`List[int]`, *optional*):
70
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
71
+ `num_inference_steps` and `sigmas` must be `None`.
72
+ sigmas (`List[float]`, *optional*):
73
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
74
+ `num_inference_steps` and `timesteps` must be `None`.
75
+
76
+ Returns:
77
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
78
+ second element is the number of inference steps.
79
+ """
80
+ if timesteps is not None and sigmas is not None:
81
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
82
+ if timesteps is not None:
83
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
84
+ if not accepts_timesteps:
85
+ raise ValueError(
86
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
87
+ f" timestep schedules. Please check whether you are using the correct scheduler."
88
+ )
89
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
90
+ timesteps = scheduler.timesteps
91
+ num_inference_steps = len(timesteps)
92
+ elif sigmas is not None:
93
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
94
+ if not accept_sigmas:
95
+ raise ValueError(
96
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
97
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
98
+ )
99
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
100
+ timesteps = scheduler.timesteps
101
+ num_inference_steps = len(timesteps)
102
+ else:
103
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
104
+ timesteps = scheduler.timesteps
105
+ return timesteps, num_inference_steps
106
+
107
+
108
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
109
+ r"""
110
+ Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
111
+ Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
112
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf).
113
+
114
+ Args:
115
+ noise_cfg (`torch.Tensor`):
116
+ The predicted noise tensor for the guided diffusion process.
117
+ noise_pred_text (`torch.Tensor`):
118
+ The predicted noise tensor for the text-guided diffusion process.
119
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
120
+ A rescale factor applied to the noise predictions.
121
+ Returns:
122
+ noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
123
+ """
124
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
125
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
126
+ # rescale the results from guidance (fixes overexposure)
127
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
128
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
129
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
130
+ return noise_cfg
131
+
132
+
133
+ @dataclass
134
+ class HunyuanImage3Text2ImagePipelineOutput(BaseOutput):
135
+ samples: Union[List[Any], np.ndarray]
136
+
137
+
138
+ @dataclass
139
+ class FlowMatchDiscreteSchedulerOutput(BaseOutput):
140
+ """
141
+ Output class for the scheduler's `step` function output.
142
+
143
+ Args:
144
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
145
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
146
+ denoising loop.
147
+ """
148
+
149
+ prev_sample: torch.FloatTensor
150
+
151
+
152
+ class FlowMatchDiscreteScheduler(SchedulerMixin, ConfigMixin):
153
+ """
154
+ Euler scheduler.
155
+
156
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
157
+ methods the library implements for all schedulers such as loading and saving.
158
+
159
+ Args:
160
+ num_train_timesteps (`int`, defaults to 1000):
161
+ The number of diffusion steps to train the model.
162
+ timestep_spacing (`str`, defaults to `"linspace"`):
163
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
164
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
165
+ shift (`float`, defaults to 1.0):
166
+ The shift value for the timestep schedule.
167
+ reverse (`bool`, defaults to `True`):
168
+ Whether to reverse the timestep schedule.
169
+ """
170
+
171
+ _compatibles = []
172
+ order = 1
173
+
174
+ @register_to_config
175
+ def __init__(
176
+ self,
177
+ num_train_timesteps: int = 1000,
178
+ shift: float = 1.0,
179
+ reverse: bool = True,
180
+ solver: str = "euler",
181
+ use_flux_shift: bool = False,
182
+ flux_base_shift: float = 0.5,
183
+ flux_max_shift: float = 1.15,
184
+ n_tokens: Optional[int] = None,
185
+ ):
186
+ sigmas = torch.linspace(1, 0, num_train_timesteps + 1)
187
+
188
+ if not reverse:
189
+ sigmas = sigmas.flip(0)
190
+
191
+ self.sigmas = sigmas
192
+ # the value fed to model
193
+ self.timesteps = (sigmas[:-1] * num_train_timesteps).to(dtype=torch.float32)
194
+ self.timesteps_full = (sigmas * num_train_timesteps).to(dtype=torch.float32)
195
+
196
+ self._step_index = None
197
+ self._begin_index = None
198
+
199
+ self.supported_solver = [
200
+ "euler",
201
+ "heun-2", "midpoint-2",
202
+ "kutta-4",
203
+ ]
204
+ if solver not in self.supported_solver:
205
+ raise ValueError(f"Solver {solver} not supported. Supported solvers: {self.supported_solver}")
206
+
207
+ # empty dt and derivative (for heun)
208
+ self.derivative_1 = None
209
+ self.derivative_2 = None
210
+ self.derivative_3 = None
211
+ self.dt = None
212
+
213
+ @property
214
+ def step_index(self):
215
+ """
216
+ The index counter for current timestep. It will increase 1 after each scheduler step.
217
+ """
218
+ return self._step_index
219
+
220
+ @property
221
+ def begin_index(self):
222
+ """
223
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
224
+ """
225
+ return self._begin_index
226
+
227
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
228
+ def set_begin_index(self, begin_index: int = 0):
229
+ """
230
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
231
+
232
+ Args:
233
+ begin_index (`int`):
234
+ The begin index for the scheduler.
235
+ """
236
+ self._begin_index = begin_index
237
+
238
+ def _sigma_to_t(self, sigma):
239
+ return sigma * self.config.num_train_timesteps
240
+
241
+ @property
242
+ def state_in_first_order(self):
243
+ return self.derivative_1 is None
244
+
245
+ @property
246
+ def state_in_second_order(self):
247
+ return self.derivative_2 is None
248
+
249
+ @property
250
+ def state_in_third_order(self):
251
+ return self.derivative_3 is None
252
+
253
+ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None,
254
+ n_tokens: int = None):
255
+ """
256
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
257
+
258
+ Args:
259
+ num_inference_steps (`int`):
260
+ The number of diffusion steps used when generating samples with a pre-trained model.
261
+ device (`str` or `torch.device`, *optional*):
262
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
263
+ n_tokens (`int`, *optional*):
264
+ Number of tokens in the input sequence.
265
+ """
266
+ self.num_inference_steps = num_inference_steps
267
+
268
+ sigmas = torch.linspace(1, 0, num_inference_steps + 1)
269
+
270
+ # Apply timestep shift
271
+ if self.config.use_flux_shift:
272
+ assert isinstance(n_tokens, int), "n_tokens should be provided for flux shift"
273
+ mu = self.get_lin_function(y1=self.config.flux_base_shift, y2=self.config.flux_max_shift)(n_tokens)
274
+ sigmas = self.flux_time_shift(mu, 1.0, sigmas)
275
+ elif self.config.shift != 1.:
276
+ sigmas = self.sd3_time_shift(sigmas)
277
+
278
+ if not self.config.reverse:
279
+ sigmas = 1 - sigmas
280
+
281
+ self.sigmas = sigmas
282
+ self.timesteps = (sigmas[:-1] * self.config.num_train_timesteps).to(dtype=torch.float32, device=device)
283
+ self.timesteps_full = (sigmas * self.config.num_train_timesteps).to(dtype=torch.float32, device=device)
284
+
285
+ # empty dt and derivative (for kutta)
286
+ self.derivative_1 = None
287
+ self.derivative_2 = None
288
+ self.derivative_3 = None
289
+ self.dt = None
290
+
291
+ # Reset step index
292
+ self._step_index = None
293
+
294
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
295
+ if schedule_timesteps is None:
296
+ schedule_timesteps = self.timesteps
297
+
298
+ indices = (schedule_timesteps == timestep).nonzero()
299
+
300
+ # The sigma index that is taken for the **very** first `step`
301
+ # is always the second index (or the last index if there is only 1)
302
+ # This way we can ensure we don't accidentally skip a sigma in
303
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
304
+ pos = 1 if len(indices) > 1 else 0
305
+
306
+ return indices[pos].item()
307
+
308
+ def _init_step_index(self, timestep):
309
+ if self.begin_index is None:
310
+ if isinstance(timestep, torch.Tensor):
311
+ timestep = timestep.to(self.timesteps.device)
312
+ self._step_index = self.index_for_timestep(timestep)
313
+ else:
314
+ self._step_index = self._begin_index
315
+
316
+ def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
317
+ return sample
318
+
319
+ @staticmethod
320
+ def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15):
321
+ m = (y2 - y1) / (x2 - x1)
322
+ b = y1 - m * x1
323
+ return lambda x: m * x + b
324
+
325
+ @staticmethod
326
+ def flux_time_shift(mu: float, sigma: float, t: torch.Tensor):
327
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
328
+
329
+ def sd3_time_shift(self, t: torch.Tensor):
330
+ return (self.config.shift * t) / (1 + (self.config.shift - 1) * t)
331
+
332
+ def step(
333
+ self,
334
+ model_output: torch.FloatTensor,
335
+ timestep: Union[float, torch.FloatTensor],
336
+ sample: torch.FloatTensor,
337
+ pred_uncond: torch.FloatTensor = None,
338
+ generator: Optional[torch.Generator] = None,
339
+ n_tokens: Optional[int] = None,
340
+ return_dict: bool = True,
341
+ ) -> Union[FlowMatchDiscreteSchedulerOutput, Tuple]:
342
+ """
343
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
344
+ process from the learned model outputs (most often the predicted noise).
345
+
346
+ Args:
347
+ model_output (`torch.FloatTensor`):
348
+ The direct output from learned diffusion model.
349
+ timestep (`float`):
350
+ The current discrete timestep in the diffusion chain.
351
+ sample (`torch.FloatTensor`):
352
+ A current instance of a sample created by the diffusion process.
353
+ generator (`torch.Generator`, *optional*):
354
+ A random number generator.
355
+ n_tokens (`int`, *optional*):
356
+ Number of tokens in the input sequence.
357
+ return_dict (`bool`):
358
+ Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
359
+ tuple.
360
+
361
+ Returns:
362
+ [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
363
+ If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
364
+ returned, otherwise a tuple is returned where the first element is the sample tensor.
365
+ """
366
+
367
+ if (
368
+ isinstance(timestep, int)
369
+ or isinstance(timestep, torch.IntTensor)
370
+ or isinstance(timestep, torch.LongTensor)
371
+ ):
372
+ raise ValueError(
373
+ (
374
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
375
+ " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
376
+ " one of the `scheduler.timesteps` as a timestep."
377
+ ),
378
+ )
379
+
380
+ if self.step_index is None:
381
+ self._init_step_index(timestep)
382
+
383
+ # Upcast to avoid precision issues when computing prev_sample
384
+ sample = sample.to(torch.float32)
385
+ model_output = model_output.to(torch.float32)
386
+ pred_uncond = pred_uncond.to(torch.float32) if pred_uncond is not None else None
387
+
388
+ # dt = self.sigmas[self.step_index + 1] - self.sigmas[self.step_index]
389
+ sigma = self.sigmas[self.step_index]
390
+ sigma_next = self.sigmas[self.step_index + 1]
391
+
392
+ last_inner_step = True
393
+ if self.config.solver == "euler":
394
+ derivative, dt, sample, last_inner_step = self.first_order_method(model_output, sigma, sigma_next, sample)
395
+ elif self.config.solver in ["heun-2", "midpoint-2"]:
396
+ derivative, dt, sample, last_inner_step = self.second_order_method(model_output, sigma, sigma_next, sample)
397
+ elif self.config.solver == "kutta-4":
398
+ derivative, dt, sample, last_inner_step = self.fourth_order_method(model_output, sigma, sigma_next, sample)
399
+ else:
400
+ raise ValueError(f"Solver {self.config.solver} not supported. Supported solvers: {self.supported_solver}")
401
+
402
+ prev_sample = sample + derivative * dt
403
+
404
+ # Cast sample back to model compatible dtype
405
+ # prev_sample = prev_sample.to(model_output.dtype)
406
+
407
+ # upon completion increase step index by one
408
+ if last_inner_step:
409
+ self._step_index += 1
410
+
411
+ if not return_dict:
412
+ return (prev_sample,)
413
+
414
+ return FlowMatchDiscreteSchedulerOutput(prev_sample=prev_sample)
415
+
416
+ def first_order_method(self, model_output, sigma, sigma_next, sample):
417
+ derivative = model_output
418
+ dt = sigma_next - sigma
419
+ return derivative, dt, sample, True
420
+
421
+ def second_order_method(self, model_output, sigma, sigma_next, sample):
422
+ if self.state_in_first_order:
423
+ # store for 2nd order step
424
+ self.derivative_1 = model_output
425
+ self.dt = sigma_next - sigma
426
+ self.sample = sample
427
+
428
+ derivative = model_output
429
+ if self.config.solver == 'heun-2':
430
+ dt = self.dt
431
+ elif self.config.solver == 'midpoint-2':
432
+ dt = self.dt / 2
433
+ else:
434
+ raise NotImplementedError(f"Solver {self.config.solver} not supported.")
435
+ last_inner_step = False
436
+
437
+ else:
438
+ if self.config.solver == 'heun-2':
439
+ derivative = 0.5 * (self.derivative_1 + model_output)
440
+ elif self.config.solver == 'midpoint-2':
441
+ derivative = model_output
442
+ else:
443
+ raise NotImplementedError(f"Solver {self.config.solver} not supported.")
444
+
445
+ # 3. take prev timestep & sample
446
+ dt = self.dt
447
+ sample = self.sample
448
+ last_inner_step = True
449
+
450
+ # free dt and derivative
451
+ # Note, this puts the scheduler in "first order mode"
452
+ self.derivative_1 = None
453
+ self.dt = None
454
+ self.sample = None
455
+
456
+ return derivative, dt, sample, last_inner_step
457
+
458
+ def fourth_order_method(self, model_output, sigma, sigma_next, sample):
459
+ if self.state_in_first_order:
460
+ self.derivative_1 = model_output
461
+ self.dt = sigma_next - sigma
462
+ self.sample = sample
463
+ derivative = model_output
464
+ dt = self.dt / 2
465
+ last_inner_step = False
466
+
467
+ elif self.state_in_second_order:
468
+ self.derivative_2 = model_output
469
+ derivative = model_output
470
+ dt = self.dt / 2
471
+ last_inner_step = False
472
+
473
+ elif self.state_in_third_order:
474
+ self.derivative_3 = model_output
475
+ derivative = model_output
476
+ dt = self.dt
477
+ last_inner_step = False
478
+
479
+ else:
480
+ derivative = (1/6 * self.derivative_1 + 1/3 * self.derivative_2 + 1/3 * self.derivative_3 +
481
+ 1/6 * model_output)
482
+
483
+ # 3. take prev timestep & sample
484
+ dt = self.dt
485
+ sample = self.sample
486
+ last_inner_step = True
487
+
488
+ # free dt and derivative
489
+ # Note, this puts the scheduler in "first order mode"
490
+ self.derivative_1 = None
491
+ self.derivative_2 = None
492
+ self.derivative_3 = None
493
+ self.dt = None
494
+ self.sample = None
495
+
496
+ return derivative, dt, sample, last_inner_step
497
+
498
+ def __len__(self):
499
+ return self.config.num_train_timesteps
500
+
501
+
502
+ class ClassifierFreeGuidance:
503
+ def __init__(
504
+ self,
505
+ use_original_formulation: bool = False,
506
+ start: float = 0.0,
507
+ stop: float = 1.0,
508
+ ):
509
+ super().__init__()
510
+ self.use_original_formulation = use_original_formulation
511
+
512
+ def __call__(
513
+ self,
514
+ pred_cond: torch.Tensor,
515
+ pred_uncond: Optional[torch.Tensor],
516
+ guidance_scale: float,
517
+ step: int,
518
+ ) -> torch.Tensor:
519
+
520
+ shift = pred_cond - pred_uncond
521
+ pred = pred_cond if self.use_original_formulation else pred_uncond
522
+ pred = pred + guidance_scale * shift
523
+
524
+ return pred
525
+
526
+
527
+ class HunyuanImage3Text2ImagePipeline(DiffusionPipeline):
528
+ r"""
529
+ Pipeline for condition-to-sample generation using Stable Diffusion.
530
+
531
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
532
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
533
+
534
+ Args:
535
+ model ([`ModelMixin`]):
536
+ A model to denoise the diffused latents.
537
+ scheduler ([`SchedulerMixin`]):
538
+ A scheduler to be used in combination with `diffusion_model` to denoise the diffused latents. Can be one of
539
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
540
+ """
541
+
542
+ model_cpu_offload_seq = ""
543
+ _optional_components = []
544
+ _exclude_from_cpu_offload = []
545
+ _callback_tensor_inputs = ["latents"]
546
+
547
+ def __init__(
548
+ self,
549
+ model,
550
+ scheduler: SchedulerMixin,
551
+ vae,
552
+ progress_bar_config: Dict[str, Any] = None,
553
+ ):
554
+ super().__init__()
555
+
556
+ # ==========================================================================================
557
+ if progress_bar_config is None:
558
+ progress_bar_config = {}
559
+ if not hasattr(self, '_progress_bar_config'):
560
+ self._progress_bar_config = {}
561
+ self._progress_bar_config.update(progress_bar_config)
562
+ # ==========================================================================================
563
+
564
+ self.register_modules(
565
+ model=model,
566
+ scheduler=scheduler,
567
+ vae=vae,
568
+ )
569
+
570
+ # should be a tuple or a list corresponding to the size of latents (batch_size, channel, *size)
571
+ # if None, will be treated as a tuple of 1
572
+ self.latent_scale_factor = self.model.config.vae_downsample_factor
573
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.latent_scale_factor)
574
+
575
+ # Must start with APG_mode_
576
+ self.cfg_operator = ClassifierFreeGuidance()
577
+
578
+ @staticmethod
579
+ def denormalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
580
+ """
581
+ Denormalize an image array to [0,1].
582
+ """
583
+ return (images / 2 + 0.5).clamp(0, 1)
584
+
585
+ @staticmethod
586
+ def pt_to_numpy(images: torch.Tensor) -> np.ndarray:
587
+ """
588
+ Convert a PyTorch tensor to a NumPy image.
589
+ """
590
+ images = images.cpu().permute(0, 2, 3, 1).float().numpy()
591
+ return images
592
+
593
+ @staticmethod
594
+ def numpy_to_pil(images: np.ndarray):
595
+ """
596
+ Convert a numpy image or a batch of images to a PIL image.
597
+ """
598
+ if images.ndim == 3:
599
+ images = images[None, ...]
600
+ images = (images * 255).round().astype("uint8")
601
+ if images.shape[-1] == 1:
602
+ # special case for grayscale (single channel) images
603
+ pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
604
+ else:
605
+ pil_images = [Image.fromarray(image) for image in images]
606
+
607
+ return pil_images
608
+
609
+ def prepare_extra_func_kwargs(self, func, kwargs):
610
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
611
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
612
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
613
+ # and should be between [0, 1]
614
+ extra_kwargs = {}
615
+
616
+ for k, v in kwargs.items():
617
+ accepts = k in set(inspect.signature(func).parameters.keys())
618
+ if accepts:
619
+ extra_kwargs[k] = v
620
+ return extra_kwargs
621
+
622
+ def prepare_latents(self, batch_size, latent_channel, image_size, dtype, device, generator, latents=None):
623
+ if self.latent_scale_factor is None:
624
+ latent_scale_factor = (1,) * len(image_size)
625
+ elif isinstance(self.latent_scale_factor, int):
626
+ latent_scale_factor = (self.latent_scale_factor,) * len(image_size)
627
+ elif isinstance(self.latent_scale_factor, tuple) or isinstance(self.latent_scale_factor, list):
628
+ assert len(self.latent_scale_factor) == len(image_size), \
629
+ "len(latent_scale_factor) shoudl be the same as len(image_size)"
630
+ latent_scale_factor = self.latent_scale_factor
631
+ else:
632
+ raise ValueError(
633
+ f"latent_scale_factor should be either None, int, tuple of int, or list of int, "
634
+ f"but got {self.latent_scale_factor}"
635
+ )
636
+
637
+ latents_shape = (
638
+ batch_size,
639
+ latent_channel,
640
+ *[int(s) // f for s, f in zip(image_size, latent_scale_factor)],
641
+ )
642
+ if isinstance(generator, list) and len(generator) != batch_size:
643
+ raise ValueError(
644
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
645
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
646
+ )
647
+
648
+ if latents is None:
649
+ latents = randn_tensor(latents_shape, generator=generator, device=device, dtype=dtype)
650
+ else:
651
+ latents = latents.to(device)
652
+
653
+ # Check existence to make it compatible with FlowMatchEulerDiscreteScheduler
654
+ if hasattr(self.scheduler, "init_noise_sigma"):
655
+ # scale the initial noise by the standard deviation required by the scheduler
656
+ latents = latents * self.scheduler.init_noise_sigma
657
+
658
+ return latents
659
+
660
+ @property
661
+ def guidance_scale(self):
662
+ return self._guidance_scale
663
+
664
+ @property
665
+ def guidance_rescale(self):
666
+ return self._guidance_rescale
667
+
668
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
669
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
670
+ # corresponds to doing no classifier free guidance.
671
+ @property
672
+ def do_classifier_free_guidance(self):
673
+ return self._guidance_scale > 1.0
674
+
675
+ @property
676
+ def num_timesteps(self):
677
+ return self._num_timesteps
678
+
679
+ def set_scheduler(self, new_scheduler):
680
+ self.register_modules(scheduler=new_scheduler)
681
+
682
+ @torch.no_grad()
683
+ def __call__(
684
+ self,
685
+ batch_size: int,
686
+ image_size: List[int],
687
+ num_inference_steps: int = 50,
688
+ timesteps: List[int] = None,
689
+ sigmas: List[float] = None,
690
+ guidance_scale: float = 7.5,
691
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
692
+ latents: Optional[torch.Tensor] = None,
693
+ output_type: Optional[str] = "pil",
694
+ return_dict: bool = True,
695
+ guidance_rescale: float = 0.0,
696
+ callback_on_step_end: Optional[
697
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
698
+ ] = None,
699
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
700
+ model_kwargs: Dict[str, Any] = None,
701
+ **kwargs,
702
+ ):
703
+ r"""
704
+ The call function to the pipeline for generation.
705
+
706
+ Args:
707
+ prompt (`str` or `List[str]`):
708
+ The text to guide image generation.
709
+ image_size (`Tuple[int]` or `List[int]`):
710
+ The size (height, width) of the generated image.
711
+ num_inference_steps (`int`, *optional*, defaults to 50):
712
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
713
+ expense of slower inference.
714
+ timesteps (`List[int]`, *optional*):
715
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
716
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
717
+ passed will be used. Must be in descending order.
718
+ sigmas (`List[float]`, *optional*):
719
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
720
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
721
+ will be used.
722
+ guidance_scale (`float`, *optional*, defaults to 7.5):
723
+ A higher guidance scale value encourages the model to generate samples closely linked to the
724
+ `condition` at the expense of lower sample quality. Guidance scale is enabled when `guidance_scale > 1`.
725
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
726
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
727
+ generation deterministic.
728
+ latents (`torch.Tensor`, *optional*):
729
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for sample
730
+ generation. Can be used to tweak the same generation with different conditions. If not provided,
731
+ a latents tensor is generated by sampling using the supplied random `generator`.
732
+ output_type (`str`, *optional*, defaults to `"pil"`):
733
+ The output format of the generated sample.
734
+ return_dict (`bool`, *optional*, defaults to `True`):
735
+ Whether or not to return a [`~DiffusionPipelineOutput`] instead of a
736
+ plain tuple.
737
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
738
+ Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
739
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
740
+ using zero terminal SNR.
741
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
742
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
743
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
744
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
745
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
746
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
747
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
748
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
749
+ `._callback_tensor_inputs` attribute of your pipeline class.
750
+
751
+ Examples:
752
+
753
+ Returns:
754
+ [`~DiffusionPipelineOutput`] or `tuple`:
755
+ If `return_dict` is `True`, [`~DiffusionPipelineOutput`] is returned,
756
+ otherwise a `tuple` is returned where the first element is a list with the generated samples.
757
+ """
758
+
759
+ callback_steps = kwargs.pop("callback_steps", None)
760
+ pbar_steps = kwargs.pop("pbar_steps", None)
761
+
762
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
763
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
764
+
765
+ self._guidance_scale = guidance_scale
766
+ self._guidance_rescale = guidance_rescale
767
+
768
+ cfg_factor = 1 + self.do_classifier_free_guidance
769
+
770
+ # Define call parameters
771
+ device = self._execution_device
772
+
773
+ # Prepare timesteps
774
+ timesteps, num_inference_steps = retrieve_timesteps(
775
+ self.scheduler, num_inference_steps, device, timesteps, sigmas,
776
+ )
777
+
778
+ # Prepare latent variables
779
+ latents = self.prepare_latents(
780
+ batch_size=batch_size,
781
+ latent_channel=self.model.config.vae["latent_channels"],
782
+ image_size=image_size,
783
+ dtype=torch.bfloat16,
784
+ device=device,
785
+ generator=generator,
786
+ latents=latents,
787
+ )
788
+
789
+ # Prepare extra step kwargs.
790
+ _scheduler_step_extra_kwargs = self.prepare_extra_func_kwargs(
791
+ self.scheduler.step, {"generator": generator}
792
+ )
793
+
794
+ # Prepare model kwargs
795
+ input_ids = model_kwargs.pop("input_ids")
796
+ attention_mask = self.model._prepare_attention_mask_for_generation( # noqa
797
+ input_ids, self.model.generation_config, model_kwargs=model_kwargs,
798
+ )
799
+ model_kwargs["attention_mask"] = attention_mask.to(latents.device)
800
+
801
+ # Sampling loop
802
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
803
+ self._num_timesteps = len(timesteps)
804
+
805
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
806
+ for i, t in enumerate(timesteps):
807
+ # expand the latents if we are doing classifier free guidance
808
+ latent_model_input = torch.cat([latents] * cfg_factor)
809
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
810
+
811
+ t_expand = t.repeat(latent_model_input.shape[0])
812
+
813
+ model_inputs = self.model.prepare_inputs_for_generation(
814
+ input_ids,
815
+ images=latent_model_input,
816
+ timestep=t_expand,
817
+ **model_kwargs,
818
+ )
819
+
820
+ with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True):
821
+ model_output = self.model(**model_inputs, first_step=(i == 0))
822
+ pred = model_output["diffusion_prediction"]
823
+ pred = pred.to(dtype=torch.float32)
824
+
825
+ # perform guidance
826
+ if self.do_classifier_free_guidance:
827
+ pred_cond, pred_uncond = pred.chunk(2)
828
+ pred = self.cfg_operator(pred_cond, pred_uncond, self.guidance_scale, step=i)
829
+
830
+ if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
831
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
832
+ pred = rescale_noise_cfg(pred, pred_cond, guidance_rescale=self.guidance_rescale)
833
+
834
+ # compute the previous noisy sample x_t -> x_t-1
835
+ latents = self.scheduler.step(pred, t, latents, **_scheduler_step_extra_kwargs, return_dict=False)[0]
836
+
837
+ if i != len(timesteps) - 1:
838
+ model_kwargs = self.model._update_model_kwargs_for_generation( # noqa
839
+ model_output,
840
+ model_kwargs,
841
+ )
842
+ if input_ids.shape[1] != model_kwargs["position_ids"].shape[1]:
843
+ input_ids = torch.gather(input_ids, 1, index=model_kwargs["position_ids"])
844
+
845
+ if callback_on_step_end is not None:
846
+ callback_kwargs = {}
847
+ for k in callback_on_step_end_tensor_inputs:
848
+ callback_kwargs[k] = locals()[k]
849
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
850
+
851
+ latents = callback_outputs.pop("latents", latents)
852
+
853
+ # call the callback, if provided
854
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
855
+ progress_bar.update()
856
+
857
+ if hasattr(self.vae.config, 'scaling_factor') and self.vae.config.scaling_factor:
858
+ latents = latents / self.vae.config.scaling_factor
859
+ if hasattr(self.vae.config, 'shift_factor') and self.vae.config.shift_factor:
860
+ latents = latents + self.vae.config.shift_factor
861
+
862
+ if hasattr(self.vae, "ffactor_temporal"):
863
+ latents = latents.unsqueeze(2)
864
+
865
+ with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=True):
866
+ image = self.vae.decode(latents, return_dict=False, generator=generator)[0]
867
+
868
+ # b c t h w
869
+ if hasattr(self.vae, "ffactor_temporal"):
870
+ assert image.shape[2] == 1, "image should have shape [B, C, T, H, W] and T should be 1"
871
+ image = image.squeeze(2)
872
+
873
+ do_denormalize = [True] * image.shape[0]
874
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
875
+
876
+ if not return_dict:
877
+ return (image,)
878
+
879
+ return HunyuanImage3Text2ImagePipelineOutput(samples=image)
image_processor.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Licensed under the TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT (the "License");
2
+ # you may not use this file except in compliance with the License.
3
+ # You may obtain a copy of the License at
4
+ #
5
+ # https://github.com/Tencent-Hunyuan/HunyuanImage-3.0/blob/main/LICENSE
6
+ #
7
+ # Unless required by applicable law or agreed to in writing, software
8
+ # distributed under the License is distributed on an "AS IS" BASIS,
9
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10
+ # See the License for the specific language governing permissions and
11
+ # limitations under the License.
12
+ # ==============================================================================
13
+
14
+ from typing import Tuple
15
+
16
+ from PIL import Image
17
+ from torchvision import transforms
18
+ from transformers import Siglip2ImageProcessorFast
19
+
20
+ from .tokenizer_wrapper import ImageInfo, JointImageInfo, ResolutionGroup
21
+
22
+
23
+ def resize_and_crop(image: Image.Image, target_size: Tuple[int, int]) -> Image.Image:
24
+ tw, th = target_size
25
+ w, h = image.size
26
+
27
+ tr = th / tw
28
+ r = h / w
29
+
30
+ # resize
31
+ if r < tr:
32
+ resize_height = th
33
+ resize_width = int(round(th / h * w))
34
+ else:
35
+ resize_width = tw
36
+ resize_height = int(round(tw / w * h))
37
+
38
+ image = image.resize((resize_width, resize_height), resample=Image.Resampling.LANCZOS)
39
+
40
+ # center crop
41
+ crop_top = int(round((resize_height - th) / 2.0))
42
+ crop_left = int(round((resize_width - tw) / 2.0))
43
+
44
+ image = image.crop((crop_left, crop_top, crop_left + tw, crop_top + th))
45
+ return image
46
+
47
+
48
+ class HunyuanImage3ImageProcessor(object):
49
+ def __init__(self, config):
50
+ self.config = config
51
+
52
+ self.reso_group = ResolutionGroup(base_size=config.image_base_size)
53
+ self.vae_processor = transforms.Compose([
54
+ transforms.ToTensor(),
55
+ transforms.Normalize([0.5], [0.5]), # transform to [-1, 1]
56
+ ])
57
+ self.vision_encoder_processor = Siglip2ImageProcessorFast.from_dict(config.vit_processor)
58
+
59
+ def build_image_info(self, image_size):
60
+ # parse image size (HxW, H:W, or <img_ratio_i>)
61
+ if isinstance(image_size, str):
62
+ if image_size.startswith("<img_ratio_"):
63
+ ratio_index = int(image_size.split("_")[-1].rstrip(">"))
64
+ reso = self.reso_group[ratio_index]
65
+ image_size = reso.height, reso.width
66
+ elif 'x' in image_size:
67
+ image_size = [int(s) for s in image_size.split('x')]
68
+ elif ':' in image_size:
69
+ image_size = [int(s) for s in image_size.split(':')]
70
+ else:
71
+ raise ValueError(
72
+ f"`image_size` should be in the format of 'HxW', 'H:W' or <img_ratio_i>, got {image_size}.")
73
+ assert len(image_size) == 2, f"`image_size` should be in the format of 'HxW', got {image_size}."
74
+ elif isinstance(image_size, (list, tuple)):
75
+ assert len(image_size) == 2 and all(isinstance(s, int) for s in image_size), \
76
+ f"`image_size` should be a tuple of two integers or a string in the format of 'HxW', got {image_size}."
77
+ else:
78
+ raise ValueError(f"`image_size` should be a tuple of two integers or a string in the format of 'WxH', "
79
+ f"got {image_size}.")
80
+ image_width, image_height = self.reso_group.get_target_size(image_size[1], image_size[0])
81
+ token_height = image_height // (self.config.vae_downsample_factor[0] * self.config.patch_size)
82
+ token_width = image_width // (self.config.vae_downsample_factor[1] * self.config.patch_size)
83
+ base_size, ratio_idx = self.reso_group.get_base_size_and_ratio_index(image_size[1], image_size[0])
84
+ image_info = ImageInfo(
85
+ image_type="gen_image", image_width=image_width, image_height=image_height,
86
+ token_width=token_width, token_height=token_height, base_size=base_size, ratio_index=ratio_idx,
87
+ )
88
+ return image_info
89
+
90
+ def preprocess(self, image: Image.Image):
91
+ # ==== VAE processor ====
92
+ image_width, image_height = self.reso_group.get_target_size(image.width, image.height)
93
+ resized_image = resize_and_crop(image, (image_width, image_height))
94
+ image_tensor = self.vae_processor(resized_image)
95
+ token_height = image_height // (self.config.vae_downsample_factor[0] * self.config.patch_size)
96
+ token_width = image_width // (self.config.vae_downsample_factor[1] * self.config.patch_size)
97
+ base_size, ratio_index = self.reso_group.get_base_size_and_ratio_index(width=image_width, height=image_height)
98
+ vae_image_info = ImageInfo(
99
+ image_type="vae",
100
+ image_tensor=image_tensor.unsqueeze(0), # include batch dim
101
+ image_width=image_width, image_height=image_height,
102
+ token_width=token_width, token_height=token_height,
103
+ base_size=base_size, ratio_index=ratio_index,
104
+ )
105
+
106
+ # ==== ViT processor ====
107
+ inputs = self.vision_encoder_processor(image)
108
+ image = inputs["pixel_values"].squeeze(0) # seq_len x dim
109
+ pixel_attention_mask = inputs["pixel_attention_mask"].squeeze(0) # seq_len
110
+ spatial_shapes = inputs["spatial_shapes"].squeeze(0) # 2 (h, w)
111
+ vision_encoder_kwargs = dict(
112
+ pixel_attention_mask=pixel_attention_mask,
113
+ spatial_shapes=spatial_shapes,
114
+ )
115
+ vision_image_info = ImageInfo(
116
+ image_type="vit",
117
+ image_tensor=image.unsqueeze(0), # 1 x seq_len x dim
118
+ image_width=spatial_shapes[1].item() * self.config.vit_processor["patch_size"],
119
+ image_height=spatial_shapes[0].item() * self.config.vit_processor["patch_size"],
120
+ token_width=spatial_shapes[1].item(),
121
+ token_height=spatial_shapes[0].item(),
122
+ image_token_length=self.config.vit_processor["max_num_patches"],
123
+ # may not equal to token_width * token_height
124
+ )
125
+ return JointImageInfo(vae_image_info, vision_image_info, vision_encoder_kwargs)
model-0001-of-0032.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dad22fa5e99dcda532c242aa4d4875f9ea6fd8b2ed59e39776dec4ea55baf4e5
3
+ size 5363066616
model-0002-of-0032.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9987e8220f81b70d07b62f06ac6c92bb0faf38ccb0ddd3f30b65ed895ad4a2fb
3
+ size 5318937248
model-0003-of-0032.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:79f8d4d1b23562299da3360ac7e2437a4dd24be30b86bc8db580521b5f9b2616
3
+ size 5344627472
model-0004-of-0032.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4faf1357831b25b9f9637594312e9024ee0fa1e87c734e20afdde2845fdaa516
3
+ size 5327343192
model-0005-of-0032.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:46189f8777c117c431e46cc57ec2328fe72050452119ac7bb676bdaca3f76575
3
+ size 5344103080
model-0006-of-0032.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8f9d5f386b7c2d0b171bd8a25f3f08e3150936fde2dfd92e9aa1f6e27dbf2e0d
3
+ size 5318937248
model-0007-of-0032.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d30616044acead06484eacace50a4cab66267feb13555f235bac63d2540cf471
3
+ size 5344103088
model-0008-of-0032.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:740ccbff8fa1dbb2847fe8c342654f7d24fa81f058065e82dfbccb89ce2743c1
3
+ size 5318937256
model-0009-of-0032.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d5fc3df50de8591735d29f7acfece39b64b3735cccef176eb4a137f4ede68430
3
+ size 5344103088
model-0010-of-0032.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3f6058eb7527741d18c17131cb7810f11d8bd4c69cce10962e093e684413cd2a
3
+ size 5318937304
model-0011-of-0032.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4c38d5fd2f18191d849b444e873ff91d3f048d8c4bcd71b3035ff0f7973ac273
3
+ size 5344103232
model-0012-of-0032.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:688a6a818f6d164d345e3bb37c4f3fcee40cc7d458027d2a37f7486463843ec3
3
+ size 5318937400
model-0013-of-0032.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f77757aa32fa67f75f8f8ec5bc831d358093483c2a8692bff7477378aea00f28
3
+ size 5344103232
model-0014-of-0032.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3308c079c20008e1ac8852cfb986764064077278754492f2fd9ec893857b6489
3
+ size 5318937400
model-0015-of-0032.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e32b467eb49473c7f42696db0916ca3275c01984c48a10433d78be4d351b7ff8
3
+ size 5344103232
model-0016-of-0032.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b97d98195a45518bae971bc43c224225b60e1fbb8b2eb93115024d2bdf328dca
3
+ size 5318937400
model-0017-of-0032.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2f00339bad7371e59f2d3642fd0575abafa92fc4509803f8fe5a64492185d2ab
3
+ size 5344103224
model-0018-of-0032.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6b48a59d090d396aa9801765485381f8255d442c2da2d9e98f1c21a68c6b83b1
3
+ size 5327859080
model-0019-of-0032.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dd4e5a082f3db3b61774ce86675cfb171f33319fd3dd8f942cd952633834d334
3
+ size 5344111888
model-0020-of-0032.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f27fc2c0eedfc6b99ebe07e244c9689e89fa06dc65216d9c07aa6067783f86b5
3
+ size 5318937392