Upload folder using huggingface_hub
Browse files- .gitattributes +4 -0
- LICENSE +78 -0
- README.md +258 -0
- README_CN.md +259 -0
- assets/bench_agent_overview_v3.jpg +3 -0
- assets/bench_claw_agent.png +0 -0
- assets/bench_claw_agent2.jpg +3 -0
- assets/bench_context.jpg +3 -0
- assets/bench_stem.jpg +3 -0
- assets/logo-en.png +0 -0
- assets/logo-zh.png +0 -0
- train/README.md +298 -0
- train/README_CN.md +298 -0
- train/ds_zero2_no_offload.json +25 -0
- train/ds_zero3_no_offload.json +41 -0
- train/ds_zero3_offload.json +41 -0
- train/ds_zero3_offload_no_auto.json +36 -0
- train/example_data.jsonl +8 -0
- train/llama_factory_support/dataset_info.json +16 -0
- train/llama_factory_support/ds_zero2_offload_lora.json +33 -0
- train/llama_factory_support/ds_zero3_offload.json +40 -0
- train/llama_factory_support/hy_v3_full_sft.yaml +57 -0
- train/llama_factory_support/hy_v3_lora_sft.yaml +56 -0
- train/llama_factory_support/hy_v3_patches.py +384 -0
- train/llama_factory_support/hy_v3_template.py +46 -0
- train/llama_factory_support/train_hy_v3.py +72 -0
- train/llama_factory_support/train_lf.sh +114 -0
- train/merge_lora_weight.py +54 -0
- train/merge_lora_weight.sh +1 -0
- train/requirements.txt +12 -0
- train/tools/check_converted.py +455 -0
- train/tools/convert_ckpt_to_outer.py +641 -0
- train/train.py +564 -0
- train/train.sh +118 -0
- train/train_lora.sh +123 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
assets/bench_agent_overview_v3.jpg filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
assets/bench_claw_agent2.jpg filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
assets/bench_context.jpg filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
assets/bench_stem.jpg filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TENCENT HY COMMUNITY LICENSE AGREEMENT
|
| 2 |
+
Tencent Hy3 preview Release Date: April 23, 2026
|
| 3 |
+
THIS LICENSE AGREEMENT DOES NOT APPLY IN THE EUROPEAN UNION, UNITED KINGDOM AND SOUTH KOREA AND IS EXPRESSLY LIMITED TO THE TERRITORY, AS DEFINED BELOW.
|
| 4 |
+
By clicking to agree or by using, reproducing, modifying, distributing, performing or displaying any portion or element of the Tencent Hy Works, including via any Hosted Service, You will be deemed to have recognized and accepted the content of this Agreement, which is effective immediately.
|
| 5 |
+
1. DEFINITIONS.
|
| 6 |
+
a. “Acceptable Use Policy” shall mean the policy made available by Tencent as set forth in the Exhibit A.
|
| 7 |
+
b. “Agreement” shall mean the terms and conditions for use, reproduction, distribution, modification, performance and displaying of Tencent Hy Works or any portion or element thereof set forth herein.
|
| 8 |
+
c. “Documentation” shall mean the specifications, manuals and documentation for Tencent Hy made publicly available by Tencent.
|
| 9 |
+
d. “Hosted Service” shall mean a hosted service offered via an application programming interface (API), web access, or any other electronic or remote means.
|
| 10 |
+
e. “Licensee,” “You” or “Your” shall mean a natural person or legal entity exercising the rights granted by this Agreement and/or using the Tencent Hy Works for any purpose and in any field of use.
|
| 11 |
+
f. “Materials” shall mean, collectively, Tencent’s proprietary Tencent Hy and Documentation (and any portion thereof) as made available by Tencent under this Agreement.
|
| 12 |
+
g. “Model Derivatives” shall mean all: (i) modifications to Tencent Hy or any Model Derivative of Tencent Hy; (ii) works based on Tencent Hy or any Model Derivative of Tencent Hy; or (iii) any other machine learning model which is created by transfer of patterns of the weights, parameters, operations, or Output of Tencent Hy or any Model Derivative of Tencent Hy, to that model in order to cause that model to perform similarly to Tencent Hy or a Model Derivative of Tencent Hy, including distillation methods, methods that use intermediate data representations, or methods based on the generation of synthetic data Outputs by Tencent Hy or a Model Derivative of Tencent Hy for training that model. For clarity, Outputs by themselves are not deemed Model Derivatives.
|
| 13 |
+
h. “Output” shall mean the information and/or content output of Tencent Hy or a Model Derivative that results from operating or otherwise using Tencent Hy or a Model Derivative, including via a Hosted Service.
|
| 14 |
+
i. “Tencent,” “We” or “Us” shall mean the applicable entity or entities in the Tencent corporate family that own(s) intellectual property or other rights embodied in or utilized by the Materials.
|
| 15 |
+
j. “Tencent Hy” shall mean the large language models, text/image/video/audio/3D generation models, and multimodal large language models and their software and algorithms, including trained model weights, parameters (including optimizer states), machine-learning model code, inference-enabling code, training-enabling code, fine-tuning enabling code and other elements of the foregoing made publicly available by Us, including, without limitation to, Tencent Hy3 preview released at [https://huggingface.co/tencent/Hy3-preview; https://github.com/Tencent-Hunyuan/Hy3-preview].
|
| 16 |
+
k. “Tencent Hy Works” shall mean: (i) the Materials; (ii) Model Derivatives; and (iii) all derivative works thereof.
|
| 17 |
+
l. “Territory” shall mean the worldwide territory, excluding the territory of the European Union, United Kingdom and South Korea.
|
| 18 |
+
m. “Third Party” or “Third Parties” shall mean individuals or legal entities that are not under common control with Us or You.
|
| 19 |
+
n. “including” shall mean including but not limited to.
|
| 20 |
+
2. GRANT OF RIGHTS.
|
| 21 |
+
We grant You, for the Territory only, a non-exclusive, non-transferable and royalty-free limited license under Tencent’s intellectual property or other rights owned by Us embodied in or utilized by the Materials to use, reproduce, distribute, create derivative works of (including Model Derivatives), and make modifications to the Materials, only in accordance with the terms of this Agreement and the Acceptable Use Policy, and You must not violate (or encourage or permit anyone else to violate) any term of this Agreement or the Acceptable Use Policy.
|
| 22 |
+
3. DISTRIBUTION.
|
| 23 |
+
You may, subject to Your compliance with this Agreement, distribute or make available to Third Parties the Tencent Hy Works, exclusively in the Territory, provided that You meet all of the following conditions:
|
| 24 |
+
a. You must provide all such Third Party recipients of the Tencent Hy Works or products or services using them a copy of this Agreement;
|
| 25 |
+
b. You must cause any modified files to carry prominent notices stating that You changed the files;
|
| 26 |
+
c. You are encouraged to: (i) publish at least one technology introduction blogpost or one public statement expressing Your experience of using the Tencent Hy Works; and (ii) mark the products or services developed by using the Tencent Hy Works to indicate that the product/service is “Powered by Tencent Hy”; and
|
| 27 |
+
d. All distributions to Third Parties (other than through a Hosted Service) must be accompanied by a “Notice” text file that contains the following notice: “Tencent Hy is licensed under the Tencent Hy Community License Agreement, Copyright © 2026 Tencent. All Rights Reserved. The trademark rights of “Tencent Hy” are owned by Tencent or its affiliate.”
|
| 28 |
+
e. In the event that You use, integrate, implement, or otherwise deploy the Tencent Hy Works, in whole or in part, to provide, enable, or support any service, product, or functionality to third parties, You shall clearly, accurately, and prominently disclose to all end users the full legal name and entity of the actual provider of such service, product, or functionality. You shall expressly and conspicuously state that Tencent is not affiliated with, associated with, sponsoring, or endorsing any such service, product, or functionality. You shall not use or display any name, logo, trademark, trade name, or other indicia of Tencent in any manner that could be construed as, or be likely to create, confusion, deception, or a false impression regarding any relationship, affiliation, sponsorship, or endorsement by Tencent.
|
| 29 |
+
You may add Your own copyright statement to Your modifications and, except as set forth in this Section and in Section 5, may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Model Derivatives as a whole, provided Your use, reproduction, modification, distribution, performance and display of the work otherwise complies with the terms and conditions of this Agreement (including as regards the Territory). If You receive Tencent Hy Works from a Licensee as part of an integrated end user product, then this Section 3 of this Agreement will not apply to You.
|
| 30 |
+
4. ADDITIONAL COMMERCIAL TERMS.
|
| 31 |
+
If, on the Tencent Hy version release date, the monthly active users of all products or services made available by or for Licensee is greater than 100 million monthly active users in the preceding calendar month, You must request a license from Tencent, which Tencent may grant to You in its sole discretion, and You are not authorized to exercise any of the rights under this Agreement unless or until Tencent otherwise expressly grants You such rights.
|
| 32 |
+
5. RULES OF USE.
|
| 33 |
+
a. Your use of the Tencent Hy Works must comply with applicable laws and regulations (including trade compliance laws and regulations) and adhere to the Acceptable Use Policy for the Tencent Hy Works, which is hereby incorporated by reference into this Agreement. You must include the use restrictions referenced in these Sections 5(a) and 5(b) as an enforceable provision in any agreement (e.g., license agreement, terms of use, etc.) governing the use and/or distribution of Tencent Hy Works and You must provide notice to subsequent users to whom You distribute that Tencent Hy Works are subject to the use restrictions in these Sections 5(a) and 5(b).
|
| 34 |
+
b. You must not use the Tencent Hy Works or any Output or results of the Tencent Hy Works to improve any other AI model (other than Tencent Hy or Model Derivatives thereof).
|
| 35 |
+
c. You must not use, reproduce, modify, distribute, or display the Tencent Hy Works, Output or results of the Tencent Hy Works outside the Territory. Any such use outside the Territory is unlicensed and unauthorized under this Agreement.
|
| 36 |
+
6. INTELLECTUAL PROPERTY.
|
| 37 |
+
a. Subject to Tencent’s ownership of Tencent Hy Works made by or for Tencent and intellectual property rights therein, conditioned upon Your compliance with the terms and conditions of this Agreement, as between You and Tencent, You will be the owner of any derivative works and modifications of the Materials and any Model Derivatives that are made by or for You.
|
| 38 |
+
b. No trademark licenses are granted under this Agreement, and in connection with the Tencent Hy Works, Licensee may not use any name or mark owned by or associated with Tencent or any of its affiliates, except as required for reasonable and customary use in describing and distributing the Tencent Hy Works. Tencent hereby grants You a license to use “Tencent Hy” (the “Mark”) in the Territory solely as required to comply with the provisions of Section 3(c), provided that You comply with any applicable laws related to trademark protection. All goodwill arising out of Your use of the Mark will inure to the benefit of Tencent.
|
| 39 |
+
c. If You commence a lawsuit or other proceedings (including a cross-claim or counterclaim in a lawsuit) against Us or any person or entity alleging that the Materials or any Output, or any portion of any of the foregoing, infringe any intellectual property or other right owned or licensable by You, then all licenses granted to You under this Agreement shall terminate as of the date such lawsuit or other proceeding is filed. You will defend, indemnify and hold harmless Us from and against any claim by any Third Party arising out of or related to Your or the Third Party’s use or distribution of the Tencent Hy Works.
|
| 40 |
+
d. Tencent claims no rights in Outputs You generate. You and Your users are solely responsible for Outputs and their subsequent uses.
|
| 41 |
+
7. DISCLAIMERS OF WARRANTY AND LIMITATIONS OF LIABILITY.
|
| 42 |
+
a. We are not obligated to support, update, provide training for, or develop any further version of the Tencent Hy Works or to grant any license thereto.
|
| 43 |
+
b. UNLESS AND ONLY TO THE EXTENT REQUIRED BY APPLICABLE LAW, THE Tencent Hy WORKS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED “AS IS” WITHOUT ANY EXPRESS OR IMPLIED WARRANTIES OF ANY KIND INCLUDING ANY WARRANTIES OF TITLE, MERCHANTABILITY, NONINFRINGEMENT, COURSE OF DEALING, USAGE OF TRADE, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING, REPRODUCING, MODIFYING, PERFORMING, DISPLAYING OR DISTRIBUTING ANY OF THE Tencent Hy WORKS OR OUTPUTS AND ASSUME ANY AND ALL RISKS ASSOCIATED WITH YOUR OR A THIRD PARTY’S USE OR DISTRIBUTION OF ANY OF THE Tencent Hy WORKS OR OUTPUTS AND YOUR EXERCISE OF RIGHTS AND PERMISSIONS UNDER THIS AGREEMENT.
|
| 44 |
+
c. TO THE FULLEST EXTENT PERMITTED BY APPLICABLE LAW, IN NO EVENT SHALL TENCENT OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, FOR ANY DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, CONSEQUENTIAL OR PUNITIVE DAMAGES, OR LOST PROFITS OF ANY KIND ARISING FROM THIS AGREEMENT OR RELATED TO ANY OF THE Tencent Hy WORKS OR OUTPUTS, EVEN IF TENCENT OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
|
| 45 |
+
8. SURVIVAL AND TERMINATION.
|
| 46 |
+
a. The term of this Agreement shall commence upon Your acceptance of this Agreement or access to the Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein.
|
| 47 |
+
b. We may terminate this Agreement if You breach any of the terms or conditions of this Agreement. Upon termination of this Agreement, You must promptly delete and cease use of the Tencent Hy Works. Sections 6(a), 6(c), 7 and 9 shall survive the termination of this Agreement.
|
| 48 |
+
9. GOVERNING LAW AND JURISDICTION.
|
| 49 |
+
a. This Agreement and any dispute arising out of or relating to it will be governed by the laws of the Hong Kong Special Administrative Region of the People’s Republic of China, without regard to conflict of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement.
|
| 50 |
+
b. Exclusive jurisdiction and venue for any dispute arising out of or relating to this Agreement will be a court of competent jurisdiction in the Hong Kong Special Administrative Region of the People’s Republic of China, and Tencent and Licensee consent to the exclusive jurisdiction of such court with respect to any such dispute.
|
| 51 |
+
|
| 52 |
+
EXHIBIT A
|
| 53 |
+
ACCEPTABLE USE POLICY
|
| 54 |
+
|
| 55 |
+
Tencent reserves the right to update this Acceptable Use Policy from time to time.
|
| 56 |
+
Last modified: December 30, 2025
|
| 57 |
+
|
| 58 |
+
Tencent endeavors to promote safe and fair use of its tools and features, including Tencent Hy. You agree not to use Tencent Hy or Model Derivatives:
|
| 59 |
+
1. Outside the Territory;
|
| 60 |
+
2. In any way that violates any applicable national, federal, state, local, international or any other law or regulation;
|
| 61 |
+
3. To harm Yourself or others;
|
| 62 |
+
4. To repurpose or distribute output from Tencent Hy or any Model Derivatives to harm Yourself or others;
|
| 63 |
+
5. To override or circumvent the safety guardrails and safeguards We have put in place;
|
| 64 |
+
6. For the purpose of exploiting, harming or attempting to exploit or harm minors in any way;
|
| 65 |
+
7. To generate or disseminate verifiably false information and/or content with the purpose of harming others or influencing elections;
|
| 66 |
+
8. To generate or facilitate false online engagement, including fake reviews and other means of fake online engagement;
|
| 67 |
+
9. To intentionally defame, disparage or otherwise harass others;
|
| 68 |
+
10. To generate and/or disseminate malware (including ransomware) or any other content to be used for the purpose of harming electronic systems;
|
| 69 |
+
11. To generate or disseminate personal identifiable information with the purpose of harming others;
|
| 70 |
+
12. To generate or disseminate information (including images, code, posts, articles), and place the information in any public context (including –through the use of bot generated tweets), without expressly and conspicuously identifying that the information and/or content is machine generated;
|
| 71 |
+
13. To impersonate another individual without consent, authorization, or legal right;
|
| 72 |
+
14. To make high-stakes automated decisions in domains that affect an individual’s safety, rights or wellbeing (e.g., law enforcement, migration, medicine/health, management of critical infrastructure, safety components of products, essential services, credit, employment, housing, education, social scoring, or insurance);
|
| 73 |
+
15. In a manner that violates or disrespects the social ethics and moral standards of other countries or regions;
|
| 74 |
+
16. To perform, facilitate, threaten, incite, plan, promote or encourage violent extremism or terrorism;
|
| 75 |
+
17. For any use intended to discriminate against or harm individuals or groups based on protected characteristics or categories, online or offline social behavior or known or predicted personal or personality characteristics;
|
| 76 |
+
18. To intentionally exploit any of the vulnerabilities of a specific group of persons based on their age, social, physical or mental characteristics, in order to materially distort the behavior of a person pertaining to that group in a manner that causes or is likely to cause that person or another person physical or psychological harm;
|
| 77 |
+
19. For military purposes;
|
| 78 |
+
20. To engage in the unauthorized or unlicensed practice of any profession including, but not limited to, financial, legal, medical/health, or other professional practices.
|
README.md
ADDED
|
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<p align="left">
|
| 2 |
+
<a href="README_CN.md">中文</a> | English
|
| 3 |
+
</p>
|
| 4 |
+
<br>
|
| 5 |
+
|
| 6 |
+
<p align="center">
|
| 7 |
+
<img src="assets/logo-en.png" width="400"/> <br>
|
| 8 |
+
</p>
|
| 9 |
+
|
| 10 |
+
<div align="center" style="line-height: 1;">
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
[](#license)
|
| 14 |
+
|
| 15 |
+
[](https://huggingface.co/tencent/Hy3-preview)
|
| 16 |
+
|
| 17 |
+
[](https://modelscope.cn/models/Tencent-Hunyuan/Hy3-preview)
|
| 18 |
+
|
| 19 |
+
[](https://cnb.cool/ai-models/tencent/Hy3-preview)
|
| 20 |
+
|
| 21 |
+
[](https://ai.gitcode.com/tencent_hunyuan/Hy3-preview)
|
| 22 |
+
|
| 23 |
+
</div>
|
| 24 |
+
|
| 25 |
+
<p align="center">
|
| 26 |
+
🖥️ <a href="https://aistudio.tencent.com/"><b>Official Website</b></a> |
|
| 27 |
+
💬 <a href="https://github.com/Tencent-Hunyuan/Hy3-preview"><b>GitHub</b></a></p>
|
| 28 |
+
|
| 29 |
+
---
|
| 30 |
+
|
| 31 |
+
## Table of Contents
|
| 32 |
+
|
| 33 |
+
- [Model Introduction](#model-introduction)
|
| 34 |
+
- [Highlights](#highlights)
|
| 35 |
+
- [Benchmark Results](#benchmark-results)
|
| 36 |
+
- [STEM & Reasoning](#stem--reasoning)
|
| 37 |
+
- [Context Learning & Instruction Following](#context-learning--instruction-following)
|
| 38 |
+
- [Code & Agent](#code--agent)
|
| 39 |
+
- [News](#news)
|
| 40 |
+
- [Model Links](#model-links)
|
| 41 |
+
- [Quickstart](#quickstart)
|
| 42 |
+
- [Deployment](#deployment)
|
| 43 |
+
- [vLLM](#vllm)
|
| 44 |
+
- [SGLang](#sglang)
|
| 45 |
+
- [Training](#training)
|
| 46 |
+
- [Quantization](#quantization)
|
| 47 |
+
- [License](#license)
|
| 48 |
+
- [Contact Us](#contact-us)
|
| 49 |
+
|
| 50 |
+
---
|
| 51 |
+
|
| 52 |
+
## Model Introduction
|
| 53 |
+
|
| 54 |
+
**Hy3 preview** is a 295B-parameter Mixture-of-Experts (MoE) model with 21B active parameters and 3.8B MTP layer parameters, developed by the Tencent Hy Team. Hy3 preview is the first model trained on our rebuilt infrastructure, and the strongest we've shipped so far. It improves significantly on complex reasoning, instruction following, context learning, coding, and agent tasks.
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
| Property | Value |
|
| 58 |
+
|:---|:---|
|
| 59 |
+
| Architecture | Mixture-of-Experts (MoE) |
|
| 60 |
+
| Total Parameters | 295B |
|
| 61 |
+
| Activated Parameters | 21B |
|
| 62 |
+
| MTP Layer Parameters | 3.8B |
|
| 63 |
+
| Number of Layers (excluding MTP layer) | 80 |
|
| 64 |
+
| Number of MTP Layers | 1 |
|
| 65 |
+
| Attention Heads | 64 (GQA, 8 KV heads, head dim 128) |
|
| 66 |
+
| Hidden Size | 4096 |
|
| 67 |
+
| Intermediate Size | 13312 |
|
| 68 |
+
| Context Length | 256K |
|
| 69 |
+
| Vocabulary Size | 120832 |
|
| 70 |
+
| Number of Experts | 192 experts, top-8 activated |
|
| 71 |
+
| Supported Precisions | BF16 |
|
| 72 |
+
|
| 73 |
+
## Highlights
|
| 74 |
+
|
| 75 |
+
- **STEM & Reasoning** — Complex reasoning underpins everything else. Hy3 preview performs well on challenging STEM benchmarks like FrontierScience-Olympiad and IMOAnswerBench, and achieved excellent results in the Tsinghua Qiuzhen College Math PhD qualifying exam (Spring '26) and the China High School Biology Olympiad (CHSBO 2025), demonstrating generalizable reasoning capacity.
|
| 76 |
+
|
| 77 |
+
- **Context Learning & Instruction Following** — Real-world tasks require the ability to parse messy, lengthy contexts and follow complex rules. We built CL-bench and CL-bench-Life from our own business scenarios to innovatively measure context learning ability. Hy3 preview exhibits solid gains in both context learning and instruction following capabilities.
|
| 78 |
+
|
| 79 |
+
- **Code & Agent** — Coding and agents saw the biggest gains. With a rebuilt RL infrastructure and larger-scale training tasks, we posted competitive scores across mainstream coding agent benchmarks (SWE-bench Verified, Terminal-Bench 2.0) and search agent benchmarks (BrowseComp, WideSearch).
|
| 80 |
+
|
| 81 |
+
## Benchmark Results
|
| 82 |
+
|
| 83 |
+
### Pre-trained Model Performance
|
| 84 |
+
|
| 85 |
+
| Category | Benchmark (Metric) | # Shots | Kimi-K2 BASE | DeepSeek-V3 BASE | GLM-4.5 BASE | Hy3 preview-Base |
|
| 86 |
+
|---|---|---|---|---|---|---|
|
| 87 |
+
| | #ActivatedParams | - | 32B | 37B | 32B | 21B |
|
| 88 |
+
| | #TotalParams | - | 1043B | 671B | 355B | 295B |
|
| 89 |
+
| **English** | MMLU | 5-shot | **88.24** | 87.68 | 87.73 | 87.42 |
|
| 90 |
+
| | MMLU-Pro | 5-shot | **65.98** | 63.98 | 63.67 | 65.76 |
|
| 91 |
+
| | MMLU-Redux | 5-shot | **87.18** | 86.81 | 86.56 | 86.86 |
|
| 92 |
+
| | ARC-Challenge | 0-shot | **96.66** | 94.65 | 96.32 | 95.99 |
|
| 93 |
+
| | DROP | 5-shot | 86.40 | **86.50** | 82.90 | 85.50 |
|
| 94 |
+
| | PIQA | 4-shot | **84.93** | 84.22 | 84.71 | 84.39 |
|
| 95 |
+
| | SuperGPQA | 5-shot | 51.10 | 46.17 | 49.64 | **51.60** |
|
| 96 |
+
| | SimpleQA | 5-shot | **34.37** | 26.15 | 29.26 | 26.47 |
|
| 97 |
+
| **Code** | MBPP-plus | 3-shot | **81.35** | 75.47 | 78.05 | 78.71 |
|
| 98 |
+
| | CRUXEval-I | 3-shot | 68.01 | 67.79 | 68.51 | **71.19** |
|
| 99 |
+
| | CRUXEval-O | 3-shot | 69.62 | **71.00** | 67.75 | 68.38 |
|
| 100 |
+
| | LiveCodeBench-v6 | 1-shot | 30.86 | 29.31 | 27.43 | **34.86** |
|
| 101 |
+
| **Math** | GSM8K | 4-shot | 93.46 | 88.15 | 90.06 | **95.37** |
|
| 102 |
+
| | MATH | 4-shot | 71.20 | 59.37 | 61.00 | **76.28** |
|
| 103 |
+
| | CMath | 4-shot | 90.83 | 85.50 | 89.33 | **91.17** |
|
| 104 |
+
| **Chinese** | C-Eval | 5-shot | **91.51** | 90.35 | 85.84 | 89.80 |
|
| 105 |
+
| | CMMLU | 5-shot | **90.72** | 87.90 | 86.46 | 89.61 |
|
| 106 |
+
| | Chinese-simpleQA | 5-shot | **74.58** | 68.72 | 68.49 | 69.73 |
|
| 107 |
+
| **Multilingual** | MMMLU | 5-shot | 77.63 | 79.54 | 79.26 | **80.15** |
|
| 108 |
+
| | INCLUDE | 5-shot | 75.66 | 77.86 | 76.27 | **78.64** |
|
| 109 |
+
|
| 110 |
+
### Instruct Model Performance
|
| 111 |
+
|
| 112 |
+
#### STEM & Reasoning
|
| 113 |
+
|
| 114 |
+
Complex reasoning underpins everything else. Hy3 preview performs well on challenging STEM benchmarks like FrontierScience-Olympiad and IMOAnswerBench. It also achieved excellent results in the Tsinghua Qiuzhen College Math PhD qualifying exam (Spring '26) and the China High School Biology Olympiad (CHSBO 2025), demonstrating a high degree of generalizable reasoning capacity.
|
| 115 |
+
|
| 116 |
+
<p align="center"><img src="assets/bench_stem.jpg" width="800" alt="STEM & Reasoning benchmarks"/></p>
|
| 117 |
+
|
| 118 |
+
#### Context Learning & Instruction Following
|
| 119 |
+
|
| 120 |
+
Real-world tasks require the ability to parse messy, lengthy contexts and follow complex rules. We built CL-bench and CL-bench-Life from our own business scenarios to innovatively measure context learning ability. Hy3 preview exhibits solid gains in both context learning and instruction following capabilities.
|
| 121 |
+
|
| 122 |
+
<p align="center"><img src="assets/bench_context.jpg" width="800" alt="Context Learning & Instruction Following benchmarks"/></p>
|
| 123 |
+
|
| 124 |
+
#### Code & Agent
|
| 125 |
+
|
| 126 |
+
Coding and agents saw the biggest gains. With a rebuilt RL infrastructure and larger-scale training tasks, we posted competitive scores across mainstream coding agent benchmarks (SWE-bench Verified, Terminal-Bench 2.0) and search agent benchmarks (BrowseComp, WideSearch).
|
| 127 |
+
|
| 128 |
+
<p align="center"><img src="assets/bench_agent_overview_v3.jpg" width="800" alt="Agent benchmarks overview"/></p>
|
| 129 |
+
|
| 130 |
+
Coding is about whether a model can execute in a development environment. Search is about whether it can find and combine information from the open web. Both matter for complex agent scenarios like OpenClaw. Hy3 preview scores well on ClawEval and WildClawBench — a sign that its agent capabilities are becoming practical.
|
| 131 |
+
|
| 132 |
+
<p align="center"><img src="assets/bench_claw_agent.png" width="800" alt="Claw Agent benchmarks"/></p>
|
| 133 |
+
|
| 134 |
+
Beyond public benchmarks, we built internal evaluation sets to test the model in real development scenarios. On Hy-Backend (backend-focused tasks), Hy-Vibe Bench (real-user dev workflows), and Hy-SWE Max, Hy3 preview scores competitively against other open-source models.
|
| 135 |
+
|
| 136 |
+
<p align="center"><img src="assets/bench_claw_agent2.jpg" width="800" alt="Internal benchmarks"/></p>
|
| 137 |
+
|
| 138 |
+
## News
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
* **[2026-04-23]** 🔥 We open-source **Hy3 preview** model weights on [Hugging Face](https://huggingface.co/tencent/Hy3-preview), [ModelScope](https://modelscope.cn/models/Tencent-Hunyuan/Hy3-preview), and [GitCode](https://ai.gitcode.com/tencent_hunyuan/Hy3-preview).
|
| 142 |
+
|
| 143 |
+
## Model Links
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
| Model Name | Description | Hugging Face | ModelScope | GitCode |
|
| 147 |
+
|:---|:---|:---:|:---:|:---:|
|
| 148 |
+
| Hy3 preview | Instruct model | 🤗 [Model](https://huggingface.co/tencent/Hy3-preview) | [Model](https://modelscope.cn/models/Tencent-Hunyuan/Hy3-preview) | [Model](https://ai.gitcode.com/tencent_hunyuan/Hy3-preview) |
|
| 149 |
+
| Hy3 preview-Base | Pre-trained base model | 🤗 [Model](https://huggingface.co/tencent/Hy3-preview-Base) | [Model](https://modelscope.cn/models/Tencent-Hunyuan/Hy3-preview-Base) | [Model](https://ai.gitcode.com/tencent_hunyuan/Hy3-preview-Base) |
|
| 150 |
+
|
| 151 |
+
## Quickstart
|
| 152 |
+
|
| 153 |
+
Deploy Hy3 preview with [vLLM](#vllm) or [SGLang](#sglang) first, then call the OpenAI-compatible API:
|
| 154 |
+
|
| 155 |
+
```python
|
| 156 |
+
from openai import OpenAI
|
| 157 |
+
|
| 158 |
+
client = OpenAI(base_url="http://localhost:8000/v1", api_key="EMPTY")
|
| 159 |
+
|
| 160 |
+
response = client.chat.completions.create(
|
| 161 |
+
model="tencent/Hy3-preview",
|
| 162 |
+
messages=[
|
| 163 |
+
{"role": "user", "content": "Hello! Can you briefly introduce yourself?"},
|
| 164 |
+
],
|
| 165 |
+
temperature=0.9,
|
| 166 |
+
top_p=1.0,
|
| 167 |
+
# reasoning_effort: "no_think" (default, direct response), "low", "high" (deep chain-of-thought)
|
| 168 |
+
extra_body={"chat_template_kwargs": {"reasoning_effort": "no_think"}},
|
| 169 |
+
)
|
| 170 |
+
print(response.choices[0].message.content)
|
| 171 |
+
```
|
| 172 |
+
|
| 173 |
+
> **Recommended parameters**: `temperature=0.9`, `top_p=1.0`.
|
| 174 |
+
>
|
| 175 |
+
> **Reasoning mode**: Set `reasoning_effort` to `"high"` for complex tasks (math, coding, reasoning) or `"no_think"` for direct responses.
|
| 176 |
+
|
| 177 |
+
See the [Deployment](#deployment) section below for how to start the API server.
|
| 178 |
+
|
| 179 |
+
## Deployment
|
| 180 |
+
|
| 181 |
+
Hy3-preview has 295B parameters in total. To serve it on 8 GPUs, we recommend using H20-3e or other GPUs with larger memory capacity.
|
| 182 |
+
|
| 183 |
+
### vLLM
|
| 184 |
+
|
| 185 |
+
Build vLLM from source:
|
| 186 |
+
```bash
|
| 187 |
+
uv venv --python 3.12 --seed --managed-python
|
| 188 |
+
source .venv/bin/activate
|
| 189 |
+
git clone https://github.com/vllm-project/vllm.git
|
| 190 |
+
cd vllm
|
| 191 |
+
uv pip install --editable . --torch-backend=auto
|
| 192 |
+
```
|
| 193 |
+
|
| 194 |
+
Start the vLLM server with MTP enabled:
|
| 195 |
+
|
| 196 |
+
```bash
|
| 197 |
+
vllm serve tencent/Hy3-preview \
|
| 198 |
+
--tensor-parallel-size 8 \
|
| 199 |
+
--speculative-config.method mtp \
|
| 200 |
+
--speculative-config.num_speculative_tokens 1 \
|
| 201 |
+
--tool-call-parser hy_v3 \
|
| 202 |
+
--reasoning-parser hy_v3 \
|
| 203 |
+
--enable-auto-tool-choice \
|
| 204 |
+
--served-model-name hy3-preview
|
| 205 |
+
```
|
| 206 |
+
|
| 207 |
+
### SGLang
|
| 208 |
+
|
| 209 |
+
Build SGLang from source:
|
| 210 |
+
```bash
|
| 211 |
+
git clone https://github.com/sgl-project/sglang
|
| 212 |
+
cd sglang
|
| 213 |
+
pip3 install pip --upgrade
|
| 214 |
+
pip3 install "transformers>=5.6.0"
|
| 215 |
+
pip3 install -e "python"
|
| 216 |
+
```
|
| 217 |
+
|
| 218 |
+
Launch SGLang server with MTP enabled:
|
| 219 |
+
|
| 220 |
+
```bash
|
| 221 |
+
python3 -m sglang.launch_server \
|
| 222 |
+
--model tencent/Hy3-preview \
|
| 223 |
+
--tp 8 \
|
| 224 |
+
--tool-call-parser hunyuan \
|
| 225 |
+
--reasoning-parser hunyuan \
|
| 226 |
+
--speculative-num-steps 1 \
|
| 227 |
+
--speculative-eagle-topk 1 \
|
| 228 |
+
--speculative-num-draft-tokens 2 \
|
| 229 |
+
--speculative-algorithm EAGLE \
|
| 230 |
+
--served-model-name hy3-preview
|
| 231 |
+
```
|
| 232 |
+
|
| 233 |
+
## Training
|
| 234 |
+
|
| 235 |
+
Hy3 preview provides a complete model training pipeline, supporting both full fine-tuning and LoRA fine-tuning, with DeepSpeed ZeRO configurations and LLaMA-Factory integration.
|
| 236 |
+
|
| 237 |
+
For detailed training documentation, please refer to: [Training Guide](./train/README.md)
|
| 238 |
+
|
| 239 |
+
## Quantization
|
| 240 |
+
|
| 241 |
+
We provide [AngelSlim](https://github.com/tencent/AngelSlim), a more accessible, comprehensive, and efficient toolkit for large model compression. AngelSlim supports a comprehensive suite of compression tools for large-scale multimodal models, including common quantization algorithms, low-bit quantization, and speculative sampling.
|
| 242 |
+
|
| 243 |
+
## License
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
Hy3 preview is released under the **Tencent Hy Community License Agreement**. See [LICENSE](./LICENSE) for details.
|
| 247 |
+
|
| 248 |
+
## Contact Us
|
| 249 |
+
|
| 250 |
+
If you would like to leave a message for our R&D and product teams, welcome to contact us. You can also reach us via email:
|
| 251 |
+
|
| 252 |
+
📧 **hunyuan_opensource@tencent.com**
|
| 253 |
+
|
| 254 |
+
---
|
| 255 |
+
|
| 256 |
+
<p align="center">
|
| 257 |
+
<i>Hy3 preview is developed by the Tencent Hy Team.</i>
|
| 258 |
+
</p>
|
README_CN.md
ADDED
|
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<p align="left">
|
| 2 |
+
<a href="README.md">English</a> | 中文
|
| 3 |
+
</p>
|
| 4 |
+
<br>
|
| 5 |
+
|
| 6 |
+
<p align="center">
|
| 7 |
+
<img src="assets/logo-zh.png" width="400"/> <br>
|
| 8 |
+
</p>
|
| 9 |
+
|
| 10 |
+
<div align="center" style="line-height: 1;">
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
[](#许可证)
|
| 14 |
+
|
| 15 |
+
[](https://huggingface.co/tencent/Hy3-preview)
|
| 16 |
+
|
| 17 |
+
[](https://modelscope.cn/models/Tencent-Hunyuan/Hy3-preview)
|
| 18 |
+
|
| 19 |
+
[](https://cnb.cool/ai-models/tencent/Hy3-preview)
|
| 20 |
+
|
| 21 |
+
[](https://ai.gitcode.com/tencent_hunyuan/Hy3-preview)
|
| 22 |
+
|
| 23 |
+
</div>
|
| 24 |
+
|
| 25 |
+
<p align="center">
|
| 26 |
+
🖥️ <a href="https://aistudio.tencent.com/"><b>官方网站</b></a> |
|
| 27 |
+
💬 <a href="https://github.com/Tencent-Hunyuan/Hy3-preview"><b>GitHub</b></a></p>
|
| 28 |
+
|
| 29 |
+
---
|
| 30 |
+
|
| 31 |
+
## 目录
|
| 32 |
+
|
| 33 |
+
- [模型介绍](#模型介绍)
|
| 34 |
+
- [亮点展示](#亮点展示)
|
| 35 |
+
- [评测结果](#评测结果)
|
| 36 |
+
- [复杂推理(STEM & Reasoning)](#复杂推理stem--reasoning)
|
| 37 |
+
- [上下文学习和指令遵循(Context Learning & Instruction Following)](#上下文学习和指令遵循context-learning--instruction-following)
|
| 38 |
+
- [代码和智能体(Code & Agent)](#代码和智能体code--agent)
|
| 39 |
+
- [新闻](#新闻)
|
| 40 |
+
- [模型链接](#模型链接)
|
| 41 |
+
- [快速开始](#快速开始)
|
| 42 |
+
- [推理和部署](#推理和部署)
|
| 43 |
+
- [vLLM](#使用-vllm-推理)
|
| 44 |
+
- [SGLang](#使用-sglang-推理)
|
| 45 |
+
- [模型训练](#模型训练)
|
| 46 |
+
- [量化工具](#量化工具)
|
| 47 |
+
- [许可证](#许可证)
|
| 48 |
+
- [联系我们](#联系我们)
|
| 49 |
+
|
| 50 |
+
---
|
| 51 |
+
|
| 52 |
+
## 模型介绍
|
| 53 |
+
|
| 54 |
+
**Hy3 preview** 是由腾讯混元团队研发的快慢思考融合的混合专家模型,总参数量 295B,激活参数 21B,MTP 层参数 3.8B。Hy3 preview 是我们重建后训练的第一个模型,也是混元迄今最智能的模型,在复杂推理、指令遵循、上下文学习、代码、智能体等能力及推理性能上实现了大幅的提升。
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
| 属性 | 值 |
|
| 58 |
+
|:---|:---|
|
| 59 |
+
| 架构 | 混合专家(MoE) |
|
| 60 |
+
| 总参数量 | 295B |
|
| 61 |
+
| 激活参数量 | 21B |
|
| 62 |
+
| MTP层参数量 | 3.8B |
|
| 63 |
+
| 层数(不含MTP层) | 80 |
|
| 64 |
+
| MTP层数 | 1 |
|
| 65 |
+
| 注意力头 | 64(GQA,8 个 KV 头,head dim 128) |
|
| 66 |
+
| 隐藏层维度 | 4096 |
|
| 67 |
+
| FFN 中间层维度 | 13312 |
|
| 68 |
+
| 上下文长度 | 256K |
|
| 69 |
+
| 词表大小 | 120832 |
|
| 70 |
+
| 专家数量 | 192 个专家,top-8 激活 |
|
| 71 |
+
| 支持精度 | BF16 |
|
| 72 |
+
|
| 73 |
+
## 亮点展示
|
| 74 |
+
|
| 75 |
+
- **复杂推理(STEM & Reasoning)** — 推理能力是模型解决各种问题的基础。在 FrontierScience-Olympiad、IMOAnswerBench 等高难度理工科推理任务中表现突出,并在最新的清华大学求真书院数学博资考(26春)和全国中学生生物学联赛(CHSBO 2025)中取得优异成绩,展现出可泛化的强推理能力。
|
| 76 |
+
|
| 77 |
+
- **上下文学习和指令遵循(Context Learning & Instruction Following)** — 在各种真实的生产与生活场景,理解杂乱冗长的上下文并遵从复杂多变的规则是模型的首要挑战。基于我们多种业务场景的灵感,我们提出了 CL-bench 和 CL-bench-Life 来创新性地评估模型的上下文学习能力,并在 Hy3 preview 显著地提升了模型上下文学习和指令遵循能力。
|
| 78 |
+
|
| 79 |
+
- **代码和智能体(Code & Agent)** — Hy3 preview 提升最为显著的方向。得益于预训练及强化学习框架的重建和强化学习任务规模的提升,我们以较快的速度在 SWE-Bench Verified、Terminal-Bench 2.0 等主流代码智能体基准以及 BrowseComp、WideSearch 等主流搜索智能体基准中取得了强竞争力的结果。
|
| 80 |
+
|
| 81 |
+
## 评测结果
|
| 82 |
+
|
| 83 |
+
### 预训练模型效果
|
| 84 |
+
|
| 85 |
+
| Category | Benchmark (Metric) | # Shots | Kimi-K2 BASE | DeepSeek-V3 BASE | GLM-4.5 BASE | Hy3 preview-Base |
|
| 86 |
+
|---|---|---|---|---|---|---|
|
| 87 |
+
| | #ActivatedParams | - | 32B | 37B | 32B | 21B |
|
| 88 |
+
| | #TotalParams | - | 1043B | 671B | 355B | 295B |
|
| 89 |
+
| **English** | MMLU | 5-shot | **88.24** | 87.68 | 87.73 | 87.42 |
|
| 90 |
+
| | MMLU-Pro | 5-shot | **65.98** | 63.98 | 63.67 | 65.76 |
|
| 91 |
+
| | MMLU-Redux | 5-shot | **87.18** | 86.81 | 86.56 | 86.86 |
|
| 92 |
+
| | ARC-Challenge | 0-shot | **96.66** | 94.65 | 96.32 | 95.99 |
|
| 93 |
+
| | DROP | 5-shot | 86.40 | **86.50** | 82.90 | 85.50 |
|
| 94 |
+
| | PIQA | 4-shot | **84.93** | 84.22 | 84.71 | 84.39 |
|
| 95 |
+
| | SuperGPQA | 5-shot | 51.10 | 46.17 | 49.64 | **51.60** |
|
| 96 |
+
| | SimpleQA | 5-shot | **34.37** | 26.15 | 29.26 | 26.47 |
|
| 97 |
+
| **Code** | MBPP-plus | 3-shot | **81.35** | 75.47 | 78.05 | 78.71 |
|
| 98 |
+
| | CRUXEval-I | 3-shot | 68.01 | 67.79 | 68.51 | **71.19** |
|
| 99 |
+
| | CRUXEval-O | 3-shot | 69.62 | **71.00** | 67.75 | 68.38 |
|
| 100 |
+
| | LiveCodeBench-v6 | 1-shot | 30.86 | 29.31 | 27.43 | **34.86** |
|
| 101 |
+
| **Math** | GSM8K | 4-shot | 93.46 | 88.15 | 90.06 | **95.37** |
|
| 102 |
+
| | MATH | 4-shot | 71.20 | 59.37 | 61.00 | **76.28** |
|
| 103 |
+
| | CMath | 4-shot | 90.83 | 85.50 | 89.33 | **91.17** |
|
| 104 |
+
| **Chinese** | C-Eval | 5-shot | **91.51** | 90.35 | 85.84 | 89.80 |
|
| 105 |
+
| | CMMLU | 5-shot | **90.72** | 87.90 | 86.46 | 89.61 |
|
| 106 |
+
| | Chinese-simpleQA | 5-shot | **74.58** | 68.72 | 68.49 | 69.73 |
|
| 107 |
+
| **Multilingual** | MMMLU | 5-shot | 77.63 | 79.54 | 79.26 | **80.15** |
|
| 108 |
+
| | INCLUDE | 5-shot | 75.66 | 77.86 | 76.27 | **78.64** |
|
| 109 |
+
|
| 110 |
+
### Instruct 模型效果
|
| 111 |
+
|
| 112 |
+
#### 复杂推理(STEM & Reasoning)
|
| 113 |
+
|
| 114 |
+
推理能力是模型解决各种问题的基础。Hy3 preview 在 FrontierScience-Olympiad、IMOAnswerBench 等高难度理工科推理任务中表现突出,并在最新的清华大学求真书院数学博资考(26春)和全国中学生生物学联赛(CHSBO 2025)中取得优异成绩,展现出可泛化的强推理能力。
|
| 115 |
+
|
| 116 |
+
<p align="center"><img src="assets/bench_stem.jpg" width="800" alt="STEM & Reasoning 评测结果"/></p>
|
| 117 |
+
|
| 118 |
+
#### 上下文学习和指令遵循(Context Learning & Instruction Following)
|
| 119 |
+
|
| 120 |
+
在各种真实的生产与生活场景,理解杂乱冗长的上下文并遵从复杂多变的规则是模型的首要挑战。基于我们多种业务场景的灵感,我们提出了 CL-bench 和 CL-bench-Life 来创新性地评估模型的上下文学习能力,并在 Hy3 preview 显著地提升了模型上下文学习和指令遵循能力。
|
| 121 |
+
|
| 122 |
+
<p align="center"><img src="assets/bench_context.jpg" width="800" alt="上下文学习和指令遵循评测结果"/></p>
|
| 123 |
+
|
| 124 |
+
#### 代码和智能体(Code & Agent)
|
| 125 |
+
|
| 126 |
+
代码和智能体是 Hy3 preview 提升最为显著的方向。得益于预训练及强化学习框架的重建和强化学习任务规模的提升,我们以较快的速度在 SWE-Bench Verified、Terminal-Bench 2.0 等主流代码智能体基准以及 BrowseComp、WideSearch 等主流搜索智能体基准中取得了强竞争力的结果。
|
| 127 |
+
|
| 128 |
+
<p align="center"><img src="assets/bench_agent_overview_v3.jpg" width="800" alt="Agent 评测总览"/></p>
|
| 129 |
+
|
| 130 |
+
在数字世界中,代码关注的是模型在开发环境中的执行能力,搜索则聚焦于开放信息空间中的检索、筛选与整合能力,两者共同决定了模型在复杂智能体场景(例如 OpenClaw)中是否真正具备可用性。Hy3 preview 在 ClawEval 和 WildClawBench 等评测中表现突出,进一步表明我们的智能体能力的全面与实用性。
|
| 131 |
+
|
| 132 |
+
<p align="center"><img src="assets/bench_claw_agent.png" width="800" alt="Claw Agent 评测"/></p>
|
| 133 |
+
|
| 134 |
+
除了公开榜单,我们进一步构建了多个内部的评测集,对模型在真实开发场景中的表现进行评估。结果表明,无论是在后端工程任务集 Hy-Backend,贴近真实用户开发交互的 Hy-Vibe Bench,还是高难度软件工程开发任务集 Hy-SWE Max 上,Hy3 preview 均体现出了强竞争力。
|
| 135 |
+
|
| 136 |
+
<p align="center"><img src="assets/bench_claw_agent2.jpg" width="800" alt="内部评测结果"/></p>
|
| 137 |
+
|
| 138 |
+
## 新闻
|
| 139 |
+
|
| 140 |
+
* **[2026-04-23]** 🔥 我们在 [Hugging Face](https://huggingface.co/tencent/Hy3-preview)、[ModelScope](https://modelscope.cn/models/Tencent-Hunyuan/Hy3-preview) 和 [GitCode](https://ai.gitcode.com/tencent_hunyuan/Hy3-preview) 开源了 **Hy3 preview** 模型权重。
|
| 141 |
+
|
| 142 |
+
## 模型链接
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
| 模型名 | 简介 | Hugging Face | ModelScope | GitCode |
|
| 146 |
+
|:---|:---|:---:|:---:|:---:|
|
| 147 |
+
| Hy3 preview | Instruct 模型 | 🤗 [Model](https://huggingface.co/tencent/Hy3-preview) | [Model](https://modelscope.cn/models/Tencent-Hunyuan/Hy3-preview) | [Model](https://ai.gitcode.com/tencent_hunyuan/Hy3-preview) |
|
| 148 |
+
| Hy3 preview-Base | 预训练基座模型 | 🤗 [Model](https://huggingface.co/tencent/Hy3-preview-Base) | [Model](https://modelscope.cn/models/Tencent-Hunyuan/Hy3-preview-Base) | [Model](https://ai.gitcode.com/tencent_hunyuan/Hy3-preview-Base) |
|
| 149 |
+
|
| 150 |
+
## 快速开始
|
| 151 |
+
|
| 152 |
+
建议先通过 [vLLM](#使用-vllm-推理) 或 [SGLang](#使用-sglang-推理) 部署服务,然后通过 OpenAI 兼容 API 调用:
|
| 153 |
+
|
| 154 |
+
```python
|
| 155 |
+
from openai import OpenAI
|
| 156 |
+
|
| 157 |
+
client = OpenAI(base_url="http://localhost:8000/v1", api_key="EMPTY")
|
| 158 |
+
|
| 159 |
+
response = client.chat.completions.create(
|
| 160 |
+
model="tencent/Hy3-preview",
|
| 161 |
+
messages=[
|
| 162 |
+
{"role": "user", "content": "你好!请简单介绍一下你自己。"},
|
| 163 |
+
],
|
| 164 |
+
temperature=0.9,
|
| 165 |
+
top_p=1.0,
|
| 166 |
+
# reasoning_effort: "no_think"(默认,直接回复)、"low"、"high"(深度思维链)
|
| 167 |
+
extra_body={"chat_template_kwargs": {"reasoning_effort": "no_think"}},
|
| 168 |
+
)
|
| 169 |
+
print(response.choices[0].message.content)
|
| 170 |
+
```
|
| 171 |
+
|
| 172 |
+
> **推荐参数**:`temperature=0.9`,`top_p=1.0`。
|
| 173 |
+
>
|
| 174 |
+
> **推理模式**:复杂任务(数学、编程、推理)建议设置 `reasoning_effort="high"`,日常对话可使用默认的 `"no_think"` 直接回复。
|
| 175 |
+
|
| 176 |
+
具体部署方式请参考下方[推理和部署](#推理和部署)章节。
|
| 177 |
+
|
| 178 |
+
## 推理和部署
|
| 179 |
+
|
| 180 |
+
Hy3-preview 总参数量为 295B,当使用 8 张 GPU 时,建议使用 H20-3e 或其他有更大显存的卡型。
|
| 181 |
+
|
| 182 |
+
### vLLM
|
| 183 |
+
|
| 184 |
+
从源码构建 vLLM:
|
| 185 |
+
|
| 186 |
+
```bash
|
| 187 |
+
uv venv --python 3.12 --seed --managed-python
|
| 188 |
+
source .venv/bin/activate
|
| 189 |
+
git clone https://github.com/vllm-project/vllm.git
|
| 190 |
+
cd vllm
|
| 191 |
+
uv pip install --editable . --torch-backend=auto
|
| 192 |
+
```
|
| 193 |
+
|
| 194 |
+
启动 vLLM 服务,开启 MTP:
|
| 195 |
+
|
| 196 |
+
```bash
|
| 197 |
+
vllm serve tencent/Hy3-preview \
|
| 198 |
+
--tensor-parallel-size 8 \
|
| 199 |
+
--speculative-config.method mtp \
|
| 200 |
+
--speculative-config.num_speculative_tokens 1 \
|
| 201 |
+
--tool-call-parser hy_v3 \
|
| 202 |
+
--reasoning-parser hy_v3 \
|
| 203 |
+
--enable-auto-tool-choice \
|
| 204 |
+
--served-model-name hy3-preview
|
| 205 |
+
```
|
| 206 |
+
|
| 207 |
+
### SGLang
|
| 208 |
+
|
| 209 |
+
从源码构建 SGLang:
|
| 210 |
+
|
| 211 |
+
```bash
|
| 212 |
+
git clone https://github.com/sgl-project/sglang
|
| 213 |
+
cd sglang
|
| 214 |
+
pip3 install pip --upgrade
|
| 215 |
+
pip3 install "transformers>=5.6.0"
|
| 216 |
+
pip3 install -e "python"
|
| 217 |
+
```
|
| 218 |
+
|
| 219 |
+
启动 SGLang 服务,开启 MTP:
|
| 220 |
+
|
| 221 |
+
```bash
|
| 222 |
+
python3 -m sglang.launch_server \
|
| 223 |
+
--model tencent/Hy3-preview \
|
| 224 |
+
--tp 8 \
|
| 225 |
+
--tool-call-parser hunyuan \
|
| 226 |
+
--reasoning-parser hunyuan \
|
| 227 |
+
--speculative-num-steps 1 \
|
| 228 |
+
--speculative-eagle-topk 1 \
|
| 229 |
+
--speculative-num-draft-tokens 2 \
|
| 230 |
+
--speculative-algorithm EAGLE \
|
| 231 |
+
--served-model-name hy3-preview
|
| 232 |
+
```
|
| 233 |
+
|
| 234 |
+
## 模型训练
|
| 235 |
+
|
| 236 |
+
Hy3 preview 提供了完整的模型训练流程,支持全量微调和 LoRA 微调,同时支持 DeepSpeed ZeRO 多种配置以及 LLaMA-Factory 集成。
|
| 237 |
+
|
| 238 |
+
详细的训练文档请参考:[模型训练指南](./train/README_CN.md)
|
| 239 |
+
|
| 240 |
+
## 量化工具
|
| 241 |
+
|
| 242 |
+
我们提供了 [AngelSlim](https://github.com/tencent/AngelSlim)——一套易用、全面、高效的大模型压缩工具包,涵盖常用量化算法、低比特量化和投机采样等能力。
|
| 243 |
+
|
| 244 |
+
## 许可证
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
Hy3 preview 基于 **腾讯混元社区许可协议** 发布。详情请参阅 [LICENSE](./LICENSE)。
|
| 248 |
+
|
| 249 |
+
## 联系我们
|
| 250 |
+
|
| 251 |
+
如有问题或建议,欢迎通过邮件联系我们:
|
| 252 |
+
|
| 253 |
+
📧 **hunyuan_opensource@tencent.com**
|
| 254 |
+
|
| 255 |
+
---
|
| 256 |
+
|
| 257 |
+
<p align="center">
|
| 258 |
+
<i>Hy3 preview 由腾讯混元团队研发。</i>
|
| 259 |
+
</p>
|
assets/bench_agent_overview_v3.jpg
ADDED
|
Git LFS Details
|
assets/bench_claw_agent.png
ADDED
|
assets/bench_claw_agent2.jpg
ADDED
|
Git LFS Details
|
assets/bench_context.jpg
ADDED
|
Git LFS Details
|
assets/bench_stem.jpg
ADDED
|
Git LFS Details
|
assets/logo-en.png
ADDED
|
assets/logo-zh.png
ADDED
|
train/README.md
ADDED
|
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<p align="left">
|
| 2 |
+
<a href="README_CN.md">中文</a> | English
|
| 3 |
+
</p>
|
| 4 |
+
|
| 5 |
+
# Model Training
|
| 6 |
+
|
| 7 |
+
Hy3 preview provides processes related to model training. This section details how to process training data for model training purposes.
|
| 8 |
+
|
| 9 |
+
## Training Data Format and Processing
|
| 10 |
+
|
| 11 |
+
**Hy3 preview supports both "slow thinking" and "fast thinking" modes. By default, the model outputs in slow thinking mode. If you wish the model to use fast thinking, you can control it via the `reasoning_effort` parameter (options: `high`, `low`, `no_think`).**
|
| 12 |
+
|
| 13 |
+
The training data should be formatted as a list of messages. By default, the system prompt for both training and inference is empty, but you may customize it as needed.
|
| 14 |
+
|
| 15 |
+
```python
|
| 16 |
+
# Fast thinking pattern (no_think)
|
| 17 |
+
{"reasoning_effort": "no_think", "messages": [{"content": "You are a helpful assistant.\nThe current time is 2026-01-01 13:26:12 Thursday", "role": "system"}, {"content": "1+1=?", "role": "user"}, {"role": "assistant", "content": "1+1=2"}]}
|
| 18 |
+
|
| 19 |
+
# Slow thinking pattern (high)
|
| 20 |
+
{"reasoning_effort": "high", "messages": [{"content": "You are a helpful assistant.\nThe current time is 2026-01-01 13:26:12 Thursday", "role": "system"}, {"content": "1+1=?", "role": "user"}, {"role": "assistant", "content": "1+1=2", "reasoning_content": "The user is asking for the result of 1 + 1. In basic decimal arithmetic, 1 + 1 equals 2."}]}
|
| 21 |
+
|
| 22 |
+
from transformers import AutoTokenizer
|
| 23 |
+
tokenizer = AutoTokenizer.from_pretrained("./models", use_fast=False, trust_remote_code=True)
|
| 24 |
+
ids = tokenizer.apply_chat_template(messages, is_training=True)
|
| 25 |
+
```
|
| 26 |
+
|
| 27 |
+
## Checkpoint Format Conversion
|
| 28 |
+
|
| 29 |
+
The original Hy3 preview checkpoint stores each expert's weights independently. Before training, it needs to be converted to the HuggingFace-compatible format (fusing multiple experts per layer into 3D tensors with unified key naming) to improve loading and training speed. We provide a conversion script `convert_ckpt_to_outer.py` and a validation script `check_converted.py`, located in the `train/tools` directory.
|
| 30 |
+
|
| 31 |
+
### Conversion
|
| 32 |
+
|
| 33 |
+
```sh
|
| 34 |
+
python convert_ckpt_to_outer.py \
|
| 35 |
+
--input_dir <original_checkpoint_dir> \
|
| 36 |
+
--output_dir <output_dir> \
|
| 37 |
+
--workers 8
|
| 38 |
+
```
|
| 39 |
+
|
| 40 |
+
**Parameters:**
|
| 41 |
+
|
| 42 |
+
- `--input_dir`: Path to the original checkpoint directory (required)
|
| 43 |
+
- `--output_dir`: Path to the converted checkpoint output directory (required)
|
| 44 |
+
- `--workers`: Number of parallel worker processes, default is 8 (optional)
|
| 45 |
+
|
| 46 |
+
The conversion script performs the following steps:
|
| 47 |
+
1. Pre-scans `model.safetensors.index.json` to detect cross-shard expert groups
|
| 48 |
+
2. Converts weights shard-by-shard in parallel (key renaming + expert fusion)
|
| 49 |
+
3. Post-processes cross-shard expert groups (merges data from multiple shards)
|
| 50 |
+
4. Copies `config.json`, tokenizer, and other files
|
| 51 |
+
5. Rebuilds `model.safetensors.index.json`
|
| 52 |
+
|
| 53 |
+
### Validation
|
| 54 |
+
|
| 55 |
+
After conversion, it is recommended to validate the result using the validation script:
|
| 56 |
+
|
| 57 |
+
```sh
|
| 58 |
+
python check_converted.py <converted_checkpoint_dir> --spot-check 3
|
| 59 |
+
```
|
| 60 |
+
|
| 61 |
+
**Parameters:**
|
| 62 |
+
|
| 63 |
+
- First argument: Path to the converted checkpoint directory (required)
|
| 64 |
+
- `--spot-check`: Number of shard files to spot-check by loading tensors and verifying shape, dtype, NaN/Inf, etc. Default is 3 (optional)
|
| 65 |
+
|
| 66 |
+
The validation script checks the following:
|
| 67 |
+
1. Completeness of `config.json`
|
| 68 |
+
2. Whether all expected keys in `model.safetensors.index.json` are present (including regular layers and MTP layers)
|
| 69 |
+
3. Whether all referenced shard files exist and are non-empty
|
| 70 |
+
4. Spot-checks tensor shape, dtype, and NaN/Inf in selected shard files
|
| 71 |
+
5. Detects orphan empty shard files (cross-shard merge residues, safe to delete)
|
| 72 |
+
|
| 73 |
+
## Quick Start
|
| 74 |
+
|
| 75 |
+
You can quickly get started by following the instructions in the Quick Start Guide.
|
| 76 |
+
|
| 77 |
+
## Model Training
|
| 78 |
+
|
| 79 |
+
### Hardware Requirements
|
| 80 |
+
|
| 81 |
+
Based on testing, when `make_moe_param_leaf_module` and `zero3+offload` are disabled and `max_seq_length` is set to 4096, full fine-tuning with LoRA requires at least a single machine with 8 GPUs (each with at least 80GB of memory).
|
| 82 |
+
|
| 83 |
+
Without LoRA, at least 4 machines with 32 GPUs (each with at least 80GB of memory) are required.
|
| 84 |
+
|
| 85 |
+
### Launch Methods
|
| 86 |
+
|
| 87 |
+
Reference: [HuggingFace Transformers Trainer](https://huggingface.co/docs/transformers/main/en/main_classes/trainer)
|
| 88 |
+
|
| 89 |
+
#### Single-Machine Training
|
| 90 |
+
In the `train` directory, execute:
|
| 91 |
+
|
| 92 |
+
```sh
|
| 93 |
+
pip install -r requirements.txt
|
| 94 |
+
bash train.sh
|
| 95 |
+
```
|
| 96 |
+
|
| 97 |
+
#### Multi-Machine Training
|
| 98 |
+
|
| 99 |
+
To launch training across multiple machines, please follow the steps below and ensure all machines are within the same cluster.
|
| 100 |
+
|
| 101 |
+
##### Configure Passwordless SSH Login Between Machines
|
| 102 |
+
|
| 103 |
+
The following instructions use two machines as an example, with their IPs denoted as `${ip1}` and `${ip2}`. All steps should be performed inside the Docker container.
|
| 104 |
+
|
| 105 |
+
First, configure passwordless SSH for each container on every machine:
|
| 106 |
+
|
| 107 |
+
```sh
|
| 108 |
+
ssh-keygen # Generate id_rsa and id_rsa.pub for passwordless login
|
| 109 |
+
ssh-keygen -t rsa -A # Generate /etc/ssh/ssh_host_rsa_key and ssh_host_ecdsa_key for SSH listening
|
| 110 |
+
/usr/sbin/sshd -p 36005 -o ListenAddress=0.0.0.0 # Start SSH listening
|
| 111 |
+
echo "Port 36005" > ~/.ssh/config # Set SSH connection port to 36005
|
| 112 |
+
passwd root # Set the root password to avoid monitoring platform alerts
|
| 113 |
+
```
|
| 114 |
+
|
| 115 |
+
Note: `36005` is an example port. You may use any available port, but ensure it is **open** and **not occupied by other processes**.
|
| 116 |
+
|
| 117 |
+
Next, in each machine's container, execute:
|
| 118 |
+
|
| 119 |
+
```sh
|
| 120 |
+
cat ~/.ssh/id_rsa.pub
|
| 121 |
+
```
|
| 122 |
+
|
| 123 |
+
**Copy the output SSH public key and paste it into the `~/.ssh/authorized_keys` file, one key per line. This must be done on every machine.** In the end, the `~/.ssh/authorized_keys` file on each machine should be identical and contain the public keys of all machines.
|
| 124 |
+
|
| 125 |
+
Please note that for multi-node training, the code executed on each node must be identical. It is recommended to mount a shared network drive. If this is not possible, you must manually copy the dataset, scripts, and code to the same directory on each machine.
|
| 126 |
+
|
| 127 |
+
##### Launching Multi-Machine Training
|
| 128 |
+
|
| 129 |
+
Once the above preparations are complete and all dependencies are installed (if not, run `pip install -r requirements.txt`), add the following configuration at the beginning of `train.sh`:
|
| 130 |
+
|
| 131 |
+
```shell
|
| 132 |
+
export HOST_GPU_NUM=8
|
| 133 |
+
# IP list, comma separated. e.g. "192.168.1.1,192.168.1.2" or single node "192.168.1.1"
|
| 134 |
+
export IP_LIST=${IP_LIST:-"127.0.0.1"}
|
| 135 |
+
```
|
| 136 |
+
|
| 137 |
+
Note: If the `IP_LIST` environment variable is not set, replace `IP_LIST` with the IP list! The format is:
|
| 138 |
+
```
|
| 139 |
+
For a single IP:
|
| 140 |
+
IP_LIST=${ip_1}
|
| 141 |
+
|
| 142 |
+
For multiple IPs:
|
| 143 |
+
IP_LIST=${ip_1},${ip_2}
|
| 144 |
+
|
| 145 |
+
```
|
| 146 |
+
|
| 147 |
+
Replace `${ip_1}` and `${ip_2}` with the actual IP addresses.
|
| 148 |
+
|
| 149 |
+
Then, on the machine with `${ip1}`, execute `bash train.sh` in the `train/` directory. On first launch, you may see the following output:
|
| 150 |
+
|
| 151 |
+
```ssh
|
| 152 |
+
The authenticity of host '[ip]:36005 ([ip]:36005)' can't be established.
|
| 153 |
+
ECDSA key fingerprint is xxxxxx.
|
| 154 |
+
ECDSA key fingerprint is MD5:xxxxxx.
|
| 155 |
+
Are you sure you want to continue connecting (yes/no)?
|
| 156 |
+
```
|
| 157 |
+
|
| 158 |
+
Type `yes` to continue.
|
| 159 |
+
|
| 160 |
+
##### Key Parameters
|
| 161 |
+
|
| 162 |
+
The key parameters in the script are as follows:
|
| 163 |
+
|
| 164 |
+
- `--deepspeed`: Path to the DeepSpeed configuration file. Three default DeepSpeed configuration files are provided in the `train` folder: `ds_zero2_no_offload.json`, `ds_zero3_no_offload.json`, and `ds_zero3_offload.json`, with decreasing memory requirements in that order.
|
| 165 |
+
- `--model_name_or_path`: Path to the Hy3 preview HF pre-trained model weights to load, otherwise loading will fail.
|
| 166 |
+
- `--tokenizer_name_or_path`: Path to the tokenizer folder, otherwise loading will fail.
|
| 167 |
+
- `--train_data_file`: Path to the training file, which should be a jsonl file.
|
| 168 |
+
- `--output_dir`: Output directory where logs, tensorboard files, and model weights will be stored.
|
| 169 |
+
- `--per_device_train_batch_size`: Batch size per GPU.
|
| 170 |
+
- `--gradient_accumulation_steps`: Number of gradient accumulation steps. The global batch size is `per_device_train_batch_size * gradient_accumulation_steps * dp_size`.
|
| 171 |
+
- `--max_steps`: Total number of training steps.
|
| 172 |
+
- `--save_steps`: Number of steps between saving checkpoints.
|
| 173 |
+
- `--use_lora`: Whether to use LoRA training. Also accepts `--lora_rank`, `--lora_alpha`, and `--lora_dropout` parameters. By default, LoRA is applied to "q_proj", "k_proj", "v_proj", and "o_proj". To change this, modify the code. Note: ** When using LoRA training, only the LoRA weights are saved, not the base model weights. ** To merge LoRA weights, see the "LoRA Weight Merging" section below.
|
| 174 |
+
- `--make_moe_param_leaf_module`: When using zero3 and MoE training, treat the MoE module as a leaf module, i.e., its parameters are not partitioned by zero3. This option is expected to significantly increase memory usage.
|
| 175 |
+
- `--gradient_checkpointing`: Enable gradient checkpointing.
|
| 176 |
+
- `--train_attention_params_only`: Whether to train only attention parameters.
|
| 177 |
+
- `--learning_rate`: Maximum learning rate during training.
|
| 178 |
+
- `--min_lr`: Minimum learning rate during training.
|
| 179 |
+
- `--use_flash_attn`: Enable flash-attention for accelerated training.
|
| 180 |
+
|
| 181 |
+
**Notes:**
|
| 182 |
+
|
| 183 |
+
- To resume training from a previously saved checkpoint rather than loading pre-trained weights, specify `--resume_from_checkpoint` with the path to the checkpoint. Do not specify `--model_name_or_path`, this will load only the weights, not the training state.
|
| 184 |
+
- When resuming from a checkpoint, there may be minor differences in loss due to the randomness of some non-deterministic algorithms. This is normal. See: [HuggingFace Transformers Trainer Randomness](https://huggingface.co/docs/transformers/main/en/main_classes/trainer#randomness)
|
| 185 |
+
- When `--model_name_or_path` is specified, all model-related parameters will be ignored.
|
| 186 |
+
- Samples within a batch are padded to the length of the longest sample in the batch, but the maximum length of each sample is `max_seq_length`. Any excess will be truncated.
|
| 187 |
+
- If you see a warning about bias weights not being loaded, you can ignore it. Hunyuan-Large does not use bias.
|
| 188 |
+
|
| 189 |
+
#### What if GPU Memory is Insufficient?
|
| 190 |
+
|
| 191 |
+
Reference: [DeepSpeed Configuration](https://www.deepspeed.ai/docs/config-json/)
|
| 192 |
+
|
| 193 |
+
You can try modifying the DeepSpeed configuration by removing the `auto` attribute from the following parameters and reducing their values:
|
| 194 |
+
|
| 195 |
+
- `stage3_param_persistence_threshold`
|
| 196 |
+
- `stage3_prefetch_bucket_size`
|
| 197 |
+
- `stage3_max_reuse_distance`
|
| 198 |
+
|
| 199 |
+
#### LoRA Weight Merging
|
| 200 |
+
|
| 201 |
+
LoRA weights saved during training cannot be merged into the zero3 model at runtime, as zero3 partitions model weights across data parallel ranks. To merge LoRA weights into the base model, you can do so offline to obtain a merged weight file. Run `merge_lora_weight.sh` to merge the LoRA and base model weights. The parameters are:
|
| 202 |
+
|
| 203 |
+
- `--base_model_path`: Directory of the base model weights
|
| 204 |
+
- `--adapter_model_path`: Directory of the LoRA weights
|
| 205 |
+
- `--output_path`: Directory to save the merged weights
|
| 206 |
+
- `--save_dtype`: Data type for saving the merged weights; options are: fp16, bf16, fp32
|
| 207 |
+
|
| 208 |
+
#### LLaMA-Factory Support
|
| 209 |
+
|
| 210 |
+
If you are familiar with LLaMA-Factory, you may use it for fine-tuning. All scripts, code, and configuration files are archived in the `./train/llama_factory_support` directory. Unless otherwise specified, all files mentioned below are located in this directory.
|
| 211 |
+
|
| 212 |
+
##### Installation
|
| 213 |
+
|
| 214 |
+
You can install LLaMA-Factory by downloading the source code from https://github.com/hiyouga/LLaMA-Factory/tree/main and following the instructions on the website.
|
| 215 |
+
|
| 216 |
+
##### Configuration Files
|
| 217 |
+
|
| 218 |
+
We provide sample LLaMA-Factory training configuration files: `hy_v3_lora_sft.yaml` and `hy_v3_full_sft.yaml`, corresponding to LoRA training and full fine-tuning respectively.
|
| 219 |
+
|
| 220 |
+
Key parameters in the configuration files are as follows:
|
| 221 |
+
|
| 222 |
+
**Model:**
|
| 223 |
+
|
| 224 |
+
- `model_name_or_path`: Path to the Hy3 preview HF format pre-trained model weights
|
| 225 |
+
- `trust_remote_code`: Whether to trust remote code; Hy3 preview requires this to be set to `true`
|
| 226 |
+
|
| 227 |
+
**Training Method:**
|
| 228 |
+
|
| 229 |
+
- `stage`: Training stage, currently `sft` (supervised fine-tuning)
|
| 230 |
+
- `finetuning_type`: Fine-tuning type, either `full` (full fine-tuning) or `lora` (LoRA fine-tuning)
|
| 231 |
+
- `deepspeed`: DeepSpeed configuration file path; `ds_zero3_offload_hy.json` is recommended for full fine-tuning, `ds_zero2_offload_lora.json` for LoRA fine-tuning
|
| 232 |
+
|
| 233 |
+
**LoRA Parameters (only effective during LoRA fine-tuning):**
|
| 234 |
+
|
| 235 |
+
- `lora_rank`: LoRA rank, default `64`
|
| 236 |
+
- `lora_alpha`: LoRA alpha coefficient, default `128`
|
| 237 |
+
- `lora_dropout`: LoRA dropout ratio, default `0.05`
|
| 238 |
+
- `lora_target`: Target modules for LoRA, default `q_proj,k_proj,v_proj,o_proj`
|
| 239 |
+
|
| 240 |
+
**Dataset:**
|
| 241 |
+
|
| 242 |
+
- `dataset_dir`: Dataset directory path
|
| 243 |
+
- `dataset`: Dataset name, must be registered in `dataset_info.json` under `dataset_dir`
|
| 244 |
+
- `template`: Chat template; Hy3 preview uses `hy_v3`
|
| 245 |
+
- `cutoff_len`: Maximum sequence length; sequences exceeding this will be truncated. For full fine-tuning, can be set to `262144` (262K); for LoRA fine-tuning, `8192` is recommended to save memory
|
| 246 |
+
- `max_samples`: Maximum number of samples per dataset
|
| 247 |
+
- `overwrite_cache`: Whether to overwrite cached preprocessed datasets
|
| 248 |
+
|
| 249 |
+
**Output:**
|
| 250 |
+
|
| 251 |
+
- `output_dir`: Output directory where logs, TensorBoard files, and weights will be stored
|
| 252 |
+
- `logging_steps`: Number of steps between logging
|
| 253 |
+
- `save_steps`: Number of steps between saving checkpoints
|
| 254 |
+
- `plot_loss`: Whether to plot the training loss curve
|
| 255 |
+
- `overwrite_output_dir`: Whether to overwrite the existing output directory
|
| 256 |
+
- `save_only_model`: Whether to save only model weights (excluding optimizer states, etc.)
|
| 257 |
+
- `report_to`: Logging tool, options: `none`, `wandb`, `tensorboard`, `swanlab`, `mlflow`
|
| 258 |
+
|
| 259 |
+
**Training Hyperparameters:**
|
| 260 |
+
|
| 261 |
+
- `per_device_train_batch_size`: Batch size per GPU
|
| 262 |
+
- `gradient_accumulation_steps`: Gradient accumulation steps; `per_device_train_batch_size * gradient_accumulation_steps * dp_size` equals the global batch size
|
| 263 |
+
- `learning_rate`: Maximum learning rate; `1.0e-5` recommended for full fine-tuning, `2.0e-4` for LoRA fine-tuning
|
| 264 |
+
- `num_train_epochs`: Number of training epochs
|
| 265 |
+
- `lr_scheduler_type`: Learning rate scheduler type; `cosine_with_min_lr` is recommended
|
| 266 |
+
- `lr_scheduler_kwargs.min_lr_rate`: Ratio of minimum to maximum learning rate; e.g., `0.1` means the minimum learning rate is 10% of the maximum
|
| 267 |
+
- `warmup_ratio`: Proportion of total training steps used for warmup
|
| 268 |
+
- `bf16`: Whether to use BFloat16 mixed precision training
|
| 269 |
+
- `gradient_checkpointing`: Whether to enable gradient checkpointing to save memory
|
| 270 |
+
- `ddp_timeout`: Distributed training timeout (milliseconds)
|
| 271 |
+
- `flash_attn`: Attention implementation; `fa2` (FlashAttention-2) is recommended, `sdpa` is also available; using `fa2` requires the flash-attn package
|
| 272 |
+
- `resume_from_checkpoint`: Resume training from a specified checkpoint path; set to `null` to start from scratch
|
| 273 |
+
|
| 274 |
+
##### Launch Training
|
| 275 |
+
|
| 276 |
+
Please first configure passwordless SSH login between machines following the [Configure Passwordless SSH Login Between Machines](#configure-passwordless-ssh-login-between-machines) section above.
|
| 277 |
+
|
| 278 |
+
Modify the following configuration at the beginning of `train_lf.sh`:
|
| 279 |
+
|
| 280 |
+
```shell
|
| 281 |
+
export HOST_GPU_NUM=8
|
| 282 |
+
# IP list, comma separated. e.g. "192.168.1.1,192.168.1.2" or single node "192.168.1.1"
|
| 283 |
+
export IP_LIST=${IP_LIST:-"127.0.0.1"}
|
| 284 |
+
```
|
| 285 |
+
|
| 286 |
+
Note: If the `IP_LIST` environment variable is not set, replace `IP_LIST` with the IP list! The format is:
|
| 287 |
+
```
|
| 288 |
+
For a single IP:
|
| 289 |
+
IP_LIST=${ip_1}
|
| 290 |
+
|
| 291 |
+
For multiple IPs:
|
| 292 |
+
IP_LIST=${ip_1},${ip_2}
|
| 293 |
+
|
| 294 |
+
```
|
| 295 |
+
|
| 296 |
+
Replace `${ip_1}` and `${ip_2}` with the actual IP addresses.
|
| 297 |
+
|
| 298 |
+
Then, on each machine, run `bash train_lf.sh` in the `train/llama_factory_support/` directory.
|
train/README_CN.md
ADDED
|
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<p align="left">
|
| 2 |
+
<a href="README.md">English</a> | 中文
|
| 3 |
+
</p>
|
| 4 |
+
|
| 5 |
+
# 模型训练
|
| 6 |
+
|
| 7 |
+
Hy3 preview 提供了模型训练相关流程,您可以在此章节对训练数据格式进行处理以供模型训练使用。
|
| 8 |
+
|
| 9 |
+
## 训练数据格式及处理
|
| 10 |
+
|
| 11 |
+
**Hy3 preview 同时支持慢思考与快思考两种模式,模型的默认输出是慢思考模式,若想让模型进行快思考,可通过 `reasoning_effort` 参数控制(可选值:`high`、`low`、`no_think`)。**
|
| 12 |
+
|
| 13 |
+
训练数据按照以下形式处理为 messages 格式,训练和推理的默认 system prompt 为空,可以根据自己的需求进行设定。
|
| 14 |
+
|
| 15 |
+
```python
|
| 16 |
+
# Fast thinking pattern (no_think)
|
| 17 |
+
{"reasoning_effort": "no_think", "messages": [{"content": "你是一个有用的人工智能助手。\n现在的时间是2026-01-01 13:26:12 周四", "role": "system"}, {"content": "1+1=?", "role": "user"}, {"role": "assistant", "content": "1+1=2"}]}
|
| 18 |
+
|
| 19 |
+
# Slow thinking pattern (high)
|
| 20 |
+
{"reasoning_effort": "high", "messages": [{"content": "你是一个有用的人工智能助手。\n现在的时间是2026-01-01 13:26:12 周四", "role": "system"}, {"content": "1+1=?", "role": "user"}, {"role": "assistant", "content": "1+1=2", "reasoning_content": "1+1=2"}]}
|
| 21 |
+
|
| 22 |
+
from transformers import AutoTokenizer
|
| 23 |
+
tokenizer = AutoTokenizer.from_pretrained("./models", use_fast=False, trust_remote_code=True)
|
| 24 |
+
ids = tokenizer.apply_chat_template(messages, is_training=True)
|
| 25 |
+
```
|
| 26 |
+
|
| 27 |
+
## 权重格式转换
|
| 28 |
+
|
| 29 |
+
Hy3 preview 的原始 checkpoint 采用每个 expert 独立存储的格式,在训练前需要转换为 expert 融合后的 HuggingFace 标准格式(将同一层的多个 expert 权重融合为 3D 张量,并统一 key 命名),用于提高加载和训练的速率。我们提供了转换脚本 `convert_ckpt_to_outer.py` 和校验脚本 `check_converted.py`,位于 `train/tools` 目录下。
|
| 30 |
+
|
| 31 |
+
### 转换
|
| 32 |
+
|
| 33 |
+
```sh
|
| 34 |
+
python convert_ckpt_to_outer.py \
|
| 35 |
+
--input_dir <原始checkpoint目录> \
|
| 36 |
+
--output_dir <输出目录> \
|
| 37 |
+
--workers 8
|
| 38 |
+
```
|
| 39 |
+
|
| 40 |
+
**参数说明:**
|
| 41 |
+
|
| 42 |
+
- `--input_dir`:原始 checkpoint 目录路径(必选)
|
| 43 |
+
- `--output_dir`:转换后的 checkpoint 输出目录路径(必选)
|
| 44 |
+
- `--workers`:并行转换的进程数,默认为 8(可选)
|
| 45 |
+
|
| 46 |
+
转换脚本会执行以下步骤:
|
| 47 |
+
1. 预扫描 `model.safetensors.index.json`,检测跨 shard 的 expert 分组
|
| 48 |
+
2. 逐 shard 并行转换权重(key 重命名 + expert 融合)
|
| 49 |
+
3. 后处理跨 shard 的 expert 分组(合并来自多个 shard 的数据)
|
| 50 |
+
4. 复制 `config.json`、tokenizer 等其他文件
|
| 51 |
+
5. 重建 `model.safetensors.index.json`
|
| 52 |
+
|
| 53 |
+
### 校验
|
| 54 |
+
|
| 55 |
+
转换完成后,建议使用校验脚本验证转换结果的完整性:
|
| 56 |
+
|
| 57 |
+
```sh
|
| 58 |
+
python check_converted.py <转换后的checkpoint目录> --spot-check 3
|
| 59 |
+
```
|
| 60 |
+
|
| 61 |
+
**参数说明:**
|
| 62 |
+
|
| 63 |
+
- 第一个参数:转换后的 checkpoint 目录路径(必选)
|
| 64 |
+
- `--spot-check`:随机抽检的 shard 文件数量,会加载 tensor 并检查 shape、dtype、NaN/Inf 等,默认为 3(可选)
|
| 65 |
+
|
| 66 |
+
校验脚本会检查以下内容:
|
| 67 |
+
1. `config.json` 的完整性
|
| 68 |
+
2. `model.safetensors.index.json` 中所有预期 key 是否齐全(包括常规层和 MTP 层)
|
| 69 |
+
3. 所有引用的 shard 文件是否存在且非空
|
| 70 |
+
4. 抽检 shard 文件中 tensor 的 shape、dtype 是否正确,是否存在 NaN/Inf
|
| 71 |
+
5. 检测孤立的空 shard 文件(跨 shard 合并残留,可安全删除)
|
| 72 |
+
|
| 73 |
+
## 快速开始
|
| 74 |
+
|
| 75 |
+
您可以参照快速开始文档中的内容进行快速上手。
|
| 76 |
+
|
| 77 |
+
## 模型训练
|
| 78 |
+
|
| 79 |
+
### 硬件需求
|
| 80 |
+
|
| 81 |
+
经过测试,不开 make_moe_param_leaf_module 以及 zero3+offload,max_seq_length 为 4096,使用LORA,全量微调最少需要单机 8 卡(显存至少80G)。
|
| 82 |
+
|
| 83 |
+
如果不使用LORA,最少需要4机32卡(显存至少80G)。
|
| 84 |
+
|
| 85 |
+
### 启动方式
|
| 86 |
+
|
| 87 |
+
参考:[HuggingFace Transformers Trainer](https://huggingface.co/docs/transformers/main/en/main_classes/trainer)
|
| 88 |
+
|
| 89 |
+
#### 单机启动训练
|
| 90 |
+
在`train`目录下,执行:
|
| 91 |
+
|
| 92 |
+
```sh
|
| 93 |
+
pip install -r requirements.txt
|
| 94 |
+
bash train.sh
|
| 95 |
+
```
|
| 96 |
+
|
| 97 |
+
#### 多机启动训练
|
| 98 |
+
|
| 99 |
+
如果要用多台机器启动训练,请按照以下步骤执行,并保证多台机器在一个集群内。
|
| 100 |
+
|
| 101 |
+
##### 配置机器间免密 ssh 登录
|
| 102 |
+
|
| 103 |
+
以下操作以两个机器为例,两台机器的 ip 分别以`${ip1}`和`${ip2}`标识,以下操作均在 docker container 内执行。
|
| 104 |
+
|
| 105 |
+
首先,配置多机container免密,在每台机器上执行。
|
| 106 |
+
|
| 107 |
+
```sh
|
| 108 |
+
ssh-keygen # 生成id_rsa和id_rsa.pub,用于免密登录
|
| 109 |
+
ssh-keygen -t rsa -A # 生成/etc/ssh/ssh_host_rsa_key和ssh_host_ecdsa_key, 用于后面启动ssh listen
|
| 110 |
+
/usr/sbin/sshd -p 36005 -o ListenAddress=0.0.0.0 # 启动 SSH 监听
|
| 111 |
+
echo "Port 36005" > ~/.ssh/config # ssh 连接端口修改为 36005
|
| 112 |
+
passwd root # 需要配置root密码,否则监测平台会报警
|
| 113 |
+
```
|
| 114 |
+
|
| 115 |
+
注意:这里的`36005`是一个示例端口,可以选用任意端口,但需要保证使用的端口**开放**且**不被其他的进程占用**。
|
| 116 |
+
|
| 117 |
+
接下来,在每台机器的 container 内,执行:
|
| 118 |
+
|
| 119 |
+
```sh
|
| 120 |
+
cat ~/.ssh/id_rsa.pub
|
| 121 |
+
```
|
| 122 |
+
|
| 123 |
+
**将输出的 ssh 公钥复制并粘贴到`~/.ssh/authorized_keys`文件中,每行一个公钥,每台机器上都要做这个操作**。最终每台机器上的`~/.ssh/authorized_keys`文件内容应当是一致的,并且包含了所有机器的公钥。
|
| 124 |
+
|
| 125 |
+
需要注意,多节点训练时,每个节点上执行的代码都得一致,建议挂载一个共享的网络盘,如果无法挂载共享网盘,则需要手动将数据集、脚本、代码复制在多台机器的相同目录下。
|
| 126 |
+
|
| 127 |
+
##### 启动多机训练
|
| 128 |
+
|
| 129 |
+
在以上准备步骤准备好了之后,以及确认依赖已经安装完成(如未安装,请执行`pip install -r requirements.txt`安装),就可以在`train.sh`中的开头增加以下配置:
|
| 130 |
+
|
| 131 |
+
```shell
|
| 132 |
+
export HOST_GPU_NUM=8
|
| 133 |
+
# IP list, comma separated. e.g. "192.168.1.1,192.168.1.2" or single node "192.168.1.1"
|
| 134 |
+
IP_LIST=${IP_LIST:-"127.0.0.1"}
|
| 135 |
+
```
|
| 136 |
+
|
| 137 |
+
注意:如果`IP_LIST`环境变量未设置,则将`IP_LIST`替换为IP列表!格式为:
|
| 138 |
+
```
|
| 139 |
+
如果只有一个IP:
|
| 140 |
+
IP_LIST=${ip_1}
|
| 141 |
+
|
| 142 |
+
如果有多个IP:
|
| 143 |
+
IP_LIST=${ip_1},${ip_2}
|
| 144 |
+
|
| 145 |
+
```
|
| 146 |
+
|
| 147 |
+
请将`${ip_1}`和`${ip_2}`替换为真实的IP地址。
|
| 148 |
+
|
| 149 |
+
然后,在`${ip1}`的机器上,在`train/`目录下,执行`bash train.sh`即可,注意第一次启动时可能会看见以下的输出:
|
| 150 |
+
|
| 151 |
+
```ssh
|
| 152 |
+
The authenticity of host '[ip]:36005 ([ip]:36005)' can't be established.
|
| 153 |
+
ECDSA key fingerprint is xxxxxx.
|
| 154 |
+
ECDSA key fingerprint is MD5:xxxxxx.
|
| 155 |
+
Are you sure you want to continue connecting (yes/no)?
|
| 156 |
+
```
|
| 157 |
+
|
| 158 |
+
此时输入`yes`即可继续。
|
| 159 |
+
|
| 160 |
+
##### 关键参数
|
| 161 |
+
|
| 162 |
+
脚本中的关键参数如下:
|
| 163 |
+
|
| 164 |
+
- `--deepspeed`: 此参数应当指向一个 deepspeed 的配置文件,`train`文件夹下提供了三种 DeepSpeed 的默认配置文件:`ds_zero2_no_offload.json`, `ds_zero3_no_offload.json`, `ds_zero3_offload.json`,这三个配置文件所需显存依次减少
|
| 165 |
+
- `--model_name_or_path`: 要加载的 Hy3 preview 的 HF 预训练模型权重,否则无法加载
|
| 166 |
+
- `--tokenizer_name_or_path`: tokenizer 文件夹路径, 否则无法加载
|
| 167 |
+
- `--train_data_file`: 训练文件路径,应该为一个 jsonl 文件
|
| 168 |
+
- `--output_dir`: 输出文件夹,log、tensorboard 和权重都会存储在这个路径下
|
| 169 |
+
- `--per_device_train_batch_size`: 每张卡上的 batch size
|
| 170 |
+
- `--gradient_accumulation_steps`: 梯度累计次数,`per_device_train_batch_size * gradient_accumulation_steps * dp_size`为 global_batch_size
|
| 171 |
+
- `--max_steps`: 训练的总步数
|
| 172 |
+
- `--save_steps`: 每多少个 step 存储一个 checkpoint
|
| 173 |
+
- `--use_lora`: 是否用 lora 训练,同时接收`--lora_rank`,`--lora_alpha`和`--lora_dropout`参数。lora 默认应用于 "q_proj", "k_proj", "v_proj", "o_proj" 四个参数,如果需要改变的话在代码中修改即可。注意:**使用 lora 训练时,只会保存 lora 的权重,而不会保存 base 模型的权重**,如果需要合并 lora 权重,看下面的“Lora 权重合并”一节
|
| 174 |
+
- `--make_moe_param_leaf_module`:当用 zero3 以及 MoE 训练时,将 MoE 模块视作一个 leaf module,即它的参数不进行 zero3 切分,这个选项预计会显著增加显存占用
|
| 175 |
+
- `--gradient_checkpointing`:开启梯度重计算
|
| 176 |
+
- `--train_attention_params_only`: 是否只训练 attention 参数
|
| 177 |
+
- `--learning_rate`: 训练时的最大学习率
|
| 178 |
+
- `--min_lr`: 训练时的最小学习率
|
| 179 |
+
- `--use_flash_attn`: 开启 flash-attention 进行训练加速
|
| 180 |
+
|
| 181 |
+
**注意:**
|
| 182 |
+
|
| 183 |
+
- 如果想从一个中途保存的 ckpt 继续训练,而不是加载一个预训练的权重,直接指定`--resume_from_checkpoint`为之前训练保存的 ckpt 路径,不要指定`--model_name_or_path`,这样只会加载权重,而不会加载训练状态
|
| 184 |
+
- 从 ckpt 继续训练时,loss 可能会有微小的偏差,这是由一些非确定性算法带来的随机性,是正常现象。参考:[HuggingFace Transformers Trainer Randomness](https://huggingface.co/docs/transformers/main/en/main_classes/trainer#randomness)
|
| 185 |
+
- 当 `--model_name_or_path` 有效时,所有模型相关的参数都会被忽略
|
| 186 |
+
- 一个 batch 内的样本会通过 padding 对齐 batch 内最长的样本,而每条样本的长度最长为 max_seq_length,超出的部分会被裁剪
|
| 187 |
+
- 如果报出 bias 权重没有 load 的 warning,忽略即可,Hunyuan-Large 中不会用到 bias
|
| 188 |
+
|
| 189 |
+
#### 显存不足怎么办?
|
| 190 |
+
|
| 191 |
+
参考:[DeepSpeed Configuration](https://www.deepspeed.ai/docs/config-json/)
|
| 192 |
+
|
| 193 |
+
可以尝试修改 ds config,去掉这几个参数的 auto 属性,改小试试看:
|
| 194 |
+
|
| 195 |
+
- `stage3_param_persistence_threshold`
|
| 196 |
+
- `stage3_prefetch_bucket_size`
|
| 197 |
+
- `stage3_max_reuse_distance`
|
| 198 |
+
|
| 199 |
+
#### Lora 模型合并
|
| 200 |
+
|
| 201 |
+
保存下来的 lora 权重没法在训练运行时合并到 zero3 模型中,因为 zero3 开启时模型权重会切分到各 dp rank 上。因此如果想把 lora 权重合并到 base 模型上,可以通过离线的方式合并后得到权重文件。执行`merge_lora_weight.sh`即可完成 lora 权重和 base 模型权重的合并,其中的参数有:
|
| 202 |
+
|
| 203 |
+
- `--base_model_path`:base 模型的权重目录
|
| 204 |
+
- `--adapter_model_path`:lora 权重目录
|
| 205 |
+
- `--output_path`:合并后的权重保存目录
|
| 206 |
+
- `--save_dtype`: 以什么数据格式存储合并后的权重,可选值:fp16,bf16,fp32
|
| 207 |
+
|
| 208 |
+
#### LLaMA-Factory 支持
|
| 209 |
+
|
| 210 |
+
如果对 LLaMA-Factory 较为熟悉,可使用LLaMA-Factory进行微调。脚本、代码以及配置文件都归档在`./train/llama_factory_support`目录下。如果没有特别说明,接下来我们提到的文件都是该目录下的文件。
|
| 211 |
+
|
| 212 |
+
##### 安装
|
| 213 |
+
|
| 214 |
+
可以通过下载源码 https://github.com/hiyouga/LLaMA-Factory/tree/main ,根据网站的指引进行安装。
|
| 215 |
+
|
| 216 |
+
##### 配置文件
|
| 217 |
+
|
| 218 |
+
我们提供了 llama-factory 的训练示例配置文件 `hy_v3_lora_sft.yaml`和`hy_v3_full_sft.yaml`文件,分别对应LORA训练和非LORA训练。
|
| 219 |
+
|
| 220 |
+
脚本中的关键参数如下:
|
| 221 |
+
|
| 222 |
+
**模型相关:**
|
| 223 |
+
|
| 224 |
+
- `model_name_or_path`: Hy3 preview HF 格式预训练模型权重路径
|
| 225 |
+
- `trust_remote_code`: 是否信任远程代码, Hy3 preview 需要设置为 `true`
|
| 226 |
+
|
| 227 |
+
**训练方法:**
|
| 228 |
+
|
| 229 |
+
- `stage`: 训练阶段, 当前为 `sft`(监督微调)
|
| 230 |
+
- `finetuning_type`: 微调类型, 可选 `full`(全量微调) 或 `lora`(LoRA 微调)
|
| 231 |
+
- `deepspeed`: DeepSpeed 配置文件路径, 全量微调推荐 `ds_zero3_offload_hy.json`, LoRA 微调推荐 `ds_zero2_offload_lora.json`
|
| 232 |
+
|
| 233 |
+
**LoRA 参数(仅 LoRA 微调时生效):**
|
| 234 |
+
|
| 235 |
+
- `lora_rank`: LoRA 秩, 默认 `64`
|
| 236 |
+
- `lora_alpha`: LoRA alpha 系数, 默认 `128`
|
| 237 |
+
- `lora_dropout`: LoRA dropout 比率, 默认 `0.05`
|
| 238 |
+
- `lora_target`: LoRA 应用的目标模块, 默认为 `q_proj,k_proj,v_proj,o_proj`
|
| 239 |
+
|
| 240 |
+
**数据集:**
|
| 241 |
+
|
| 242 |
+
- `dataset_dir`: 数据集目录路径
|
| 243 |
+
- `dataset`: 数据集名称, 需要在 `dataset_dir` 下的 `dataset_info.json` 中注册
|
| 244 |
+
- `template`: 对话模板, Hy3 preview 使用 `hy_v3`
|
| 245 |
+
- `cutoff_len`: 最大序列长度, 超出部分会被截断; 全量微调可设为 `262144`(262K), LoRA 微调建议设为 `8192` 以节省显存
|
| 246 |
+
- `max_samples`: 每个数据集最多使用的样本数
|
| 247 |
+
- `overwrite_cache`: 是否覆盖已缓存的预处理数据集
|
| 248 |
+
|
| 249 |
+
**输出:**
|
| 250 |
+
|
| 251 |
+
- `output_dir`: 输出目录, 日志、TensorBoard 和权重都会存储在此路径下
|
| 252 |
+
- `logging_steps`: 每多少步记录一次日志
|
| 253 |
+
- `save_steps`: 每多少步保存一次 checkpoint
|
| 254 |
+
- `plot_loss`: 是否绘制训练 loss 曲线
|
| 255 |
+
- `overwrite_output_dir`: 是否覆盖已有的输出目录
|
| 256 |
+
- `save_only_model`: 是否只保存模型权重(不保存优化器状态等)
|
| 257 |
+
- `report_to`: 日志上报工具, 可选 `none`, `wandb`, `tensorboard`, `swanlab`, `mlflow`
|
| 258 |
+
|
| 259 |
+
**训练超参数:**
|
| 260 |
+
|
| 261 |
+
- `per_device_train_batch_size`: 每张卡上的 batch size
|
| 262 |
+
- `gradient_accumulation_steps`: 梯度累积步数, `per_device_train_batch_size * gradient_accumulation_steps * dp_size` 为 global batch size
|
| 263 |
+
- `learning_rate`: 最大学习率, 全量微调推荐 `1.0e-5`, LoRA 微调推荐 `2.0e-4`
|
| 264 |
+
- `num_train_epochs`: 训练轮数
|
| 265 |
+
- `lr_scheduler_type`: 学习率调度器类型, 推荐使用 `cosine_with_min_lr`
|
| 266 |
+
- `lr_scheduler_kwargs.min_lr_rate`: 最小学习率与最大学习率的比值, 例如 `0.1` 表示最小学习率为最大学习率的 10%
|
| 267 |
+
- `warmup_ratio`: 预热阶段占总训练步数的比例
|
| 268 |
+
- `bf16`: 是否使用 BFloat16 混合精度训练
|
| 269 |
+
- `gradient_checkpointing`: 是否开启梯度重计算以节省显存
|
| 270 |
+
- `ddp_timeout`: 分布式训练超时时间(毫秒)
|
| 271 |
+
- `flash_attn`: 注意力实现方式, 推荐 `fa2`(FlashAttention-2), 也可选 `sdpa`; 使用 `fa2` 需要安装 flash-attn 包
|
| 272 |
+
- `resume_from_checkpoint`: 从指定 checkpoint 路径恢复训练, 设为 `null` 表示从头开始训练
|
| 273 |
+
|
| 274 |
+
##### 启动训练
|
| 275 |
+
|
| 276 |
+
请先按照前面章节 [配置机器间免密 ssh 登录](#配置机器间免密-ssh-登录) 配置多机免密登录。
|
| 277 |
+
|
| 278 |
+
修改`train_lf.sh`中开头的以下配置:
|
| 279 |
+
|
| 280 |
+
```shell
|
| 281 |
+
export HOST_GPU_NUM=8
|
| 282 |
+
# IP list, comma separated. e.g. "192.168.1.1,192.168.1.2" or single node "192.168.1.1"
|
| 283 |
+
export IP_LIST=${IP_LIST:-"127.0.0.1"}
|
| 284 |
+
```
|
| 285 |
+
|
| 286 |
+
注意:如果`IP_LIST`环境变量未设置,则将`IP_LIST`替换为IP列表!格式为:
|
| 287 |
+
```
|
| 288 |
+
如果只有一个IP:
|
| 289 |
+
IP_LIST=${ip_1}
|
| 290 |
+
|
| 291 |
+
如果有多个IP:
|
| 292 |
+
IP_LIST=${ip_1},${ip_2}
|
| 293 |
+
|
| 294 |
+
```
|
| 295 |
+
|
| 296 |
+
请将`${ip_1}`和`${ip_2}`替换为真实的IP地址。
|
| 297 |
+
|
| 298 |
+
然后,在每一台机器上,在`train/llama_factory_support/`目录下执行`bash train_lf.sh`。
|
train/ds_zero2_no_offload.json
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"fp16": {
|
| 3 |
+
"enabled": "auto",
|
| 4 |
+
"loss_scale": 0,
|
| 5 |
+
"loss_scale_window": 100,
|
| 6 |
+
"initial_scale_power": 16,
|
| 7 |
+
"hysteresis": 2,
|
| 8 |
+
"min_loss_scale": 1e-10
|
| 9 |
+
},
|
| 10 |
+
"zero_optimization": {
|
| 11 |
+
"stage": 2,
|
| 12 |
+
"allgather_partitions": true,
|
| 13 |
+
"allgather_bucket_size": 1e8,
|
| 14 |
+
"overlap_comm": true,
|
| 15 |
+
"reduce_scatter": true,
|
| 16 |
+
"reduce_bucket_size": 1e8,
|
| 17 |
+
"contiguous_gradients": true
|
| 18 |
+
},
|
| 19 |
+
"gradient_accumulation_steps": "auto",
|
| 20 |
+
"gradient_clipping": "auto",
|
| 21 |
+
"steps_per_print": 10,
|
| 22 |
+
"train_batch_size": "auto",
|
| 23 |
+
"train_micro_batch_size_per_gpu": "auto",
|
| 24 |
+
"wall_clock_breakdown": false
|
| 25 |
+
}
|
train/ds_zero3_no_offload.json
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"fp16": {
|
| 3 |
+
"enabled": "auto",
|
| 4 |
+
"loss_scale": 0,
|
| 5 |
+
"loss_scale_window": 1000,
|
| 6 |
+
"initial_scale_power": 16,
|
| 7 |
+
"hysteresis": 2,
|
| 8 |
+
"min_loss_scale": 1
|
| 9 |
+
},
|
| 10 |
+
"bf16": {
|
| 11 |
+
"enabled": "auto"
|
| 12 |
+
},
|
| 13 |
+
|
| 14 |
+
"zero_optimization": {
|
| 15 |
+
"stage": 3,
|
| 16 |
+
"offload_optimizer": {
|
| 17 |
+
"device": "none",
|
| 18 |
+
"pin_memory": true
|
| 19 |
+
},
|
| 20 |
+
"offload_param": {
|
| 21 |
+
"device": "none",
|
| 22 |
+
"pin_memory": true
|
| 23 |
+
},
|
| 24 |
+
"overlap_comm": true,
|
| 25 |
+
"contiguous_gradients": true,
|
| 26 |
+
"sub_group_size": 1e9,
|
| 27 |
+
"reduce_bucket_size": "auto",
|
| 28 |
+
"stage3_prefetch_bucket_size": "auto",
|
| 29 |
+
"stage3_param_persistence_threshold": "auto",
|
| 30 |
+
"stage3_max_live_parameters": 1e9,
|
| 31 |
+
"stage3_max_reuse_distance": 1e9,
|
| 32 |
+
"stage3_gather_16bit_weights_on_model_save": true
|
| 33 |
+
},
|
| 34 |
+
|
| 35 |
+
"gradient_accumulation_steps": "auto",
|
| 36 |
+
"gradient_clipping": "auto",
|
| 37 |
+
"steps_per_print": 10,
|
| 38 |
+
"train_batch_size": "auto",
|
| 39 |
+
"train_micro_batch_size_per_gpu": "auto",
|
| 40 |
+
"wall_clock_breakdown": false
|
| 41 |
+
}
|
train/ds_zero3_offload.json
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"fp16": {
|
| 3 |
+
"enabled": "auto",
|
| 4 |
+
"loss_scale": 0,
|
| 5 |
+
"loss_scale_window": 1000,
|
| 6 |
+
"initial_scale_power": 16,
|
| 7 |
+
"hysteresis": 2,
|
| 8 |
+
"min_loss_scale": 1
|
| 9 |
+
},
|
| 10 |
+
"bf16": {
|
| 11 |
+
"enabled": "auto"
|
| 12 |
+
},
|
| 13 |
+
|
| 14 |
+
"zero_optimization": {
|
| 15 |
+
"stage": 3,
|
| 16 |
+
"offload_optimizer": {
|
| 17 |
+
"device": "cpu",
|
| 18 |
+
"pin_memory": true
|
| 19 |
+
},
|
| 20 |
+
"offload_param": {
|
| 21 |
+
"device": "cpu",
|
| 22 |
+
"pin_memory": true
|
| 23 |
+
},
|
| 24 |
+
"overlap_comm": true,
|
| 25 |
+
"contiguous_gradients": true,
|
| 26 |
+
"sub_group_size": 1e9,
|
| 27 |
+
"reduce_bucket_size": "auto",
|
| 28 |
+
"stage3_prefetch_bucket_size": "auto",
|
| 29 |
+
"stage3_param_persistence_threshold": "auto",
|
| 30 |
+
"stage3_max_live_parameters": 1e9,
|
| 31 |
+
"stage3_max_reuse_distance": 1e9,
|
| 32 |
+
"stage3_gather_16bit_weights_on_model_save": false
|
| 33 |
+
},
|
| 34 |
+
|
| 35 |
+
"gradient_accumulation_steps": "auto",
|
| 36 |
+
"gradient_clipping": "auto",
|
| 37 |
+
"steps_per_print": 10,
|
| 38 |
+
"train_batch_size": "auto",
|
| 39 |
+
"train_micro_batch_size_per_gpu": "auto",
|
| 40 |
+
"wall_clock_breakdown": false
|
| 41 |
+
}
|
train/ds_zero3_offload_no_auto.json
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"fp16": {
|
| 3 |
+
"enabled": false,
|
| 4 |
+
"loss_scale": 0,
|
| 5 |
+
"loss_scale_window": 1000,
|
| 6 |
+
"initial_scale_power": 16,
|
| 7 |
+
"hysteresis": 2,
|
| 8 |
+
"min_loss_scale": 1
|
| 9 |
+
},
|
| 10 |
+
"bf16": {
|
| 11 |
+
"enabled": true
|
| 12 |
+
},
|
| 13 |
+
|
| 14 |
+
"zero_optimization": {
|
| 15 |
+
"stage": 3,
|
| 16 |
+
"offload_optimizer": {
|
| 17 |
+
"device": "cpu",
|
| 18 |
+
"pin_memory": true
|
| 19 |
+
},
|
| 20 |
+
"overlap_comm": true,
|
| 21 |
+
"contiguous_gradients": true,
|
| 22 |
+
"sub_group_size": 1e9,
|
| 23 |
+
"reduce_bucket_size": 1e8,
|
| 24 |
+
"stage3_prefetch_bucket_size": 1e8,
|
| 25 |
+
"stage3_param_persistence_threshold": 0,
|
| 26 |
+
"stage3_max_live_parameters": 1e8,
|
| 27 |
+
"stage3_max_reuse_distance": 1e8,
|
| 28 |
+
"stage3_gather_16bit_weights_on_model_save": true
|
| 29 |
+
},
|
| 30 |
+
|
| 31 |
+
"gradient_accumulation_steps": 1,
|
| 32 |
+
"gradient_clipping": "auto",
|
| 33 |
+
"steps_per_print": 10,
|
| 34 |
+
"train_micro_batch_size_per_gpu": 1,
|
| 35 |
+
"wall_clock_breakdown": false
|
| 36 |
+
}
|
train/example_data.jsonl
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"messages": [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "能否帮我撰写一个关于环保的议论文开头部分?\n主题:环保\n立场:重要性和我们的责任"}, {"role": "assistant", "content": "在我们的生活中,环保已经成为一个不可忽视的议题。随着科技的发展和人口的增长,我们的环境正在遭受前所未有的压力。森林被砍伐,河流被污染,动植物种类正在消失,全球气候正在发生变化。这些都是我们无法忽视的事实。然而,环保并不仅仅是一个环境问题,它更是一个关乎我们生存和发展的问题。环保的重要性已经超越了任何一个单一的领域,它关乎我们的健康、经济、社会和文化。因此,我们每一个人都有责任和义务去保护我们的环境,这不仅是为了我们自己,也是为了我们的后代和整个地球。"}]}
|
| 2 |
+
{"messages": [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "请描述小红书APP的用户画像,并用表格的形式展示结果。"}, {"role": "assistant", "content": "| 用户特征 | 描述 |\n| -------- | ------------------------------ |\n| 年龄段 | 18-35岁 |\n| 性别 | 女性用户较多(约占70%) |\n| 地域 | 一线城市和二线城市用户占比较高 |\n| 职业 | 白领、创业者、自由职业者等 |\n| 爱好 | 美容、时尚、旅行、美食、健康、生活方式等 |\n| 消费能力 | 中高消费能力用户 |\n| 使用目的 | 阅读内容,分享购物心得,结识相似兴趣的朋友 |\n\n注:以上用户画像仅供参考,实际情况可能会有所不同。"}]}
|
| 3 |
+
{"messages": [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "请根据以下两个句子的描述,判断它们是否对同一情况进行了描述。句子1:“外面下着大雨,人们纷纷打起了伞。” 句子2:“由于雨量较大,人们都撑起了雨伞。”"}, {"role": "assistant", "content": "是的,这两个句子描述的是同一情况。"}]}
|
| 4 |
+
{"messages": [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "你认为游戏开发者应该如何处理游戏中的社交网络集成?"}, {"role": "assistant", "content": "游戏开发者在处理游戏中的社交网络集成时可以考虑以下几点:\n\n1. 以用户为中心:确保集成社交网络功能是为了提高用户体验,让玩家更容易地与朋友互动、分享成就和游戏进度。\n\n2. 保护隐私:在集成社交网络时,务必保护用户的隐私。在收集和使用用户数据时,需要遵循相关法规和用户协议,同时提供明确的隐私政策。\n\n3. 适度推送:在推送游戏动态、邀请和通知时要适度,避免过度打扰用户或侵犯用户隐私。可提供自定义设置,让玩家选择性地接收通知。\n\n4. 社交功能设计:在游戏中加入社交功能,如聊天、组队、比赛等,以增进玩家之间的互动。同时,确保这些功能易用且不影响游戏体验。\n\n5. 社区建设:鼓励玩家在社交平台上建立游戏相关的社区,分享游戏心得、攻略和经验。这将有助于增加游戏的知名度和吸引力。\n\n6. 奖励机制:设定奖励机制,鼓励玩家在社交网络上分享游戏内容、邀请好友等。但要确保奖励机制不会破坏游戏平衡。\n\n7. 持续更新和优化:根据玩家的反馈,不断优化社交网络集成功能,确保其与游戏的其他方面保持一致。\n\n总之,游戏开发者应该在保护用户隐私的前提下,提供便捷的社交网络集成功能,以提高游戏的互动性和吸引力。同时,要关注玩家的反馈,不断优化功能。"}]}
|
| 5 |
+
{"messages": [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "将这段过去进行时的文字改为现在完成进行时:“她们在公园里跳舞。”"}, {"role": "assistant", "content": "她们已经在公园里跳舞了。"}]}
|
| 6 |
+
{"messages": [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "在godaddy上买了一台服务器怎么更改root账户"}, {"role": "assistant", "content": "在GoDaddy上购买的服务器可能是Linux系统的VPS或者是专用服务器。要更改root账户,你可以通过SSH连接到服务器并修改密码。以下是具体步骤:\n\n1. 打开终端(Terminal)或使用SSH客户端(如PuTTY)。\n2. 使用SSH连接到服务器。在终端中输入以下命令:\n\n ```\n ssh root@your_server_ip_address\n ```\n\n 将`your_server_ip_address`替换为你的服务器IP地址。如果需要,你可能还需要输入服务器的SSH端口号。\n \n3. 当系统提示输入密码时,输入你的当前root密码。如果连接成功,你将看到服务器的命令提示符。\n\n4. 要更改root密码,请输入以下命令:\n\n ```\n passwd\n ```\n\n5. 系统将提示你输入新的UNIX密码。输入你想要设置的新密码并按回车键。\n\n6. 系统将要求你重新输入新密码以进行确认。再次输入新密码并按回车键。\n\n7. 如果密码更改成功,系统将显示“密码已成功更新”的消息。\n\n8. 使用`exit`命令断开与服务器的连接。\n\n现在你已经成功更改了root账户的密码。下次使用SSH连接到服务器时,请使用新的密码进行身份验证。"}]}
|
| 7 |
+
{"messages": [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "PHP算法逻辑:如何计算年龄?"}, {"role": "assistant", "content": "要计算年龄,您可以使用以下PHP算法。此算法考虑了生日、当前日期以及闰年的计算。\n\n```php\nfunction calculateAge($birthdate) {\n // 格式为YYYY-MM-DD的生日\n $birthdate = strtotime($birthdate);\n \n // 获取当前日期\n $now = time();\n\n // 计算年龄\n $age = date('Y', $now) - date('Y', $birthdate);\n \n // 如果还没有过生日,减去一年\n if (date('m-d', $now) < date('m-d', $birthdate)) {\n $age--;\n }\n\n return $age;\n}\n\n// 使用\n$birthdate = \"2000-01-01\";\necho \"年龄是: \" . calculateAge($birthdate) . \" 岁\";\n```\n\n这个函数将生日作为输入,输出用户的年龄。首先,它将生日转换为时间戳,然后获取当前的时间戳。接下来,它计算年份之间的差异。如果当前日期仍未超过这一年的生日,这个函数将年龄减少1。\n\n在使用函数时,只需传入出生日期即可,函数会返回相应的年龄。"}]}
|
| 8 |
+
{"messages": [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "给我解释一下什么是胰岛素抵抗? 请详细描述胰岛素抵抗的症状、原因与治疗方法。 胰岛素抵抗会影响一个人的日常生活吗? 有哪些生活习惯改变可以帮助改善胰岛素抵抗?"}, {"role": "assistant", "content": "胰岛素抵抗是指身体对胰岛素的敏感性降低,胰岛素不能有效地发挥作用,导致血糖无法正常进入身体细胞。胰岛素是一种激素,由胰腺产生,其主要作用是帮助身体细胞吸收血糖并将其转化为能量。当胰岛素抵抗发生时,胰腺会分泌更多的胰岛素以维持正常的血糖水平,但长期的胰岛素抵抗可能导致2型糖尿病和其他健康问题的发生。\n\n胰岛素抵抗的症状可能包括:\n\n1. 高血糖:由于胰岛素不能正常工作,血糖无法进入细胞,导致血糖升高。\n2. 疲劳:由于身体无法有效利用血糖产生能量,患者可能会感到疲劳无力。\n3. 肥胖:胰岛素抵抗可能导致过量的胰岛素分泌,使得脂肪存储增加,进而导致体重增加。\n4. 高血压:胰岛素抵抗与高血压之间存在关联,可能导致血压升高。\n5. 黑棘皮病:皮肤出现黑色、粗糙的皮疹,通常出现在颈部、腋窝等皮肤皱褶处。\n\n胰岛素抵抗的原因可能包括:\n\n1. 遗传因素:遗传易感基因可能会增加一个人患胰岛素抵抗的风险。\n2. 肥胖:过重或肥胖会增加患胰岛素抵抗的风险。\n3. 缺乏运动:缺乏运动可能会导致肌肉对胰岛素的敏感性降低。\n4. 高糖饮食:长期摄入高糖食品可能导致胰岛素抵抗。\n5. 睡眠不足:长期睡眠不足可能会影响胰岛素的正常功能。\n\n治疗胰岛素抵抗的方法主要包括生活方式的改变和药物治疗:\n\n1. 减轻体重:通过改善饮食和增加运动来减轻体重,有助于提高身体对胰岛素的敏感性。\n2. 增加运动:增加有氧运动和力量训练,有助于提高肌肉对胰岛素的敏感性。\n3. 改善饮食:摄入低糖、高纤维、富含全谷物和蔬菜水果的食物,有助于控制血糖和减轻体重。\n4. 药物治疗:在必要时,医生可能会开处降糖药物,如二甲双胍,以帮助提高胰岛素敏感性。\n\n胰岛素抵抗会影响一个人的日常生活,包括疲劳、体重增加、血糖波动等问题。通过改变生活习惯,如减轻体重、增加运动、改善饮食和保证充足睡眠,可以帮助改善胰岛素抵抗,降低患2型糖尿病的风险。"}]}
|
train/llama_factory_support/dataset_info.json
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"hy_v3_demo": {
|
| 3 |
+
"file_name": "../example_data.jsonl",
|
| 4 |
+
"formatting": "sharegpt",
|
| 5 |
+
"columns": {
|
| 6 |
+
"messages": "messages"
|
| 7 |
+
},
|
| 8 |
+
"tags": {
|
| 9 |
+
"role_tag": "role",
|
| 10 |
+
"content_tag": "content",
|
| 11 |
+
"user_tag": "user",
|
| 12 |
+
"assistant_tag": "assistant",
|
| 13 |
+
"system_tag": "system"
|
| 14 |
+
}
|
| 15 |
+
}
|
| 16 |
+
}
|
train/llama_factory_support/ds_zero2_offload_lora.json
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"fp16": {
|
| 3 |
+
"enabled": false,
|
| 4 |
+
"loss_scale": 0,
|
| 5 |
+
"loss_scale_window": 1000,
|
| 6 |
+
"initial_scale_power": 16,
|
| 7 |
+
"hysteresis": 2,
|
| 8 |
+
"min_loss_scale": 1
|
| 9 |
+
},
|
| 10 |
+
"bf16": {
|
| 11 |
+
"enabled": true
|
| 12 |
+
},
|
| 13 |
+
|
| 14 |
+
"zero_optimization": {
|
| 15 |
+
"stage": 2,
|
| 16 |
+
"offload_optimizer": {
|
| 17 |
+
"device": "cpu",
|
| 18 |
+
"pin_memory": true
|
| 19 |
+
},
|
| 20 |
+
"allgather_partitions": true,
|
| 21 |
+
"allgather_bucket_size": 5e8,
|
| 22 |
+
"overlap_comm": true,
|
| 23 |
+
"reduce_scatter": true,
|
| 24 |
+
"reduce_bucket_size": 5e8,
|
| 25 |
+
"contiguous_gradients": true
|
| 26 |
+
},
|
| 27 |
+
|
| 28 |
+
"gradient_accumulation_steps": "auto",
|
| 29 |
+
"gradient_clipping": "auto",
|
| 30 |
+
"steps_per_print": 10,
|
| 31 |
+
"train_micro_batch_size_per_gpu": "auto",
|
| 32 |
+
"wall_clock_breakdown": false
|
| 33 |
+
}
|
train/llama_factory_support/ds_zero3_offload.json
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"fp16": {
|
| 3 |
+
"enabled": false,
|
| 4 |
+
"loss_scale": 0,
|
| 5 |
+
"loss_scale_window": 1000,
|
| 6 |
+
"initial_scale_power": 16,
|
| 7 |
+
"hysteresis": 2,
|
| 8 |
+
"min_loss_scale": 1
|
| 9 |
+
},
|
| 10 |
+
"bf16": {
|
| 11 |
+
"enabled": true
|
| 12 |
+
},
|
| 13 |
+
|
| 14 |
+
"zero_optimization": {
|
| 15 |
+
"stage": 3,
|
| 16 |
+
"offload_optimizer": {
|
| 17 |
+
"device": "cpu",
|
| 18 |
+
"pin_memory": true
|
| 19 |
+
},
|
| 20 |
+
"offload_param": {
|
| 21 |
+
"device": "cpu",
|
| 22 |
+
"pin_memory": true
|
| 23 |
+
},
|
| 24 |
+
"overlap_comm": true,
|
| 25 |
+
"contiguous_gradients": true,
|
| 26 |
+
"sub_group_size": 1e9,
|
| 27 |
+
"reduce_bucket_size": 1e8,
|
| 28 |
+
"stage3_prefetch_bucket_size": 5e8,
|
| 29 |
+
"stage3_param_persistence_threshold": 1e6,
|
| 30 |
+
"stage3_max_live_parameters": 5e9,
|
| 31 |
+
"stage3_max_reuse_distance": 5e9,
|
| 32 |
+
"stage3_gather_16bit_weights_on_model_save": false
|
| 33 |
+
},
|
| 34 |
+
|
| 35 |
+
"gradient_accumulation_steps": "auto",
|
| 36 |
+
"gradient_clipping": "auto",
|
| 37 |
+
"steps_per_print": 10,
|
| 38 |
+
"train_micro_batch_size_per_gpu": "auto",
|
| 39 |
+
"wall_clock_breakdown": false
|
| 40 |
+
}
|
train/llama_factory_support/hy_v3_full_sft.yaml
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
### model
|
| 2 |
+
model_name_or_path: ../hf
|
| 3 |
+
trust_remote_code: true
|
| 4 |
+
|
| 5 |
+
### method
|
| 6 |
+
stage: sft
|
| 7 |
+
do_train: true
|
| 8 |
+
finetuning_type: full
|
| 9 |
+
deepspeed: ds_zero3_offload.json
|
| 10 |
+
|
| 11 |
+
### dataset
|
| 12 |
+
dataset_dir: .
|
| 13 |
+
dataset: hy_v3_demo
|
| 14 |
+
template: hy_v3
|
| 15 |
+
cutoff_len: 4096 # HYV3 supports 262k context length
|
| 16 |
+
max_samples: 1000
|
| 17 |
+
overwrite_cache: true
|
| 18 |
+
|
| 19 |
+
### output
|
| 20 |
+
output_dir: saves/hy_v3/full/sft
|
| 21 |
+
logging_steps: 1
|
| 22 |
+
save_steps: 500
|
| 23 |
+
plot_loss: true
|
| 24 |
+
overwrite_output_dir: true
|
| 25 |
+
save_only_model: false
|
| 26 |
+
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
|
| 27 |
+
|
| 28 |
+
### train
|
| 29 |
+
per_device_train_batch_size: 1
|
| 30 |
+
gradient_accumulation_steps: 1
|
| 31 |
+
learning_rate: 1.0e-5
|
| 32 |
+
num_train_epochs: 100.0
|
| 33 |
+
lr_scheduler_type: cosine_with_min_lr
|
| 34 |
+
lr_scheduler_kwargs:
|
| 35 |
+
min_lr_rate: 0.1 # min_lr / learning_rate = 1e-6 / 1e-5 = 0.1
|
| 36 |
+
warmup_ratio: 0.1
|
| 37 |
+
bf16: true
|
| 38 |
+
gradient_checkpointing: true
|
| 39 |
+
gradient_checkpointing_kwargs:
|
| 40 |
+
use_reentrant: true
|
| 41 |
+
ddp_timeout: 180000000
|
| 42 |
+
flash_attn: fa2
|
| 43 |
+
resume_from_checkpoint: null
|
| 44 |
+
|
| 45 |
+
### eval
|
| 46 |
+
# eval_dataset: alpaca_en_demo
|
| 47 |
+
# val_size: 0.1
|
| 48 |
+
# per_device_eval_batch_size: 1
|
| 49 |
+
# eval_strategy: steps
|
| 50 |
+
# eval_steps: 500
|
| 51 |
+
|
| 52 |
+
### HYV3 specific parameters (MoE architecture)
|
| 53 |
+
# moe_num_experts: 192
|
| 54 |
+
# moe_topk: 8
|
| 55 |
+
# moe_intermediate_size: 1536
|
| 56 |
+
# moe_layer_num_skipped: 1
|
| 57 |
+
# output_router_logits: true
|
train/llama_factory_support/hy_v3_lora_sft.yaml
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
### model
|
| 2 |
+
model_name_or_path: ../hf
|
| 3 |
+
trust_remote_code: true
|
| 4 |
+
|
| 5 |
+
### method
|
| 6 |
+
stage: sft
|
| 7 |
+
do_train: true
|
| 8 |
+
finetuning_type: lora
|
| 9 |
+
deepspeed: ds_zero2_offload_lora.json
|
| 10 |
+
|
| 11 |
+
### LoRA parameters
|
| 12 |
+
lora_rank: 64
|
| 13 |
+
lora_alpha: 128
|
| 14 |
+
lora_dropout: 0.05
|
| 15 |
+
lora_target: q_proj,k_proj,v_proj,o_proj
|
| 16 |
+
|
| 17 |
+
### dataset
|
| 18 |
+
dataset_dir: .
|
| 19 |
+
dataset: hy_v3_demo
|
| 20 |
+
template: hy_v3
|
| 21 |
+
cutoff_len: 4096 # Use shorter context for LoRA to save memory; increase if needed
|
| 22 |
+
max_samples: 1000
|
| 23 |
+
overwrite_cache: true
|
| 24 |
+
|
| 25 |
+
### output
|
| 26 |
+
output_dir: saves/hy_v3/lora/sft
|
| 27 |
+
logging_steps: 10
|
| 28 |
+
save_steps: 500
|
| 29 |
+
plot_loss: true
|
| 30 |
+
overwrite_output_dir: true
|
| 31 |
+
save_only_model: false
|
| 32 |
+
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
|
| 33 |
+
|
| 34 |
+
### train
|
| 35 |
+
per_device_train_batch_size: 1
|
| 36 |
+
gradient_accumulation_steps: 4
|
| 37 |
+
learning_rate: 2.0e-4
|
| 38 |
+
num_train_epochs: 3.0
|
| 39 |
+
lr_scheduler_type: cosine_with_min_lr
|
| 40 |
+
lr_scheduler_kwargs:
|
| 41 |
+
min_lr_rate: 0.1 # min_lr / learning_rate = 2e-5 / 2e-4 = 0.1
|
| 42 |
+
warmup_ratio: 0.1
|
| 43 |
+
bf16: true
|
| 44 |
+
gradient_checkpointing: true
|
| 45 |
+
gradient_checkpointing_kwargs:
|
| 46 |
+
use_reentrant: true
|
| 47 |
+
ddp_timeout: 180000000
|
| 48 |
+
flash_attn: fa2
|
| 49 |
+
resume_from_checkpoint: null
|
| 50 |
+
|
| 51 |
+
### eval
|
| 52 |
+
# eval_dataset: alpaca_en_demo
|
| 53 |
+
# val_size: 0.1
|
| 54 |
+
# per_device_eval_batch_size: 1
|
| 55 |
+
# eval_strategy: steps
|
| 56 |
+
# eval_steps: 500
|
train/llama_factory_support/hy_v3_patches.py
ADDED
|
@@ -0,0 +1,384 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
HYV3 monkey-patches for LLaMA Factory + DeepSpeed ZeRO-3 training.
|
| 3 |
+
|
| 4 |
+
This module applies all necessary runtime patches so that HYV3 (MoE)
|
| 5 |
+
can be trained correctly under LLaMA Factory with DeepSpeed ZeRO-3.
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
Import this module **before** calling `llamafactory-cli train`:
|
| 9 |
+
|
| 10 |
+
import hy_v3_patches # applies patches on import
|
| 11 |
+
# ... then start training
|
| 12 |
+
|
| 13 |
+
Or add to the LLaMA Factory YAML via a custom entry-point wrapper.
|
| 14 |
+
|
| 15 |
+
Patches applied:
|
| 16 |
+
1. ZeRO-3 buffer loading (e_score_correction_bias etc.)
|
| 17 |
+
Key renaming + expert fusing is now done offline by convert_ckpt_to_outer.py
|
| 18 |
+
2. Router forward dtype fix (MoE router gate dtype alignment for ZeRO-3)
|
| 19 |
+
3. gradient_checkpointing (use_reentrant=True for ZeRO-3)
|
| 20 |
+
4. Tokenizer file copy (CustomSaveCallback)
|
| 21 |
+
5. (Removed) -- was per-expert ModuleList, now using native 3D Parameters
|
| 22 |
+
6. Save-time reverse key rename + 3D -> per-expert unfuse
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
import os
|
| 26 |
+
import re
|
| 27 |
+
import logging
|
| 28 |
+
import shutil
|
| 29 |
+
from typing import Optional
|
| 30 |
+
|
| 31 |
+
import torch
|
| 32 |
+
import torch.nn as _nn
|
| 33 |
+
import torch.nn.functional as _F
|
| 34 |
+
|
| 35 |
+
logger = logging.getLogger(__name__)
|
| 36 |
+
|
| 37 |
+
# ============================================================================
|
| 38 |
+
# Patch 1: Buffer loading for ZeRO-3
|
| 39 |
+
#
|
| 40 |
+
# The checkpoint is expected to be in outer format (pre-converted by
|
| 41 |
+
# convert_ckpt_to_outer.py). Key renaming and expert fusing are no longer
|
| 42 |
+
# needed at load time.
|
| 43 |
+
#
|
| 44 |
+
# However, ZeRO-3's _load_state_dict_into_zero3_model only handles
|
| 45 |
+
# named_parameters, not named_buffers (e.g. e_score_correction_bias).
|
| 46 |
+
# We still need to manually load buffers from the state_dict.
|
| 47 |
+
# ============================================================================
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def _apply_buffer_loading_patch():
|
| 51 |
+
"""Patch the DeepSpeed ZeRO-3 state_dict loader to manually load buffers.
|
| 52 |
+
|
| 53 |
+
ZeRO-3's _load_state_dict_into_zero3_model only handles named_parameters.
|
| 54 |
+
Buffers like e_score_correction_bias must be loaded manually.
|
| 55 |
+
"""
|
| 56 |
+
try:
|
| 57 |
+
from transformers.integrations.deepspeed import (
|
| 58 |
+
_load_state_dict_into_zero3_model as _orig_load_zero3,
|
| 59 |
+
)
|
| 60 |
+
import transformers.integrations.deepspeed as _ds_mod
|
| 61 |
+
import transformers.modeling_utils as _mu_mod
|
| 62 |
+
except ImportError:
|
| 63 |
+
logger.warning(
|
| 64 |
+
"Could not import transformers.integrations.deepspeed; "
|
| 65 |
+
"buffer loading patch NOT applied."
|
| 66 |
+
)
|
| 67 |
+
return
|
| 68 |
+
|
| 69 |
+
def _patched_load_zero3(model_to_load, state_dict, *args, **kwargs):
|
| 70 |
+
# Load parameters via original ZeRO-3 loader
|
| 71 |
+
result = _orig_load_zero3(model_to_load, state_dict, *args, **kwargs)
|
| 72 |
+
|
| 73 |
+
# Manually load buffers (e.g. e_score_correction_bias) from state_dict.
|
| 74 |
+
# ZeRO-3's loader only handles named_parameters, not named_buffers.
|
| 75 |
+
buffers_loaded = 0
|
| 76 |
+
for name, buf in model_to_load.named_buffers():
|
| 77 |
+
if name in state_dict:
|
| 78 |
+
src_tensor = state_dict[name]
|
| 79 |
+
if isinstance(src_tensor, torch.Tensor):
|
| 80 |
+
buf.data.copy_(src_tensor.to(buf.dtype))
|
| 81 |
+
buffers_loaded += 1
|
| 82 |
+
if isinstance(result, tuple) and len(result) >= 2:
|
| 83 |
+
if isinstance(result[1], set):
|
| 84 |
+
result[1].discard(name)
|
| 85 |
+
if buffers_loaded > 0:
|
| 86 |
+
logger.info(
|
| 87 |
+
"HYV3 Patch 1: Manually loaded %d buffers into model.",
|
| 88 |
+
buffers_loaded
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
return result
|
| 92 |
+
|
| 93 |
+
_ds_mod._load_state_dict_into_zero3_model = _patched_load_zero3
|
| 94 |
+
_mu_mod._load_state_dict_into_zero3_model = _patched_load_zero3
|
| 95 |
+
logger.info(
|
| 96 |
+
"HYV3 patch applied: ZeRO-3 buffer loading for e_score_correction_bias."
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
# ============================================================================
|
| 100 |
+
# Patch 2: Router forward dtype alignment for ZeRO-3
|
| 101 |
+
#
|
| 102 |
+
# The HYV3 MoE HYV3TopKRouter.forward() calls F.linear with .float().
|
| 103 |
+
# Under DeepSpeed ZeRO-3, F.linear is replaced by zero3_linear_wrap which
|
| 104 |
+
# internally does input.matmul(weight.t()) WITHOUT aligning dtypes.
|
| 105 |
+
# When ZeRO-3 stores the gate weight in bf16, the fp32 input causes a
|
| 106 |
+
# dtype mismatch RuntimeError.
|
| 107 |
+
#
|
| 108 |
+
# Fix: monkey-patch HYV3TopKRouter.forward to cast input to
|
| 109 |
+
# self.weight.dtype before F.linear, then cast the output back to float32.
|
| 110 |
+
# ============================================================================
|
| 111 |
+
|
| 112 |
+
_router_patch_applied = False
|
| 113 |
+
|
| 114 |
+
def _apply_router_dtype_patch():
|
| 115 |
+
"""Monkey-patch HYV3TopKRouter.forward to align gate input dtype with weight dtype."""
|
| 116 |
+
global _router_patch_applied
|
| 117 |
+
if _router_patch_applied:
|
| 118 |
+
return
|
| 119 |
+
|
| 120 |
+
try:
|
| 121 |
+
from transformers.models.hy_v3.modeling_hy_v3 import HYV3TopKRouter
|
| 122 |
+
except ImportError:
|
| 123 |
+
try:
|
| 124 |
+
from transformers.hy_v3.modeling_hy_v3 import HYV3TopKRouter
|
| 125 |
+
except ImportError:
|
| 126 |
+
logger.warning(
|
| 127 |
+
"Could not import HYV3TopKRouter; "
|
| 128 |
+
"router dtype patch NOT applied."
|
| 129 |
+
)
|
| 130 |
+
return
|
| 131 |
+
|
| 132 |
+
def _patched_router_forward(
|
| 133 |
+
self,
|
| 134 |
+
hidden_states: torch.Tensor,
|
| 135 |
+
e_score_correction_bias: torch.Tensor,
|
| 136 |
+
) -> tuple:
|
| 137 |
+
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
|
| 138 |
+
# Cast input to match weight dtype (bf16 under ZeRO-3)
|
| 139 |
+
# instead of hard-coding float32, to avoid matmul dtype mismatch.
|
| 140 |
+
weight_dtype = self.weight.dtype
|
| 141 |
+
router_logits = _F.linear(hidden_states.to(weight_dtype), self.weight.to(weight_dtype))
|
| 142 |
+
# Cast back to float32 for numerically stable sigmoid
|
| 143 |
+
router_logits = router_logits.to(torch.float32)
|
| 144 |
+
routing_weights = torch.sigmoid(router_logits)
|
| 145 |
+
|
| 146 |
+
scores_for_choice = routing_weights + e_score_correction_bias
|
| 147 |
+
_, top_k_index = torch.topk(scores_for_choice, self.top_k, dim=-1, sorted=False)
|
| 148 |
+
top_k_weights = routing_weights.gather(1, top_k_index)
|
| 149 |
+
|
| 150 |
+
top_k_weights = top_k_weights / (top_k_weights.sum(dim=-1, keepdim=True) + 1e-20)
|
| 151 |
+
top_k_weights = top_k_weights * self.router_scaling_factor
|
| 152 |
+
|
| 153 |
+
return router_logits, top_k_weights, top_k_index
|
| 154 |
+
|
| 155 |
+
HYV3TopKRouter.forward = _patched_router_forward
|
| 156 |
+
_router_patch_applied = True
|
| 157 |
+
logger.info("HYV3 patch applied: HYV3TopKRouter.forward dtype alignment for ZeRO-3.")
|
| 158 |
+
|
| 159 |
+
# ============================================================================
|
| 160 |
+
# Patch 3: gradient_checkpointing use_reentrant=True
|
| 161 |
+
#
|
| 162 |
+
# PyTorch's torch.utils.checkpoint with use_reentrant=False (the default
|
| 163 |
+
# in transformers) performs strict metadata checks on recomputed tensors.
|
| 164 |
+
# Under ZeRO-3, parameters are all-gathered during the first forward pass
|
| 165 |
+
# but may be partitioned back when the checkpoint recomputes, causing a
|
| 166 |
+
# CheckpointError. Setting use_reentrant=True avoids this.
|
| 167 |
+
#
|
| 168 |
+
# This is applied via a Trainer callback that modifies training_args
|
| 169 |
+
# before training starts.
|
| 170 |
+
# ============================================================================
|
| 171 |
+
|
| 172 |
+
# ============================================================================
|
| 173 |
+
# Patch 4: Tokenizer file copy callback
|
| 174 |
+
#
|
| 175 |
+
# Ensures each checkpoint directory is self-contained for inference by
|
| 176 |
+
# copying all tokenizer-related files from the original tokenizer path.
|
| 177 |
+
# ============================================================================
|
| 178 |
+
|
| 179 |
+
# Tokenizer files that should be copied to each checkpoint
|
| 180 |
+
_TOKENIZER_FILES = [
|
| 181 |
+
"generation_config.json",
|
| 182 |
+
"hy.tiktoken",
|
| 183 |
+
"tokenizer_config.json",
|
| 184 |
+
"tokenization_hy.py",
|
| 185 |
+
"tokenizer.json",
|
| 186 |
+
"special_tokens_map.json",
|
| 187 |
+
"chat_template.jinja",
|
| 188 |
+
]
|
| 189 |
+
|
| 190 |
+
def _copy_tokenizer_to_checkpoint(tokenizer_dir: str, checkpoint_dir: str):
|
| 191 |
+
"""Copy tokenizer files from tokenizer_dir to checkpoint_dir."""
|
| 192 |
+
for fname in _TOKENIZER_FILES:
|
| 193 |
+
src = os.path.join(tokenizer_dir, fname)
|
| 194 |
+
if os.path.isfile(src):
|
| 195 |
+
shutil.copy(src, os.path.join(checkpoint_dir, fname))
|
| 196 |
+
|
| 197 |
+
# ============================================================================
|
| 198 |
+
# Patch 6: Save-time reverse key rename + 3D -> per-expert unfuse
|
| 199 |
+
#
|
| 200 |
+
# When saving checkpoints, the model state_dict uses:
|
| 201 |
+
# - 3D fused experts (experts.gate_up_proj, experts.down_proj)
|
| 202 |
+
# - New naming (mlp.gate, mlp.e_score_correction_bias, mlp.shared_experts)
|
| 203 |
+
#
|
| 204 |
+
# We need to reverse both for old checkpoint compatibility:
|
| 205 |
+
# - mlp.gate. -> mlp.router.gate.
|
| 206 |
+
# - mlp.e_score_correction_bias -> mlp.expert_bias
|
| 207 |
+
# - mlp.shared_experts. -> mlp.shared_mlp.
|
| 208 |
+
# - experts.gate_up_proj -> experts.{N}.gate_proj.weight + experts.{N}.up_proj.weight
|
| 209 |
+
# - experts.down_proj -> experts.{N}.down_proj.weight
|
| 210 |
+
# ============================================================================
|
| 211 |
+
|
| 212 |
+
# Reverse mapping: new model name -> old checkpoint name
|
| 213 |
+
_SAVE_KEY_RENAMES = [
|
| 214 |
+
("mlp.gate.", "mlp.router.gate."),
|
| 215 |
+
("mlp.e_score_correction_bias", "mlp.expert_bias"),
|
| 216 |
+
("mlp.shared_experts.", "mlp.shared_mlp."),
|
| 217 |
+
]
|
| 218 |
+
|
| 219 |
+
# Regex to match fused 3D expert keys
|
| 220 |
+
_FUSED_EXPERT_KEY_RE = re.compile(
|
| 221 |
+
r"^(.*\.mlp\.experts\.)(gate_up_proj|down_proj)$"
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
_save_patch_applied = False
|
| 225 |
+
|
| 226 |
+
def _apply_save_reverse_rename_patch():
|
| 227 |
+
"""Monkey-patch save_pretrained to reverse-rename keys and unfuse 3D experts."""
|
| 228 |
+
global _save_patch_applied
|
| 229 |
+
if _save_patch_applied:
|
| 230 |
+
return
|
| 231 |
+
|
| 232 |
+
try:
|
| 233 |
+
from transformers.models.hy_v3.modeling_hy_v3 import HYV3ForCausalLM
|
| 234 |
+
except ImportError:
|
| 235 |
+
try:
|
| 236 |
+
from transformers.hy_v3.modeling_hy_v3 import HYV3ForCausalLM
|
| 237 |
+
except ImportError:
|
| 238 |
+
logger.warning(
|
| 239 |
+
"Could not import HYV3ForCausalLM; "
|
| 240 |
+
"save reverse rename patch NOT applied."
|
| 241 |
+
)
|
| 242 |
+
return
|
| 243 |
+
|
| 244 |
+
_orig_save_pretrained = HYV3ForCausalLM.save_pretrained
|
| 245 |
+
|
| 246 |
+
def _patched_save_pretrained(self, *args, **kwargs):
|
| 247 |
+
state_dict = kwargs.get("state_dict", None)
|
| 248 |
+
if state_dict is not None:
|
| 249 |
+
reversed_sd = {}
|
| 250 |
+
|
| 251 |
+
for k, v in state_dict.items():
|
| 252 |
+
new_k = k
|
| 253 |
+
# Apply simple key renames
|
| 254 |
+
for new_sub, old_sub in _SAVE_KEY_RENAMES:
|
| 255 |
+
if new_sub in new_k:
|
| 256 |
+
new_k = new_k.replace(new_sub, old_sub)
|
| 257 |
+
break
|
| 258 |
+
|
| 259 |
+
# Check if this is a fused 3D expert key
|
| 260 |
+
m = _FUSED_EXPERT_KEY_RE.match(new_k)
|
| 261 |
+
if m:
|
| 262 |
+
prefix = m.group(1) # e.g. "model.layers.1.mlp.experts."
|
| 263 |
+
proj_type = m.group(2) # "gate_up_proj" or "down_proj"
|
| 264 |
+
|
| 265 |
+
if proj_type == "gate_up_proj":
|
| 266 |
+
# v shape: [num_experts, 2*intermediate, hidden]
|
| 267 |
+
num_experts = v.shape[0]
|
| 268 |
+
intermediate = v.shape[1] // 2
|
| 269 |
+
for i in range(num_experts):
|
| 270 |
+
gate = v[i, :intermediate, :]
|
| 271 |
+
up = v[i, intermediate:, :]
|
| 272 |
+
reversed_sd[f"{prefix}{i}.gate_proj.weight"] = gate
|
| 273 |
+
reversed_sd[f"{prefix}{i}.up_proj.weight"] = up
|
| 274 |
+
elif proj_type == "down_proj":
|
| 275 |
+
# v shape: [num_experts, hidden, intermediate]
|
| 276 |
+
num_experts = v.shape[0]
|
| 277 |
+
for i in range(num_experts):
|
| 278 |
+
reversed_sd[f"{prefix}{i}.down_proj.weight"] = v[i]
|
| 279 |
+
else:
|
| 280 |
+
reversed_sd[new_k] = v
|
| 281 |
+
|
| 282 |
+
kwargs["state_dict"] = reversed_sd
|
| 283 |
+
logger.info(
|
| 284 |
+
"HYV3 Patch 6: Reverse-renamed and unfused %d -> %d "
|
| 285 |
+
"state_dict keys for old checkpoint compatibility.",
|
| 286 |
+
len(state_dict), len(reversed_sd)
|
| 287 |
+
)
|
| 288 |
+
return _orig_save_pretrained(self, *args, **kwargs)
|
| 289 |
+
|
| 290 |
+
HYV3ForCausalLM.save_pretrained = _patched_save_pretrained
|
| 291 |
+
|
| 292 |
+
_save_patch_applied = True
|
| 293 |
+
logger.info(
|
| 294 |
+
"HYV3 patch applied: save-time reverse key rename + "
|
| 295 |
+
"3D -> per-expert unfuse for old ckpt compatibility."
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
# ============================================================================
|
| 299 |
+
# LLaMA Factory Callback: integrates patches 3, 4 into the training loop
|
| 300 |
+
# ============================================================================
|
| 301 |
+
|
| 302 |
+
try:
|
| 303 |
+
from transformers import TrainerCallback
|
| 304 |
+
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
| 305 |
+
|
| 306 |
+
class HYV3PatchCallback(TrainerCallback):
|
| 307 |
+
"""
|
| 308 |
+
LLaMA Factory compatible callback that applies HYV3-specific patches.
|
| 309 |
+
|
| 310 |
+
Add to your YAML or pass to Trainer:
|
| 311 |
+
callbacks: [hy_v3_patches.HYV3PatchCallback]
|
| 312 |
+
"""
|
| 313 |
+
|
| 314 |
+
def __init__(self, tokenizer_dir: Optional[str] = None):
|
| 315 |
+
"""
|
| 316 |
+
Args:
|
| 317 |
+
tokenizer_dir: Path to the original tokenizer directory.
|
| 318 |
+
If None, will try to use model_name_or_path from training args.
|
| 319 |
+
"""
|
| 320 |
+
self._tokenizer_dir = tokenizer_dir
|
| 321 |
+
|
| 322 |
+
def on_train_begin(self, args, state, control, **kwargs):
|
| 323 |
+
# --- Patch 3: gradient_checkpointing use_reentrant ---
|
| 324 |
+
if getattr(args, "gradient_checkpointing", False) and getattr(args, "deepspeed", None):
|
| 325 |
+
if not hasattr(args, "gradient_checkpointing_kwargs") or not args.gradient_checkpointing_kwargs:
|
| 326 |
+
args.gradient_checkpointing_kwargs = {"use_reentrant": True}
|
| 327 |
+
elif "use_reentrant" not in args.gradient_checkpointing_kwargs:
|
| 328 |
+
args.gradient_checkpointing_kwargs["use_reentrant"] = True
|
| 329 |
+
logger.info("HYV3 patch applied: gradient_checkpointing use_reentrant=True.")
|
| 330 |
+
|
| 331 |
+
return control
|
| 332 |
+
|
| 333 |
+
def on_save(self, args, state, control, **kwargs):
|
| 334 |
+
# --- Patch 4: Copy tokenizer files ---
|
| 335 |
+
if torch.distributed.is_initialized() and torch.distributed.get_rank() != 0:
|
| 336 |
+
return control
|
| 337 |
+
|
| 338 |
+
checkpoint_dir = os.path.join(
|
| 339 |
+
args.output_dir,
|
| 340 |
+
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
# Determine tokenizer directory
|
| 344 |
+
tokenizer_dir = self._tokenizer_dir
|
| 345 |
+
if tokenizer_dir is None:
|
| 346 |
+
# Try common locations
|
| 347 |
+
tokenizer_dir = getattr(args, "tokenizer_name_or_path", None)
|
| 348 |
+
if tokenizer_dir is None:
|
| 349 |
+
tokenizer_dir = getattr(args, "model_name_or_path", None)
|
| 350 |
+
|
| 351 |
+
if tokenizer_dir and os.path.isdir(tokenizer_dir):
|
| 352 |
+
_copy_tokenizer_to_checkpoint(tokenizer_dir, checkpoint_dir)
|
| 353 |
+
logger.info(
|
| 354 |
+
"HYV3: Copied tokenizer files from %s to %s",
|
| 355 |
+
tokenizer_dir, checkpoint_dir
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
return control
|
| 359 |
+
|
| 360 |
+
except ImportError:
|
| 361 |
+
logger.warning(
|
| 362 |
+
"transformers not available; HYV3PatchCallback not defined."
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
# ============================================================================
|
| 366 |
+
# Auto-apply patches on import
|
| 367 |
+
# ============================================================================
|
| 368 |
+
|
| 369 |
+
# Patch 1: ZeRO-3 buffer loading (key rename + fuse now done by preprocessing)
|
| 370 |
+
_apply_buffer_loading_patch()
|
| 371 |
+
|
| 372 |
+
# Patch 2: Router dtype fix
|
| 373 |
+
_apply_router_dtype_patch()
|
| 374 |
+
|
| 375 |
+
# Patch 6: Save-time reverse key rename + 3D -> per-expert unfuse
|
| 376 |
+
_apply_save_reverse_rename_patch()
|
| 377 |
+
|
| 378 |
+
# Patches 3, 4 are applied via HYV3PatchCallback during training.
|
| 379 |
+
# Users should add HYV3PatchCallback to their Trainer callbacks.
|
| 380 |
+
|
| 381 |
+
logger.info(
|
| 382 |
+
"HYV3 patches module loaded. Remember to add HYV3PatchCallback to "
|
| 383 |
+
"your Trainer callbacks for full compatibility."
|
| 384 |
+
)
|
train/llama_factory_support/hy_v3_template.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
HYV3 chat template registration for LLaMA Factory.
|
| 3 |
+
|
| 4 |
+
Usage:
|
| 5 |
+
1. Copy this file's register_template block into LLaMA Factory's
|
| 6 |
+
src/llamafactory/data/template.py (for upstream MR).
|
| 7 |
+
2. Or import this module before training to register at runtime:
|
| 8 |
+
import hy_v3_template
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from llamafactory.data.template import ReasoningTemplate, register_template
|
| 12 |
+
from llamafactory.data.formatter import EmptyFormatter, StringFormatter
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# ---------------------------------------------------------------------------
|
| 16 |
+
# HYV3 (MoE, pure text) chat template
|
| 17 |
+
#
|
| 18 |
+
# Token format (from chat_template.jinja & tokenizer_config.json):
|
| 19 |
+
# BOS: <|hy_begin▁of▁sentence|>
|
| 20 |
+
# System: {system_content} (directly after BOS, no role tag)
|
| 21 |
+
# User: <|hy_User|>{user_content}
|
| 22 |
+
# Assistant: <|hy_Assistant|>{assistant_content}<|hy_eos|>
|
| 23 |
+
# EOS: <|hy_eos|>
|
| 24 |
+
#
|
| 25 |
+
# Loss mask: only compute loss on assistant content (including <|hy_eos|>).
|
| 26 |
+
#
|
| 27 |
+
# Note: The system message has NO explicit role token -- it is placed right
|
| 28 |
+
# after BOS. The eos_token is <|hy_eos|>.
|
| 29 |
+
#
|
| 30 |
+
# Reasoning: Supports think tags via ReasoningTemplate.
|
| 31 |
+
# - thought_words: ("<think>", "</think>") matching jinja template
|
| 32 |
+
# - enable_thinking: set globally via data_args.enable_thinking (default True)
|
| 33 |
+
# - Training data always includes think tags (empty or with content)
|
| 34 |
+
# ---------------------------------------------------------------------------
|
| 35 |
+
|
| 36 |
+
register_template(
|
| 37 |
+
name="hy_v3",
|
| 38 |
+
template_class=ReasoningTemplate,
|
| 39 |
+
format_user=StringFormatter(slots=["<|hy_User|>{{content}}"]),
|
| 40 |
+
format_assistant=StringFormatter(slots=["<|hy_Assistant|>{{content}}", {"eos_token"}]),
|
| 41 |
+
format_system=StringFormatter(slots=["{{content}}"]),
|
| 42 |
+
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
| 43 |
+
thought_words=("<think>", "</think>"),
|
| 44 |
+
stop_words=["<|hy_eos|>"],
|
| 45 |
+
efficient_eos=True,
|
| 46 |
+
)
|
train/llama_factory_support/train_hy_v3.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
LLaMA Factory training entry-point wrapper for HYV3.
|
| 3 |
+
|
| 4 |
+
This script:
|
| 5 |
+
1. Registers the hy_v3 chat template
|
| 6 |
+
2. Applies all HYV3 monkey-patches (checkpoint key rename, dtype fix, etc.)
|
| 7 |
+
3. Injects HYV3PatchCallback into the training loop
|
| 8 |
+
4. Calls run_exp() to start LLaMA Factory training
|
| 9 |
+
|
| 10 |
+
How it works:
|
| 11 |
+
- train_lf.sh launches this script via torchrun directly:
|
| 12 |
+
torchrun ... train_hy_v3.py hy_v3_full_sft.yaml
|
| 13 |
+
- Each torchrun worker executes this script, so all patches are applied
|
| 14 |
+
in every worker process before training begins.
|
| 15 |
+
- We call run_exp() directly (not the CLI launcher) to avoid the
|
| 16 |
+
launcher re-spawning workers and losing our patches.
|
| 17 |
+
|
| 18 |
+
Usage:
|
| 19 |
+
# Via launch script (recommended):
|
| 20 |
+
bash train_lf.sh
|
| 21 |
+
|
| 22 |
+
# Direct single-node (8 GPUs):
|
| 23 |
+
torchrun --nproc_per_node 8 train_hy_v3.py hy_v3_full_sft.yaml
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
import sys
|
| 27 |
+
import os
|
| 28 |
+
|
| 29 |
+
# Add current directory to path so patches can be imported
|
| 30 |
+
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
| 31 |
+
|
| 32 |
+
# Step 1: Register HYV3 template (must be before training starts)
|
| 33 |
+
import hy_v3_template # noqa: F401
|
| 34 |
+
|
| 35 |
+
# Step 2: Apply checkpoint key rename patch (must be before model loading)
|
| 36 |
+
import hy_v3_patches # noqa: F401
|
| 37 |
+
|
| 38 |
+
# Step 3: Inject HYV3PatchCallback into LLaMA Factory's training flow
|
| 39 |
+
from llamafactory.train.sft.workflow import run_sft as _orig_run_sft
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def _patched_run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks=None):
|
| 43 |
+
"""Wrap run_sft to inject HYV3PatchCallback."""
|
| 44 |
+
if callbacks is None:
|
| 45 |
+
callbacks = []
|
| 46 |
+
|
| 47 |
+
# Determine tokenizer directory for the save callback
|
| 48 |
+
tokenizer_dir = getattr(model_args, "model_name_or_path", None)
|
| 49 |
+
callbacks.append(hy_v3_patches.HYV3PatchCallback(tokenizer_dir=tokenizer_dir))
|
| 50 |
+
|
| 51 |
+
return _orig_run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks=callbacks)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
# Monkey-patch the SFT workflow
|
| 55 |
+
import llamafactory.train.sft.workflow as _sft_wf
|
| 56 |
+
_sft_wf.run_sft = _patched_run_sft
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def main():
|
| 60 |
+
"""Entry point: called by torchrun in each worker process.
|
| 61 |
+
|
| 62 |
+
Since train_lf.sh launches us via torchrun directly, all patches
|
| 63 |
+
(template registration, checkpoint key rename, SFT callback injection)
|
| 64 |
+
are already applied in this process. We just call run_exp() to start
|
| 65 |
+
training — no need to go through the CLI launcher.
|
| 66 |
+
"""
|
| 67 |
+
from llamafactory.train.tuner import run_exp
|
| 68 |
+
run_exp()
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
if __name__ == "__main__":
|
| 72 |
+
main()
|
train/llama_factory_support/train_lf.sh
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# ============================================================================
|
| 3 |
+
# LLaMA Factory training launch script for HYV3
|
| 4 |
+
#
|
| 5 |
+
# This script sets up the environment and launches training via torchrun.
|
| 6 |
+
#
|
| 7 |
+
# We use train_hy_v3.py as the entry point (not llamafactory-cli)
|
| 8 |
+
# because we need to inject HYV3-specific monkey-patches and register
|
| 9 |
+
# the hy_v3 chat template BEFORE LLaMA Factory starts.
|
| 10 |
+
# train_hy_v3.py directly calls run_exp() in each torchrun worker,
|
| 11 |
+
# ensuring all patches are active.
|
| 12 |
+
#
|
| 13 |
+
# Usage:
|
| 14 |
+
# Single node: bash train_lf.sh
|
| 15 |
+
# Multi-node: Run this script on EACH node with the same IP_LIST.
|
| 16 |
+
# IP_LIST="10.0.0.1,10.0.0.2" bash train_lf.sh
|
| 17 |
+
# ============================================================================
|
| 18 |
+
|
| 19 |
+
set -euo pipefail
|
| 20 |
+
|
| 21 |
+
# -------------------- Network Configuration --------------------
|
| 22 |
+
NET_TYPE="high"
|
| 23 |
+
export NCCL_DEBUG=WARN
|
| 24 |
+
export NCCL_P2P_LEVEL=NVL
|
| 25 |
+
export NCCL_IB_TIMEOUT=24
|
| 26 |
+
export NCCL_NVLS_ENABLE=0
|
| 27 |
+
export NCCL_MPI_PROFILE_PRIMS_ENABLE=0
|
| 28 |
+
export CUDA_DEVICE_MAX_CONNECTIONS=1
|
| 29 |
+
export TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC=3600
|
| 30 |
+
if [[ "${NET_TYPE}" = "low" ]]; then
|
| 31 |
+
export NCCL_SOCKET_IFNAME=eth1
|
| 32 |
+
export NCCL_IB_GID_INDEX=3
|
| 33 |
+
export NCCL_IB_HCA=mlx5_2:1
|
| 34 |
+
export NCCL_IB_SL=3
|
| 35 |
+
export NCCL_CHECK_DISABLE=1
|
| 36 |
+
export NCCL_P2P_DISABLE=0
|
| 37 |
+
export NCCL_LL_THRESHOLD=16384
|
| 38 |
+
export NCCL_IB_CUDA_SUPPORT=1
|
| 39 |
+
else
|
| 40 |
+
export NCCL_IB_GID_INDEX=3
|
| 41 |
+
export NCCL_IB_SL=3
|
| 42 |
+
export NCCL_CHECK_DISABLE=1
|
| 43 |
+
export NCCL_P2P_DISABLE=0
|
| 44 |
+
export NCCL_IB_DISABLE=0
|
| 45 |
+
export NCCL_LL_THRESHOLD=16384
|
| 46 |
+
export NCCL_IB_CUDA_SUPPORT=1
|
| 47 |
+
export NCCL_SOCKET_IFNAME=bond1
|
| 48 |
+
export UCX_NET_DEVICES=bond1
|
| 49 |
+
export NCCL_IB_HCA=mlx5_bond_1,mlx5_bond_5,mlx5_bond_3,mlx5_bond_7,mlx5_bond_4,mlx5_bond_8,mlx5_bond_2,mlx5_bond_6
|
| 50 |
+
export NCCL_COLLNET_ENABLE=0
|
| 51 |
+
export SHARP_COLL_ENABLE_SAT=0
|
| 52 |
+
export NCCL_NET_GDR_LEVEL=2
|
| 53 |
+
export NCCL_IB_QPS_PER_CONNECTION=4
|
| 54 |
+
export NCCL_IB_TC=160
|
| 55 |
+
export NCCL_PXN_DISABLE=1
|
| 56 |
+
fi
|
| 57 |
+
|
| 58 |
+
# Skip LLaMA Factory version check (we use a newer transformers branch)
|
| 59 |
+
export DISABLE_VERSION_CHECK=1
|
| 60 |
+
|
| 61 |
+
# -------------------- Node Configuration --------------------
|
| 62 |
+
export HOST_GPU_NUM=8
|
| 63 |
+
# IP list, comma separated. e.g. "10.0.0.1,10.0.0.2" or single node "127.0.0.1"
|
| 64 |
+
export IP_LIST=${IP_LIST:-"127.0.0.1"}
|
| 65 |
+
|
| 66 |
+
MASTER_PORT=${MASTER_PORT:-29500}
|
| 67 |
+
|
| 68 |
+
IFS=',' read -ra IP_ARRAY <<< "$IP_LIST"
|
| 69 |
+
NODES=${#IP_ARRAY[@]}
|
| 70 |
+
MASTER_ADDR=${IP_ARRAY[0]}
|
| 71 |
+
|
| 72 |
+
# -------------------- Paths --------------------
|
| 73 |
+
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
| 74 |
+
YAML_FILE="${SCRIPT_DIR}/hy_v3_full_sft.yaml"
|
| 75 |
+
ENTRY_SCRIPT="${SCRIPT_DIR}/train_hy_v3.py"
|
| 76 |
+
|
| 77 |
+
# -------------------- Distributed Environment --------------------
|
| 78 |
+
export MASTER_ADDR="${MASTER_ADDR}"
|
| 79 |
+
export MASTER_PORT="${MASTER_PORT}"
|
| 80 |
+
export NNODES="${NODES}"
|
| 81 |
+
|
| 82 |
+
if [ ${NODES} -gt 1 ]; then
|
| 83 |
+
# Determine local node rank by matching local IP against IP_LIST
|
| 84 |
+
LOCAL_IP=$(hostname -i | awk '{print $1}')
|
| 85 |
+
NODE_RANK=0
|
| 86 |
+
for i in "${!IP_ARRAY[@]}"; do
|
| 87 |
+
if [[ "${IP_ARRAY[$i]}" == "${LOCAL_IP}" ]]; then
|
| 88 |
+
NODE_RANK=$i
|
| 89 |
+
break
|
| 90 |
+
fi
|
| 91 |
+
done
|
| 92 |
+
export RANK="${NODE_RANK}"
|
| 93 |
+
else
|
| 94 |
+
export RANK=0
|
| 95 |
+
fi
|
| 96 |
+
|
| 97 |
+
echo "============================================"
|
| 98 |
+
echo " HYV3 LLaMA Factory Training"
|
| 99 |
+
echo " Nodes: ${NNODES}, Rank: ${RANK}"
|
| 100 |
+
echo " Master: ${MASTER_ADDR}:${MASTER_PORT}"
|
| 101 |
+
echo " GPUs per node: ${HOST_GPU_NUM}"
|
| 102 |
+
echo " Total GPUs: $((NODES * HOST_GPU_NUM))"
|
| 103 |
+
echo "============================================"
|
| 104 |
+
|
| 105 |
+
# -------------------- Launch --------------------
|
| 106 |
+
# We launch torchrun directly (instead of FORCE_TORCHRUN) so that each
|
| 107 |
+
# worker process runs train_hy_v3.py with all HYV3 patches applied.
|
| 108 |
+
torchrun \
|
| 109 |
+
--nnodes "${NNODES}" \
|
| 110 |
+
--node_rank "${RANK}" \
|
| 111 |
+
--nproc_per_node "${HOST_GPU_NUM}" \
|
| 112 |
+
--master_addr "${MASTER_ADDR}" \
|
| 113 |
+
--master_port "${MASTER_PORT}" \
|
| 114 |
+
"${ENTRY_SCRIPT}" "${YAML_FILE}"
|
train/merge_lora_weight.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 导入所需的库
|
| 2 |
+
from transformers import AutoModelForCausalLM # 用于加载预训练的语言模型
|
| 3 |
+
from peft import LoraConfig, get_peft_model, PeftModel # 用于处理LoRA权重
|
| 4 |
+
import argparse # 用于解析命令行参数
|
| 5 |
+
import shutil # 用于文件操作,如复制
|
| 6 |
+
import os # 用于文件路径操作
|
| 7 |
+
import torch # 用于深度学习操作
|
| 8 |
+
|
| 9 |
+
def main():
|
| 10 |
+
# 创建参数解析器
|
| 11 |
+
parser = argparse.ArgumentParser()
|
| 12 |
+
# 添加命令行参数
|
| 13 |
+
parser.add_argument("--base_model_path", type=str, required=True,
|
| 14 |
+
help="Path to pretrained model or model identifier from huggingface.co/models")
|
| 15 |
+
parser.add_argument("--adapter_model_path", type=str, required=True, help="Path to adapter model")
|
| 16 |
+
parser.add_argument("--output_path", type=str, required=True, help="Path to save the output model")
|
| 17 |
+
parser.add_argument("--save_dtype", type=str, choices=['bf16', 'fp32', 'fp16'],
|
| 18 |
+
default='fp32', help="In which dtype to save, fp32, bf16 or fp16.")
|
| 19 |
+
# 解析命令行参数
|
| 20 |
+
args = parser.parse_args()
|
| 21 |
+
|
| 22 |
+
name2dtype = {'bf16': torch.bfloat16, 'fp32': torch.float32, 'fp16': torch.float16}
|
| 23 |
+
# 加载基座模型
|
| 24 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 25 |
+
args.base_model_path, device_map='cpu',
|
| 26 |
+
trust_remote_code=True, torch_dtype=name2dtype[args.save_dtype]
|
| 27 |
+
)
|
| 28 |
+
# 在基座模型的基础上加载 adapter 权重
|
| 29 |
+
model = PeftModel.from_pretrained(model, args.adapter_model_path, trust_remote_code=True)
|
| 30 |
+
# 融合模型和 adapter
|
| 31 |
+
model = model.merge_and_unload()
|
| 32 |
+
# 保存融合后的模型权重
|
| 33 |
+
model.save_pretrained(args.output_path, safe_serialization=False)
|
| 34 |
+
|
| 35 |
+
# Copy tokenizer, config and other non-weight files from base model
|
| 36 |
+
# Skip model weight files (.safetensors, .bin, .pt) and index files
|
| 37 |
+
_SKIP_SUFFIXES = ('.safetensors', '.bin', '.pt', '.pth')
|
| 38 |
+
_SKIP_NAMES = {'model.safetensors.index.json', 'pytorch_model.bin.index.json'}
|
| 39 |
+
|
| 40 |
+
for fname in os.listdir(args.base_model_path):
|
| 41 |
+
src = os.path.join(args.base_model_path, fname)
|
| 42 |
+
if not os.path.isfile(src):
|
| 43 |
+
continue
|
| 44 |
+
if fname in _SKIP_NAMES or fname.endswith(_SKIP_SUFFIXES):
|
| 45 |
+
continue
|
| 46 |
+
dst = os.path.join(args.output_path, fname)
|
| 47 |
+
if not os.path.exists(dst):
|
| 48 |
+
shutil.copy(src, dst)
|
| 49 |
+
print(f'Copied {fname}')
|
| 50 |
+
|
| 51 |
+
print(f'Merged model weight is saved to {args.output_path}')
|
| 52 |
+
|
| 53 |
+
if __name__ == "__main__":
|
| 54 |
+
main()
|
train/merge_lora_weight.sh
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
python3 merge_lora_weight.py --base_model_path /xxx/hy_l_train/checkpoint-200 --adapter_model_path /xxx/runs/hy_l_lora_train/checkpoint-200 --output_path /xxx/ckpts/merged_hy_lora_weight --save_dtype bf16
|
train/requirements.txt
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
transformers>=5.6.0
|
| 2 |
+
torch>=2.10.0
|
| 3 |
+
torchvision
|
| 4 |
+
torchaudio
|
| 5 |
+
accelerate>=1.11.0
|
| 6 |
+
peft>=0.18.1
|
| 7 |
+
deepspeed>=0.18.7
|
| 8 |
+
sentencepiece
|
| 9 |
+
protobuf
|
| 10 |
+
ninja
|
| 11 |
+
flash-attn
|
| 12 |
+
tensorboard
|
train/tools/check_converted.py
ADDED
|
@@ -0,0 +1,455 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Quick validation script for converted HYV3 outer-format checkpoint.
|
| 4 |
+
|
| 5 |
+
Checks:
|
| 6 |
+
1. model.safetensors.index.json structure and completeness
|
| 7 |
+
2. All expected weight keys exist (dense layer 0, MoE layers 1-79)
|
| 8 |
+
3. Expert tensor shapes (fused 3D format)
|
| 9 |
+
4. All referenced shard files exist and are non-empty
|
| 10 |
+
5. Spot-check: load a few shards and verify tensor shapes/dtypes
|
| 11 |
+
6. No duplicate or orphan keys
|
| 12 |
+
|
| 13 |
+
Usage:
|
| 14 |
+
python check_converted.py <output_dir> [--spot-check N]
|
| 15 |
+
|
| 16 |
+
Example:
|
| 17 |
+
python check_converted.py pretrain_base/hf_outer
|
| 18 |
+
python check_converted.py pretrain_base/hf_outer --spot-check 5
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
import argparse
|
| 22 |
+
import json
|
| 23 |
+
import os
|
| 24 |
+
import sys
|
| 25 |
+
import time
|
| 26 |
+
from collections import defaultdict
|
| 27 |
+
|
| 28 |
+
# ============================================================================
|
| 29 |
+
# Expected key patterns for HYV3 outer format
|
| 30 |
+
# ============================================================================
|
| 31 |
+
|
| 32 |
+
# Dense layer (layer 0) expected suffixes
|
| 33 |
+
DENSE_SUFFIXES = [
|
| 34 |
+
"input_layernorm.weight",
|
| 35 |
+
"post_attention_layernorm.weight",
|
| 36 |
+
"self_attn.q_proj.weight",
|
| 37 |
+
"self_attn.k_proj.weight",
|
| 38 |
+
"self_attn.v_proj.weight",
|
| 39 |
+
"self_attn.o_proj.weight",
|
| 40 |
+
"self_attn.q_norm.weight",
|
| 41 |
+
"self_attn.k_norm.weight",
|
| 42 |
+
"mlp.gate_proj.weight",
|
| 43 |
+
"mlp.up_proj.weight",
|
| 44 |
+
"mlp.down_proj.weight",
|
| 45 |
+
]
|
| 46 |
+
|
| 47 |
+
# MoE layer (layers 1-79) expected suffixes
|
| 48 |
+
MOE_SUFFIXES = [
|
| 49 |
+
"input_layernorm.weight",
|
| 50 |
+
"post_attention_layernorm.weight",
|
| 51 |
+
"self_attn.q_proj.weight",
|
| 52 |
+
"self_attn.k_proj.weight",
|
| 53 |
+
"self_attn.v_proj.weight",
|
| 54 |
+
"self_attn.o_proj.weight",
|
| 55 |
+
"self_attn.q_norm.weight",
|
| 56 |
+
"self_attn.k_norm.weight",
|
| 57 |
+
# MoE-specific
|
| 58 |
+
"mlp.gate.weight",
|
| 59 |
+
"mlp.e_score_correction_bias",
|
| 60 |
+
"mlp.experts.gate_up_proj",
|
| 61 |
+
"mlp.experts.down_proj",
|
| 62 |
+
"mlp.shared_experts.gate_proj.weight",
|
| 63 |
+
"mlp.shared_experts.up_proj.weight",
|
| 64 |
+
"mlp.shared_experts.down_proj.weight",
|
| 65 |
+
]
|
| 66 |
+
|
| 67 |
+
# MTP (Multi-Token Prediction) layer expected suffixes
|
| 68 |
+
# MTP layers share MoE structure but have additional projection/norm keys
|
| 69 |
+
MTP_EXTRA_SUFFIXES = [
|
| 70 |
+
"eh_proj.weight",
|
| 71 |
+
"enorm.weight",
|
| 72 |
+
"final_layernorm.weight",
|
| 73 |
+
"hnorm.weight",
|
| 74 |
+
]
|
| 75 |
+
|
| 76 |
+
# Global keys (not per-layer)
|
| 77 |
+
GLOBAL_KEYS = [
|
| 78 |
+
"model.embed_tokens.weight",
|
| 79 |
+
"model.norm.weight",
|
| 80 |
+
"lm_head.weight",
|
| 81 |
+
]
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def load_config(output_dir):
|
| 85 |
+
"""Load config.json and extract model parameters."""
|
| 86 |
+
config_path = os.path.join(output_dir, "config.json")
|
| 87 |
+
if not os.path.exists(config_path):
|
| 88 |
+
print(f"[ERROR] config.json not found in {output_dir}")
|
| 89 |
+
return None
|
| 90 |
+
with open(config_path) as f:
|
| 91 |
+
return json.load(f)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def check_index_json(output_dir):
|
| 95 |
+
"""Check model.safetensors.index.json for structure and completeness."""
|
| 96 |
+
index_path = os.path.join(output_dir, "model.safetensors.index.json")
|
| 97 |
+
if not os.path.exists(index_path):
|
| 98 |
+
print(f"[ERROR] model.safetensors.index.json not found")
|
| 99 |
+
return None, []
|
| 100 |
+
|
| 101 |
+
with open(index_path) as f:
|
| 102 |
+
index = json.load(f)
|
| 103 |
+
|
| 104 |
+
errors = []
|
| 105 |
+
|
| 106 |
+
# Check structure
|
| 107 |
+
if "metadata" not in index:
|
| 108 |
+
errors.append("Missing 'metadata' in index.json")
|
| 109 |
+
elif "total_size" not in index["metadata"]:
|
| 110 |
+
errors.append("Missing 'total_size' in metadata")
|
| 111 |
+
|
| 112 |
+
if "weight_map" not in index:
|
| 113 |
+
errors.append("Missing 'weight_map' in index.json")
|
| 114 |
+
return index, errors
|
| 115 |
+
|
| 116 |
+
weight_map = index["weight_map"]
|
| 117 |
+
total_size = index.get("metadata", {}).get("total_size", 0)
|
| 118 |
+
|
| 119 |
+
print(f" Index keys : {len(weight_map)}")
|
| 120 |
+
print(f" Total size : {total_size / 1e9:.2f} GB")
|
| 121 |
+
|
| 122 |
+
# Check for empty weight_map
|
| 123 |
+
if len(weight_map) == 0:
|
| 124 |
+
errors.append("weight_map is empty")
|
| 125 |
+
|
| 126 |
+
return index, errors
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def check_expected_keys(weight_map, config):
|
| 130 |
+
"""Check that all expected keys exist in the weight_map."""
|
| 131 |
+
errors = []
|
| 132 |
+
warnings = []
|
| 133 |
+
|
| 134 |
+
num_layers = config.get("num_hidden_layers", 80)
|
| 135 |
+
first_k_dense = config.get("first_k_dense_replace", 1)
|
| 136 |
+
num_experts = config.get("num_experts", 192)
|
| 137 |
+
num_mtp_layers = config.get("num_nextn_predict_layers", 0)
|
| 138 |
+
|
| 139 |
+
# Check global keys
|
| 140 |
+
for key in GLOBAL_KEYS:
|
| 141 |
+
if key not in weight_map:
|
| 142 |
+
errors.append(f"Missing global key: {key}")
|
| 143 |
+
|
| 144 |
+
# Check per-layer keys (regular layers)
|
| 145 |
+
missing_by_type = defaultdict(list)
|
| 146 |
+
for layer_idx in range(num_layers):
|
| 147 |
+
prefix = f"model.layers.{layer_idx}."
|
| 148 |
+
if layer_idx < first_k_dense:
|
| 149 |
+
# Dense layer
|
| 150 |
+
suffixes = DENSE_SUFFIXES
|
| 151 |
+
else:
|
| 152 |
+
# MoE layer
|
| 153 |
+
suffixes = MOE_SUFFIXES
|
| 154 |
+
|
| 155 |
+
for suffix in suffixes:
|
| 156 |
+
full_key = prefix + suffix
|
| 157 |
+
if full_key not in weight_map:
|
| 158 |
+
missing_by_type[suffix].append(layer_idx)
|
| 159 |
+
|
| 160 |
+
# Check MTP layers (layer num_layers .. num_layers + num_mtp_layers - 1)
|
| 161 |
+
mtp_missing_by_type = defaultdict(list)
|
| 162 |
+
for mtp_idx in range(num_mtp_layers):
|
| 163 |
+
layer_idx = num_layers + mtp_idx
|
| 164 |
+
prefix = f"model.layers.{layer_idx}."
|
| 165 |
+
# MTP layers use MoE structure + extra projection/norm keys
|
| 166 |
+
mtp_suffixes = MOE_SUFFIXES + MTP_EXTRA_SUFFIXES
|
| 167 |
+
for suffix in mtp_suffixes:
|
| 168 |
+
full_key = prefix + suffix
|
| 169 |
+
if full_key not in weight_map:
|
| 170 |
+
mtp_missing_by_type[suffix].append(layer_idx)
|
| 171 |
+
|
| 172 |
+
for suffix, layers in sorted(mtp_missing_by_type.items()):
|
| 173 |
+
layer_str = str(layers)
|
| 174 |
+
errors.append(f"Missing MTP key '{suffix}' in layers: {layer_str}")
|
| 175 |
+
|
| 176 |
+
for suffix, layers in sorted(missing_by_type.items()):
|
| 177 |
+
if len(layers) <= 5:
|
| 178 |
+
layer_str = str(layers)
|
| 179 |
+
else:
|
| 180 |
+
layer_str = f"{layers[:3]}...({len(layers)} total)"
|
| 181 |
+
errors.append(f"Missing '{suffix}' in layers: {layer_str}")
|
| 182 |
+
|
| 183 |
+
# Check for unexpected keys (not matching any known pattern)
|
| 184 |
+
known_prefixes = set()
|
| 185 |
+
# Regular layers + MTP layers
|
| 186 |
+
for layer_idx in range(num_layers + num_mtp_layers):
|
| 187 |
+
known_prefixes.add(f"model.layers.{layer_idx}.")
|
| 188 |
+
known_prefixes.add("model.embed_tokens.")
|
| 189 |
+
known_prefixes.add("model.norm.")
|
| 190 |
+
known_prefixes.add("lm_head.")
|
| 191 |
+
# Alternative MTP prefix (some models use this)
|
| 192 |
+
known_prefixes.add("model.mtp_layers.")
|
| 193 |
+
|
| 194 |
+
unexpected = []
|
| 195 |
+
for key in weight_map:
|
| 196 |
+
if not any(key.startswith(p) for p in known_prefixes):
|
| 197 |
+
unexpected.append(key)
|
| 198 |
+
|
| 199 |
+
if unexpected:
|
| 200 |
+
if len(unexpected) <= 5:
|
| 201 |
+
for k in unexpected:
|
| 202 |
+
warnings.append(f"Unexpected key: {k}")
|
| 203 |
+
else:
|
| 204 |
+
warnings.append(f"{len(unexpected)} unexpected keys found (first 3: {unexpected[:3]})")
|
| 205 |
+
|
| 206 |
+
return errors, warnings
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def check_shard_files(output_dir, weight_map):
|
| 210 |
+
"""Check that all referenced shard files exist and are non-empty."""
|
| 211 |
+
errors = []
|
| 212 |
+
warnings = []
|
| 213 |
+
|
| 214 |
+
# Get unique shard files
|
| 215 |
+
shard_files = sorted(set(weight_map.values()))
|
| 216 |
+
print(f" Shard files : {len(shard_files)}")
|
| 217 |
+
|
| 218 |
+
missing = []
|
| 219 |
+
empty = []
|
| 220 |
+
total_disk_size = 0
|
| 221 |
+
|
| 222 |
+
for sf in shard_files:
|
| 223 |
+
path = os.path.join(output_dir, sf)
|
| 224 |
+
if not os.path.exists(path):
|
| 225 |
+
missing.append(sf)
|
| 226 |
+
else:
|
| 227 |
+
size = os.path.getsize(path)
|
| 228 |
+
if size == 0:
|
| 229 |
+
empty.append(sf)
|
| 230 |
+
total_disk_size += size
|
| 231 |
+
|
| 232 |
+
print(f" Disk size : {total_disk_size / 1e9:.2f} GB")
|
| 233 |
+
|
| 234 |
+
if missing:
|
| 235 |
+
errors.append(f"Missing shard files ({len(missing)}): {missing[:5]}")
|
| 236 |
+
if empty:
|
| 237 |
+
errors.append(f"Empty shard files ({len(empty)}): {empty[:5]}")
|
| 238 |
+
|
| 239 |
+
# Check for orphan shard files (exist on disk but not in index)
|
| 240 |
+
all_safetensors = set(
|
| 241 |
+
f for f in os.listdir(output_dir)
|
| 242 |
+
if f.endswith(".safetensors")
|
| 243 |
+
)
|
| 244 |
+
referenced = set(shard_files)
|
| 245 |
+
orphans = all_safetensors - referenced
|
| 246 |
+
if orphans:
|
| 247 |
+
# Distinguish between empty residue files (cross-shard merge artifacts)
|
| 248 |
+
# and real orphan files with actual data
|
| 249 |
+
EMPTY_SHARD_THRESHOLD = 128 # bytes; empty safetensors header is ~16 bytes
|
| 250 |
+
residue_orphans = []
|
| 251 |
+
real_orphans = []
|
| 252 |
+
for o in sorted(orphans):
|
| 253 |
+
sz = os.path.getsize(os.path.join(output_dir, o))
|
| 254 |
+
if sz <= EMPTY_SHARD_THRESHOLD:
|
| 255 |
+
residue_orphans.append(o)
|
| 256 |
+
else:
|
| 257 |
+
real_orphans.append(o)
|
| 258 |
+
|
| 259 |
+
if residue_orphans:
|
| 260 |
+
warnings.append(
|
| 261 |
+
f"{len(residue_orphans)} empty residue shard(s) from cross-shard merge "
|
| 262 |
+
f"(<=128 bytes each, safe to delete)"
|
| 263 |
+
)
|
| 264 |
+
if real_orphans:
|
| 265 |
+
errors.append(
|
| 266 |
+
f"Orphan shard files with data (not in index): {real_orphans[:5]}"
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
return errors, warnings
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
def check_key_distribution(weight_map):
|
| 273 |
+
"""Check the distribution of keys across shards."""
|
| 274 |
+
shard_key_count = defaultdict(int)
|
| 275 |
+
for key, shard in weight_map.items():
|
| 276 |
+
shard_key_count[shard] += 1
|
| 277 |
+
|
| 278 |
+
counts = sorted(shard_key_count.values())
|
| 279 |
+
print(f" Keys/shard : min={counts[0]}, max={counts[-1]}, "
|
| 280 |
+
f"median={counts[len(counts)//2]}")
|
| 281 |
+
|
| 282 |
+
# Check for shards with 0 keys (should not happen if they are in weight_map)
|
| 283 |
+
zero_shards = [s for s, c in shard_key_count.items() if c == 0]
|
| 284 |
+
if zero_shards:
|
| 285 |
+
return [f"Shards with 0 keys: {zero_shards}"]
|
| 286 |
+
return []
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
def spot_check_shards(output_dir, weight_map, config, num_checks=3):
|
| 290 |
+
"""Spot-check a few shards by loading and verifying tensor shapes."""
|
| 291 |
+
errors = []
|
| 292 |
+
|
| 293 |
+
try:
|
| 294 |
+
from safetensors import safe_open
|
| 295 |
+
except ImportError:
|
| 296 |
+
print(" [SKIP] safetensors not installed, skipping spot-check")
|
| 297 |
+
return errors
|
| 298 |
+
|
| 299 |
+
num_experts = config.get("num_experts", 192)
|
| 300 |
+
expert_hidden = config.get("expert_hidden_dim", config.get("moe_intermediate_size", 1536))
|
| 301 |
+
hidden_size = config.get("hidden_size", 4096)
|
| 302 |
+
|
| 303 |
+
# Find shards that contain expert tensors (most interesting to check)
|
| 304 |
+
expert_shards = set()
|
| 305 |
+
for key, shard in weight_map.items():
|
| 306 |
+
if "experts.gate_up_proj" in key or "experts.down_proj" in key:
|
| 307 |
+
expert_shards.add(shard)
|
| 308 |
+
|
| 309 |
+
# Pick a few shards to check
|
| 310 |
+
check_shards = sorted(expert_shards)[:num_checks]
|
| 311 |
+
if not check_shards:
|
| 312 |
+
check_shards = sorted(set(weight_map.values()))[:num_checks]
|
| 313 |
+
|
| 314 |
+
print(f"\n Spot-checking {len(check_shards)} shard(s)...")
|
| 315 |
+
|
| 316 |
+
for shard_file in check_shards:
|
| 317 |
+
shard_path = os.path.join(output_dir, shard_file)
|
| 318 |
+
t0 = time.time()
|
| 319 |
+
|
| 320 |
+
try:
|
| 321 |
+
with safe_open(shard_path, framework="pt", device="cpu") as f:
|
| 322 |
+
keys_in_shard = list(f.keys())
|
| 323 |
+
for key in keys_in_shard:
|
| 324 |
+
tensor = f.get_tensor(key)
|
| 325 |
+
|
| 326 |
+
# Check expert shapes
|
| 327 |
+
if key.endswith("experts.gate_up_proj"):
|
| 328 |
+
expected_shape = (num_experts, expert_hidden * 2, hidden_size)
|
| 329 |
+
if tuple(tensor.shape) != expected_shape:
|
| 330 |
+
errors.append(
|
| 331 |
+
f"{shard_file}/{key}: shape {tuple(tensor.shape)} "
|
| 332 |
+
f"!= expected {expected_shape}"
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
elif key.endswith("experts.down_proj"):
|
| 336 |
+
expected_shape = (num_experts, hidden_size, expert_hidden)
|
| 337 |
+
if tuple(tensor.shape) != expected_shape:
|
| 338 |
+
errors.append(
|
| 339 |
+
f"{shard_file}/{key}: shape {tuple(tensor.shape)} "
|
| 340 |
+
f"!= expected {expected_shape}"
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
# Check for NaN/Inf
|
| 344 |
+
if tensor.is_floating_point():
|
| 345 |
+
if tensor.isnan().any():
|
| 346 |
+
errors.append(f"{shard_file}/{key}: contains NaN values")
|
| 347 |
+
if tensor.isinf().any():
|
| 348 |
+
errors.append(f"{shard_file}/{key}: contains Inf values")
|
| 349 |
+
|
| 350 |
+
elapsed = time.time() - t0
|
| 351 |
+
print(f" {shard_file}: {len(keys_in_shard)} keys, OK ({elapsed:.1f}s)")
|
| 352 |
+
|
| 353 |
+
except Exception as e:
|
| 354 |
+
errors.append(f"Failed to load {shard_file}: {e}")
|
| 355 |
+
|
| 356 |
+
return errors
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
def main():
|
| 360 |
+
parser = argparse.ArgumentParser(
|
| 361 |
+
description="Validate converted HYV3 outer-format checkpoint."
|
| 362 |
+
)
|
| 363 |
+
parser.add_argument(
|
| 364 |
+
"output_dir", type=str,
|
| 365 |
+
help="Path to the converted outer-format checkpoint directory.",
|
| 366 |
+
)
|
| 367 |
+
parser.add_argument(
|
| 368 |
+
"--spot-check", type=int, default=3, dest="spot_check",
|
| 369 |
+
help="Number of shards to spot-check by loading tensors (default: 3).",
|
| 370 |
+
)
|
| 371 |
+
args = parser.parse_args()
|
| 372 |
+
|
| 373 |
+
output_dir = os.path.abspath(args.output_dir)
|
| 374 |
+
print(f"Validating: {output_dir}\n")
|
| 375 |
+
|
| 376 |
+
if not os.path.isdir(output_dir):
|
| 377 |
+
print(f"[ERROR] Directory not found: {output_dir}")
|
| 378 |
+
sys.exit(1)
|
| 379 |
+
|
| 380 |
+
all_errors = []
|
| 381 |
+
all_warnings = []
|
| 382 |
+
|
| 383 |
+
# 1. Load config
|
| 384 |
+
print("[1/5] Loading config.json...")
|
| 385 |
+
config = load_config(output_dir)
|
| 386 |
+
if config is None:
|
| 387 |
+
print("[ERROR] Cannot proceed without config.json")
|
| 388 |
+
sys.exit(1)
|
| 389 |
+
|
| 390 |
+
num_layers = config.get("num_hidden_layers", 0)
|
| 391 |
+
num_experts = config.get("num_experts", 0)
|
| 392 |
+
first_k_dense = config.get("first_k_dense_replace", 0)
|
| 393 |
+
num_mtp = config.get("num_nextn_predict_layers", 0)
|
| 394 |
+
print(f" Layers : {num_layers} ({first_k_dense} dense, {num_layers - first_k_dense} MoE)")
|
| 395 |
+
print(f" MTP layers : {num_mtp}")
|
| 396 |
+
print(f" Experts/layer : {num_experts}")
|
| 397 |
+
print(f" Hidden size : {config.get('hidden_size', '?')}")
|
| 398 |
+
print(f" Expert hidden : {config.get('expert_hidden_dim', config.get('moe_intermediate_size', '?'))}")
|
| 399 |
+
|
| 400 |
+
# 2. Check index.json
|
| 401 |
+
print("\n[2/5] Checking model.safetensors.index.json...")
|
| 402 |
+
index, idx_errors = check_index_json(output_dir)
|
| 403 |
+
all_errors.extend(idx_errors)
|
| 404 |
+
|
| 405 |
+
if index is None or "weight_map" not in index:
|
| 406 |
+
print("[ERROR] Cannot proceed without valid index.json")
|
| 407 |
+
sys.exit(1)
|
| 408 |
+
|
| 409 |
+
weight_map = index["weight_map"]
|
| 410 |
+
|
| 411 |
+
# 3. Check expected keys
|
| 412 |
+
print("\n[3/5] Checking expected keys...")
|
| 413 |
+
key_errors, key_warnings = check_expected_keys(weight_map, config)
|
| 414 |
+
all_errors.extend(key_errors)
|
| 415 |
+
all_warnings.extend(key_warnings)
|
| 416 |
+
|
| 417 |
+
# Also check key distribution
|
| 418 |
+
dist_errors = check_key_distribution(weight_map)
|
| 419 |
+
all_errors.extend(dist_errors)
|
| 420 |
+
|
| 421 |
+
# 4. Check shard files
|
| 422 |
+
print("\n[4/5] Checking shard files on disk...")
|
| 423 |
+
shard_errors, shard_warnings = check_shard_files(output_dir, weight_map)
|
| 424 |
+
all_errors.extend(shard_errors)
|
| 425 |
+
all_warnings.extend(shard_warnings)
|
| 426 |
+
|
| 427 |
+
# 5. Spot-check
|
| 428 |
+
if args.spot_check > 0:
|
| 429 |
+
print(f"\n[5/5] Spot-checking tensors (loading {args.spot_check} shard(s))...")
|
| 430 |
+
spot_errors = spot_check_shards(output_dir, weight_map, config, args.spot_check)
|
| 431 |
+
all_errors.extend(spot_errors)
|
| 432 |
+
else:
|
| 433 |
+
print("\n[5/5] Spot-check skipped (--spot-check 0)")
|
| 434 |
+
|
| 435 |
+
# Summary
|
| 436 |
+
print(f"\n{'=' * 60}")
|
| 437 |
+
if all_warnings:
|
| 438 |
+
print(f"WARNINGS ({len(all_warnings)}):")
|
| 439 |
+
for w in all_warnings:
|
| 440 |
+
print(f" [WARN] {w}")
|
| 441 |
+
|
| 442 |
+
if all_errors:
|
| 443 |
+
print(f"ERRORS ({len(all_errors)}):")
|
| 444 |
+
for e in all_errors:
|
| 445 |
+
print(f" [ERROR] {e}")
|
| 446 |
+
print(f"\nResult: FAILED ({len(all_errors)} error(s), {len(all_warnings)} warning(s))")
|
| 447 |
+
sys.exit(1)
|
| 448 |
+
else:
|
| 449 |
+
print(f"Result: PASSED (0 errors, {len(all_warnings)} warning(s))")
|
| 450 |
+
print(f"{'=' * 60}")
|
| 451 |
+
sys.exit(0)
|
| 452 |
+
|
| 453 |
+
|
| 454 |
+
if __name__ == "__main__":
|
| 455 |
+
main()
|
train/tools/convert_ckpt_to_outer.py
ADDED
|
@@ -0,0 +1,641 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Memory-friendly checkpoint converter: inner -> outer format (v2).
|
| 4 |
+
|
| 5 |
+
Converts the HYV3 checkpoint from inner format (per-expert keys, old naming)
|
| 6 |
+
to outer format (fused 3D experts, new naming) shard by shard.
|
| 7 |
+
|
| 8 |
+
Handles the case where a single layer's experts may be split across
|
| 9 |
+
multiple shards (cross-shard experts) by deferring their fusion to a
|
| 10 |
+
post-processing step.
|
| 11 |
+
|
| 12 |
+
v2 improvements over v1:
|
| 13 |
+
- Post-processing is shard-centric (each shard read/written only once)
|
| 14 |
+
instead of prefix-centric (same shard read/written multiple times).
|
| 15 |
+
This fixes Bus error (core dump) when there are many cross-shard groups.
|
| 16 |
+
- Explicit memory management with gc.collect() to prevent memory bloat.
|
| 17 |
+
- Better progress reporting during post-processing.
|
| 18 |
+
|
| 19 |
+
Supports multi-process parallelism for faster conversion.
|
| 20 |
+
|
| 21 |
+
Usage:
|
| 22 |
+
# Default 8 workers
|
| 23 |
+
python convert_ckpt_to_outer.py \\
|
| 24 |
+
--input_dir pretrain_base/hf \\
|
| 25 |
+
--output_dir pretrain_base/hf_outer
|
| 26 |
+
|
| 27 |
+
# Custom worker count
|
| 28 |
+
python convert_ckpt_to_outer.py \\
|
| 29 |
+
--input_dir pretrain_base/hf \\
|
| 30 |
+
--output_dir pretrain_base/hf_outer \\
|
| 31 |
+
--workers 16
|
| 32 |
+
|
| 33 |
+
The script will:
|
| 34 |
+
1. Pre-scan index.json to detect cross-shard expert groups
|
| 35 |
+
2. Convert weights shard-by-shard in parallel (key rename + expert fuse)
|
| 36 |
+
3. Post-process cross-shard expert groups (merge from multiple shards)
|
| 37 |
+
- v2: shard-centric approach, each shard read/written only once
|
| 38 |
+
4. Copy config.json as-is (already in outer format)
|
| 39 |
+
5. Copy all other files (tokenizer, etc.)
|
| 40 |
+
6. Rebuild model.safetensors.index.json
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
import argparse
|
| 44 |
+
import gc
|
| 45 |
+
import json
|
| 46 |
+
import os
|
| 47 |
+
import re
|
| 48 |
+
import signal
|
| 49 |
+
import shutil
|
| 50 |
+
import sys
|
| 51 |
+
import time
|
| 52 |
+
import traceback
|
| 53 |
+
from collections import OrderedDict, defaultdict
|
| 54 |
+
from multiprocessing import Pool
|
| 55 |
+
|
| 56 |
+
import torch
|
| 57 |
+
|
| 58 |
+
try:
|
| 59 |
+
from safetensors import safe_open
|
| 60 |
+
from safetensors.torch import save_file
|
| 61 |
+
except ImportError:
|
| 62 |
+
raise ImportError("Please install safetensors: pip install safetensors")
|
| 63 |
+
|
| 64 |
+
# ============================================================================
|
| 65 |
+
# Signal handling for Bus error (SIGBUS) and other fatal signals
|
| 66 |
+
# ============================================================================
|
| 67 |
+
|
| 68 |
+
def _fatal_signal_handler(signum, frame):
|
| 69 |
+
"""Handle fatal signals (SIGBUS, SIGSEGV) by logging before exit.
|
| 70 |
+
|
| 71 |
+
These signals cannot be caught by try/except. This handler ensures
|
| 72 |
+
the error message is written to stderr (captured by nohup redirection)
|
| 73 |
+
before the process terminates.
|
| 74 |
+
"""
|
| 75 |
+
sig_name = signal.Signals(signum).name if hasattr(signal, 'Signals') else str(signum)
|
| 76 |
+
pid = os.getpid()
|
| 77 |
+
msg = (
|
| 78 |
+
f"\n[FATAL] Process {pid} received {sig_name} (signal {signum}).\n"
|
| 79 |
+
f"This typically indicates an out-of-memory condition during mmap I/O.\n"
|
| 80 |
+
f"Stack trace at time of signal:\n"
|
| 81 |
+
)
|
| 82 |
+
sys.stderr.write(msg)
|
| 83 |
+
traceback.print_stack(frame, file=sys.stderr)
|
| 84 |
+
sys.stderr.flush()
|
| 85 |
+
# Re-raise with default handler to get proper exit code
|
| 86 |
+
signal.signal(signum, signal.SIG_DFL)
|
| 87 |
+
os.kill(pid, signum)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def _install_signal_handlers():
|
| 91 |
+
"""Install handlers for SIGBUS and SIGSEGV in the current process."""
|
| 92 |
+
for sig in (signal.SIGBUS, signal.SIGSEGV):
|
| 93 |
+
try:
|
| 94 |
+
signal.signal(sig, _fatal_signal_handler)
|
| 95 |
+
except (OSError, ValueError):
|
| 96 |
+
# Some signals may not be available on all platforms
|
| 97 |
+
pass
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def _pool_worker_init():
|
| 101 |
+
"""Initializer for multiprocessing pool workers.
|
| 102 |
+
|
| 103 |
+
Installs signal handlers so that Bus errors in worker processes
|
| 104 |
+
are also logged before the process dies.
|
| 105 |
+
"""
|
| 106 |
+
_install_signal_handlers()
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
# ============================================================================
|
| 110 |
+
# Key rename mapping (inner -> outer)
|
| 111 |
+
# ============================================================================
|
| 112 |
+
|
| 113 |
+
_KEY_RENAMES = [
|
| 114 |
+
("mlp.router.gate.", "mlp.gate."),
|
| 115 |
+
("mlp.expert_bias", "mlp.e_score_correction_bias"),
|
| 116 |
+
("mlp.shared_mlp.", "mlp.shared_experts."),
|
| 117 |
+
]
|
| 118 |
+
|
| 119 |
+
# Regex to match per-expert keys
|
| 120 |
+
_EXPERT_KEY_RE = re.compile(
|
| 121 |
+
r"^(.*\.mlp\.experts\.)(\d+)\.(gate_proj|up_proj|down_proj)\.weight$"
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
def rename_key(key: str) -> str:
|
| 125 |
+
"""Rename a single key from inner to outer format."""
|
| 126 |
+
for old_sub, new_sub in _KEY_RENAMES:
|
| 127 |
+
if old_sub in key:
|
| 128 |
+
key = key.replace(old_sub, new_sub)
|
| 129 |
+
break
|
| 130 |
+
return key
|
| 131 |
+
|
| 132 |
+
def scan_cross_shard_experts(index_path: str):
|
| 133 |
+
"""Pre-scan index.json to find expert groups that span multiple shards.
|
| 134 |
+
|
| 135 |
+
Returns:
|
| 136 |
+
cross_shard_prefixes: set of expert prefixes that span multiple shards
|
| 137 |
+
e.g. {"model.layers.80.mlp.experts."}
|
| 138 |
+
"""
|
| 139 |
+
with open(index_path) as f:
|
| 140 |
+
index = json.load(f)
|
| 141 |
+
wm = index["weight_map"]
|
| 142 |
+
|
| 143 |
+
# prefix -> set of shards
|
| 144 |
+
prefix_shards = defaultdict(set)
|
| 145 |
+
for key in wm:
|
| 146 |
+
m = _EXPERT_KEY_RE.match(key)
|
| 147 |
+
if m:
|
| 148 |
+
prefix = m.group(1)
|
| 149 |
+
prefix_shards[prefix].add(wm[key])
|
| 150 |
+
|
| 151 |
+
cross_shard_prefixes = set()
|
| 152 |
+
for prefix, shards in prefix_shards.items():
|
| 153 |
+
if len(shards) > 1:
|
| 154 |
+
cross_shard_prefixes.add(prefix)
|
| 155 |
+
|
| 156 |
+
return cross_shard_prefixes
|
| 157 |
+
|
| 158 |
+
def convert_shard(shard_path: str, cross_shard_prefixes: set = None):
|
| 159 |
+
"""Load a single shard, rename keys, and fuse experts.
|
| 160 |
+
|
| 161 |
+
For expert groups in cross_shard_prefixes, the per-expert keys are
|
| 162 |
+
kept as-is (just renamed) and returned separately as deferred items,
|
| 163 |
+
to be merged later in a post-processing step.
|
| 164 |
+
|
| 165 |
+
Returns:
|
| 166 |
+
result: OrderedDict of converted tensors (ready to save)
|
| 167 |
+
deferred_expert_keys: list of original expert keys that were deferred
|
| 168 |
+
(these are kept in result with their original per-expert naming
|
| 169 |
+
but with the outer rename applied, to be post-processed later)
|
| 170 |
+
"""
|
| 171 |
+
if cross_shard_prefixes is None:
|
| 172 |
+
cross_shard_prefixes = set()
|
| 173 |
+
|
| 174 |
+
tensors = OrderedDict()
|
| 175 |
+
with safe_open(shard_path, framework="pt", device="cpu") as f:
|
| 176 |
+
for key in f.keys():
|
| 177 |
+
tensors[key] = f.get_tensor(key)
|
| 178 |
+
|
| 179 |
+
# Separate expert keys from non-expert keys
|
| 180 |
+
expert_groups = {} # prefix -> {expert_idx -> {proj_name -> tensor}}
|
| 181 |
+
deferred_expert_keys = [] # keys that belong to cross-shard experts
|
| 182 |
+
result = OrderedDict()
|
| 183 |
+
|
| 184 |
+
for key, tensor in tensors.items():
|
| 185 |
+
m = _EXPERT_KEY_RE.match(key)
|
| 186 |
+
if m:
|
| 187 |
+
prefix = m.group(1)
|
| 188 |
+
expert_idx = int(m.group(2))
|
| 189 |
+
proj_name = m.group(3)
|
| 190 |
+
|
| 191 |
+
if prefix in cross_shard_prefixes:
|
| 192 |
+
# Defer: keep the key as-is (with rename) for post-processing
|
| 193 |
+
new_key = rename_key(key)
|
| 194 |
+
result[new_key] = tensor
|
| 195 |
+
deferred_expert_keys.append(new_key)
|
| 196 |
+
else:
|
| 197 |
+
# Normal: collect for fusion within this shard
|
| 198 |
+
if prefix not in expert_groups:
|
| 199 |
+
expert_groups[prefix] = {}
|
| 200 |
+
if expert_idx not in expert_groups[prefix]:
|
| 201 |
+
expert_groups[prefix][expert_idx] = {}
|
| 202 |
+
expert_groups[prefix][expert_idx][proj_name] = tensor
|
| 203 |
+
else:
|
| 204 |
+
# Non-expert key: just rename
|
| 205 |
+
new_key = rename_key(key)
|
| 206 |
+
result[new_key] = tensor
|
| 207 |
+
|
| 208 |
+
# Fuse expert weights for each non-cross-shard layer prefix
|
| 209 |
+
for prefix in sorted(expert_groups.keys()):
|
| 210 |
+
experts = expert_groups[prefix]
|
| 211 |
+
num_experts = max(experts.keys()) + 1
|
| 212 |
+
|
| 213 |
+
gate_up_list = []
|
| 214 |
+
down_list = []
|
| 215 |
+
for i in range(num_experts):
|
| 216 |
+
if i not in experts:
|
| 217 |
+
raise ValueError(
|
| 218 |
+
f"Missing expert {i} in {prefix}. "
|
| 219 |
+
f"Found: {sorted(experts.keys())}"
|
| 220 |
+
)
|
| 221 |
+
exp = experts[i]
|
| 222 |
+
gate_up = torch.cat([exp["gate_proj"], exp["up_proj"]], dim=0)
|
| 223 |
+
gate_up_list.append(gate_up)
|
| 224 |
+
down_list.append(exp["down_proj"])
|
| 225 |
+
|
| 226 |
+
fused_gate_up = torch.stack(gate_up_list, dim=0)
|
| 227 |
+
fused_down = torch.stack(down_list, dim=0)
|
| 228 |
+
|
| 229 |
+
for exp in experts.values():
|
| 230 |
+
exp.clear()
|
| 231 |
+
gate_up_list.clear()
|
| 232 |
+
down_list.clear()
|
| 233 |
+
|
| 234 |
+
result[f"{prefix}gate_up_proj"] = fused_gate_up
|
| 235 |
+
result[f"{prefix}down_proj"] = fused_down
|
| 236 |
+
|
| 237 |
+
return result, deferred_expert_keys
|
| 238 |
+
|
| 239 |
+
def _process_one_shard(args_tuple):
|
| 240 |
+
"""Worker function: convert a single shard and save to output dir.
|
| 241 |
+
|
| 242 |
+
Args:
|
| 243 |
+
args_tuple: (idx, num_shards, shard_file, input_dir, output_dir, cross_shard_prefixes)
|
| 244 |
+
|
| 245 |
+
Returns:
|
| 246 |
+
(shard_file, key_list, shard_size, elapsed, deferred_keys)
|
| 247 |
+
"""
|
| 248 |
+
idx, num_shards, shard_file, input_dir, output_dir, cross_shard_prefixes = args_tuple
|
| 249 |
+
shard_path = os.path.join(input_dir, shard_file)
|
| 250 |
+
t0 = time.time()
|
| 251 |
+
|
| 252 |
+
converted, deferred_keys = convert_shard(shard_path, cross_shard_prefixes)
|
| 253 |
+
|
| 254 |
+
shard_size = sum(t.numel() * t.element_size() for t in converted.values())
|
| 255 |
+
|
| 256 |
+
out_shard_path = os.path.join(output_dir, shard_file)
|
| 257 |
+
save_file(converted, out_shard_path)
|
| 258 |
+
|
| 259 |
+
elapsed = time.time() - t0
|
| 260 |
+
num_keys = len(converted)
|
| 261 |
+
key_list = list(converted.keys())
|
| 262 |
+
|
| 263 |
+
del converted
|
| 264 |
+
|
| 265 |
+
deferred_info = ""
|
| 266 |
+
if deferred_keys:
|
| 267 |
+
deferred_info = f", Deferred={len(deferred_keys)}"
|
| 268 |
+
|
| 269 |
+
print(
|
| 270 |
+
f" [{idx + 1}/{num_shards}] {shard_file}: "
|
| 271 |
+
f"Keys={num_keys}, Size={shard_size / 1e9:.2f} GB, "
|
| 272 |
+
f"Time={elapsed:.1f}s{deferred_info}",
|
| 273 |
+
flush=True,
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
return shard_file, key_list, shard_size, elapsed, deferred_keys
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
def post_process_cross_shard_experts(output_dir, cross_shard_prefixes, all_deferred):
|
| 280 |
+
"""Merge cross-shard expert groups (v2: shard-centric approach).
|
| 281 |
+
|
| 282 |
+
Instead of iterating per-prefix (which causes the same shard to be
|
| 283 |
+
loaded/saved multiple times), this v2 approach:
|
| 284 |
+
1. Builds a mapping of which prefixes each shard is involved in
|
| 285 |
+
2. Collects all expert tensors from all involved shards in ONE pass
|
| 286 |
+
3. Fuses all prefixes
|
| 287 |
+
4. Writes each shard only ONCE with all its updates applied
|
| 288 |
+
|
| 289 |
+
This avoids the Bus error (core dump) caused by repeated mmap of
|
| 290 |
+
large files and memory bloat.
|
| 291 |
+
|
| 292 |
+
Args:
|
| 293 |
+
output_dir: path to output directory
|
| 294 |
+
cross_shard_prefixes: set of expert prefixes that span multiple shards
|
| 295 |
+
all_deferred: dict of {shard_file: [deferred_key, ...]}
|
| 296 |
+
|
| 297 |
+
Returns:
|
| 298 |
+
updated_shards: dict of {shard_file: (key_list, shard_size)} for updated shards
|
| 299 |
+
"""
|
| 300 |
+
if not cross_shard_prefixes:
|
| 301 |
+
return {}
|
| 302 |
+
|
| 303 |
+
print(f"\n Post-processing {len(cross_shard_prefixes)} cross-shard expert group(s)...",
|
| 304 |
+
flush=True)
|
| 305 |
+
|
| 306 |
+
# ----------------------------------------------------------------
|
| 307 |
+
# Step 1: Build mappings
|
| 308 |
+
# ----------------------------------------------------------------
|
| 309 |
+
# prefix -> ordered list of shards that contain its experts
|
| 310 |
+
prefix_to_shards = defaultdict(set)
|
| 311 |
+
# shard -> set of prefixes it is involved in
|
| 312 |
+
shard_to_prefixes = defaultdict(set)
|
| 313 |
+
|
| 314 |
+
for shard_file, deferred_keys in all_deferred.items():
|
| 315 |
+
for key in deferred_keys:
|
| 316 |
+
m = _EXPERT_KEY_RE.match(key)
|
| 317 |
+
if m:
|
| 318 |
+
prefix = m.group(1)
|
| 319 |
+
if prefix in cross_shard_prefixes:
|
| 320 |
+
prefix_to_shards[prefix].add(shard_file)
|
| 321 |
+
shard_to_prefixes[shard_file].add(prefix)
|
| 322 |
+
|
| 323 |
+
# For each prefix, decide which shard will hold the fused result
|
| 324 |
+
# (use the first shard alphabetically)
|
| 325 |
+
prefix_to_target_shard = {}
|
| 326 |
+
for prefix in sorted(prefix_to_shards.keys()):
|
| 327 |
+
target = sorted(prefix_to_shards[prefix])[0]
|
| 328 |
+
prefix_to_target_shard[prefix] = target
|
| 329 |
+
|
| 330 |
+
# All shards that need to be updated
|
| 331 |
+
all_involved_shards = set()
|
| 332 |
+
for shards in prefix_to_shards.values():
|
| 333 |
+
all_involved_shards.update(shards)
|
| 334 |
+
|
| 335 |
+
print(f" Involved shards: {len(all_involved_shards)}", flush=True)
|
| 336 |
+
print(f" Expert groups: {len(prefix_to_shards)}", flush=True)
|
| 337 |
+
|
| 338 |
+
# ----------------------------------------------------------------
|
| 339 |
+
# Step 2: Collect all expert tensors from all involved shards
|
| 340 |
+
# (one pass per shard)
|
| 341 |
+
# ----------------------------------------------------------------
|
| 342 |
+
# prefix -> {expert_idx -> {proj_name -> tensor}}
|
| 343 |
+
all_expert_data = defaultdict(dict)
|
| 344 |
+
# shard -> OrderedDict of non-expert keys (to be re-saved)
|
| 345 |
+
shard_non_expert = {}
|
| 346 |
+
|
| 347 |
+
sorted_involved = sorted(all_involved_shards)
|
| 348 |
+
for si, shard_file in enumerate(sorted_involved):
|
| 349 |
+
shard_path = os.path.join(output_dir, shard_file)
|
| 350 |
+
prefixes_in_shard = shard_to_prefixes[shard_file]
|
| 351 |
+
|
| 352 |
+
print(f" [{si+1}/{len(sorted_involved)}] Reading {shard_file} "
|
| 353 |
+
f"({len(prefixes_in_shard)} prefix(es))...", flush=True)
|
| 354 |
+
|
| 355 |
+
non_expert = OrderedDict()
|
| 356 |
+
with safe_open(shard_path, framework="pt", device="cpu") as f:
|
| 357 |
+
for key in f.keys():
|
| 358 |
+
m = _EXPERT_KEY_RE.match(key)
|
| 359 |
+
if m and m.group(1) in prefixes_in_shard:
|
| 360 |
+
# This is a deferred expert key
|
| 361 |
+
prefix = m.group(1)
|
| 362 |
+
expert_idx = int(m.group(2))
|
| 363 |
+
proj_name = m.group(3)
|
| 364 |
+
if expert_idx not in all_expert_data[prefix]:
|
| 365 |
+
all_expert_data[prefix][expert_idx] = {}
|
| 366 |
+
all_expert_data[prefix][expert_idx][proj_name] = f.get_tensor(key)
|
| 367 |
+
else:
|
| 368 |
+
# Non-expert key: keep as-is
|
| 369 |
+
non_expert[key] = f.get_tensor(key)
|
| 370 |
+
|
| 371 |
+
shard_non_expert[shard_file] = non_expert
|
| 372 |
+
gc.collect()
|
| 373 |
+
|
| 374 |
+
# ----------------------------------------------------------------
|
| 375 |
+
# Step 3: Fuse all expert groups
|
| 376 |
+
# ----------------------------------------------------------------
|
| 377 |
+
# prefix -> {"gate_up_proj": tensor, "down_proj": tensor}
|
| 378 |
+
fused_results = {}
|
| 379 |
+
|
| 380 |
+
for pi, prefix in enumerate(sorted(all_expert_data.keys())):
|
| 381 |
+
expert_data = all_expert_data[prefix]
|
| 382 |
+
num_experts = max(expert_data.keys()) + 1
|
| 383 |
+
|
| 384 |
+
print(f" Fusing {prefix} ({num_experts} experts)...", flush=True)
|
| 385 |
+
|
| 386 |
+
gate_up_list = []
|
| 387 |
+
down_list = []
|
| 388 |
+
for i in range(num_experts):
|
| 389 |
+
if i not in expert_data:
|
| 390 |
+
raise ValueError(
|
| 391 |
+
f"Missing expert {i} in {prefix} after cross-shard merge. "
|
| 392 |
+
f"Found: {sorted(expert_data.keys())}"
|
| 393 |
+
)
|
| 394 |
+
exp = expert_data[i]
|
| 395 |
+
if "gate_proj" not in exp or "up_proj" not in exp:
|
| 396 |
+
raise ValueError(
|
| 397 |
+
f"Expert {i} in {prefix} missing gate_proj/up_proj. "
|
| 398 |
+
f"Has: {sorted(exp.keys())}"
|
| 399 |
+
)
|
| 400 |
+
if "down_proj" not in exp:
|
| 401 |
+
raise ValueError(
|
| 402 |
+
f"Expert {i} in {prefix} missing down_proj. "
|
| 403 |
+
f"Has: {sorted(exp.keys())}"
|
| 404 |
+
)
|
| 405 |
+
gate_up = torch.cat([exp["gate_proj"], exp["up_proj"]], dim=0)
|
| 406 |
+
gate_up_list.append(gate_up)
|
| 407 |
+
down_list.append(exp["down_proj"])
|
| 408 |
+
|
| 409 |
+
fused_gate_up = torch.stack(gate_up_list, dim=0)
|
| 410 |
+
fused_down = torch.stack(down_list, dim=0)
|
| 411 |
+
|
| 412 |
+
fused_results[prefix] = {
|
| 413 |
+
"gate_up_proj": fused_gate_up,
|
| 414 |
+
"down_proj": fused_down,
|
| 415 |
+
}
|
| 416 |
+
|
| 417 |
+
# Free per-expert data for this prefix
|
| 418 |
+
del gate_up_list, down_list
|
| 419 |
+
for exp in expert_data.values():
|
| 420 |
+
exp.clear()
|
| 421 |
+
del all_expert_data[prefix]
|
| 422 |
+
gc.collect()
|
| 423 |
+
|
| 424 |
+
del all_expert_data
|
| 425 |
+
gc.collect()
|
| 426 |
+
|
| 427 |
+
# ----------------------------------------------------------------
|
| 428 |
+
# Step 4: Write each involved shard ONCE with all updates applied
|
| 429 |
+
# ----------------------------------------------------------------
|
| 430 |
+
updated_shards = {}
|
| 431 |
+
|
| 432 |
+
for si, shard_file in enumerate(sorted_involved):
|
| 433 |
+
shard_path = os.path.join(output_dir, shard_file)
|
| 434 |
+
non_expert = shard_non_expert[shard_file]
|
| 435 |
+
|
| 436 |
+
# Add fused tensors for prefixes that target this shard
|
| 437 |
+
fused_added = []
|
| 438 |
+
for prefix, target_shard in prefix_to_target_shard.items():
|
| 439 |
+
if target_shard == shard_file and prefix in fused_results:
|
| 440 |
+
non_expert[f"{prefix}gate_up_proj"] = fused_results[prefix]["gate_up_proj"]
|
| 441 |
+
non_expert[f"{prefix}down_proj"] = fused_results[prefix]["down_proj"]
|
| 442 |
+
fused_added.append(prefix)
|
| 443 |
+
|
| 444 |
+
save_file(non_expert, shard_path)
|
| 445 |
+
shard_size = sum(t.numel() * t.element_size() for t in non_expert.values())
|
| 446 |
+
updated_shards[shard_file] = (list(non_expert.keys()), shard_size)
|
| 447 |
+
|
| 448 |
+
fused_info = ""
|
| 449 |
+
if fused_added:
|
| 450 |
+
fused_info = f", Fused {len(fused_added)} group(s)"
|
| 451 |
+
|
| 452 |
+
print(f" [{si+1}/{len(sorted_involved)}] Wrote {shard_file}: "
|
| 453 |
+
f"{len(non_expert)} keys, {shard_size / 1e9:.2f} GB{fused_info}",
|
| 454 |
+
flush=True)
|
| 455 |
+
|
| 456 |
+
# Free memory for this shard
|
| 457 |
+
del shard_non_expert[shard_file]
|
| 458 |
+
for prefix in fused_added:
|
| 459 |
+
del fused_results[prefix]
|
| 460 |
+
del non_expert
|
| 461 |
+
gc.collect()
|
| 462 |
+
|
| 463 |
+
return updated_shards
|
| 464 |
+
|
| 465 |
+
|
| 466 |
+
def main():
|
| 467 |
+
parser = argparse.ArgumentParser(
|
| 468 |
+
description="Convert HYV3 checkpoint from inner to outer format (v2, shard-centric post-processing)."
|
| 469 |
+
)
|
| 470 |
+
parser.add_argument(
|
| 471 |
+
"--input_dir", type=str, required=True,
|
| 472 |
+
help="Path to the inner-format checkpoint directory.",
|
| 473 |
+
)
|
| 474 |
+
parser.add_argument(
|
| 475 |
+
"--output_dir", type=str, required=True,
|
| 476 |
+
help="Path to the output outer-format checkpoint directory.",
|
| 477 |
+
)
|
| 478 |
+
parser.add_argument(
|
| 479 |
+
"--workers", type=int, default=8,
|
| 480 |
+
help="Number of parallel worker processes (default: 8).",
|
| 481 |
+
)
|
| 482 |
+
args = parser.parse_args()
|
| 483 |
+
|
| 484 |
+
input_dir = os.path.abspath(args.input_dir)
|
| 485 |
+
output_dir = os.path.abspath(args.output_dir)
|
| 486 |
+
num_workers = args.workers
|
| 487 |
+
|
| 488 |
+
if not os.path.isdir(input_dir):
|
| 489 |
+
raise FileNotFoundError(f"Input directory not found: {input_dir}")
|
| 490 |
+
|
| 491 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 492 |
+
|
| 493 |
+
# Pre-scan for cross-shard expert groups
|
| 494 |
+
index_path = os.path.join(input_dir, "model.safetensors.index.json")
|
| 495 |
+
cross_shard_prefixes = set()
|
| 496 |
+
if os.path.exists(index_path):
|
| 497 |
+
cross_shard_prefixes = scan_cross_shard_experts(index_path)
|
| 498 |
+
if cross_shard_prefixes:
|
| 499 |
+
print(f"Detected {len(cross_shard_prefixes)} cross-shard expert group(s):")
|
| 500 |
+
for p in sorted(cross_shard_prefixes):
|
| 501 |
+
print(f" - {p}")
|
| 502 |
+
print()
|
| 503 |
+
|
| 504 |
+
# Get all safetensors files
|
| 505 |
+
shard_files = sorted(
|
| 506 |
+
f for f in os.listdir(input_dir) if f.endswith(".safetensors")
|
| 507 |
+
)
|
| 508 |
+
if not shard_files:
|
| 509 |
+
raise FileNotFoundError(f"No .safetensors files found in {input_dir}")
|
| 510 |
+
|
| 511 |
+
# Skip already-converted shards (for resumability)
|
| 512 |
+
# NOTE: if there are cross-shard experts, we cannot skip shards that
|
| 513 |
+
# contain deferred keys (they need post-processing). For simplicity,
|
| 514 |
+
# when cross-shard experts exist, we re-process all shards.
|
| 515 |
+
remaining = []
|
| 516 |
+
skipped = []
|
| 517 |
+
if cross_shard_prefixes:
|
| 518 |
+
# Re-process all shards when cross-shard experts exist
|
| 519 |
+
remaining = list(shard_files)
|
| 520 |
+
else:
|
| 521 |
+
for sf in shard_files:
|
| 522 |
+
out_path = os.path.join(output_dir, sf)
|
| 523 |
+
if os.path.exists(out_path) and os.path.getsize(out_path) > 0:
|
| 524 |
+
skipped.append(sf)
|
| 525 |
+
else:
|
| 526 |
+
remaining.append(sf)
|
| 527 |
+
|
| 528 |
+
num_shards = len(shard_files)
|
| 529 |
+
num_workers = min(num_workers, len(remaining)) if remaining else 1
|
| 530 |
+
|
| 531 |
+
print(f"=" * 60)
|
| 532 |
+
print(f"HYV3 Checkpoint Converter (inner -> outer, v2)")
|
| 533 |
+
print(f" Input : {input_dir}")
|
| 534 |
+
print(f" Output : {output_dir}")
|
| 535 |
+
print(f" Shards : {num_shards} total, {len(skipped)} already done, {len(remaining)} to process")
|
| 536 |
+
print(f" Workers: {num_workers}")
|
| 537 |
+
if cross_shard_prefixes:
|
| 538 |
+
print(f" Cross-shard experts: {len(cross_shard_prefixes)} group(s) (will post-process)")
|
| 539 |
+
print(f"=" * 60)
|
| 540 |
+
|
| 541 |
+
t_start = time.time()
|
| 542 |
+
|
| 543 |
+
# Build task list for remaining shards
|
| 544 |
+
tasks = [
|
| 545 |
+
(i, len(remaining), sf, input_dir, output_dir, cross_shard_prefixes)
|
| 546 |
+
for i, sf in enumerate(remaining)
|
| 547 |
+
]
|
| 548 |
+
|
| 549 |
+
# Process in parallel
|
| 550 |
+
results = []
|
| 551 |
+
if tasks:
|
| 552 |
+
with Pool(processes=num_workers, initializer=_pool_worker_init) as pool:
|
| 553 |
+
results = pool.map(_process_one_shard, tasks)
|
| 554 |
+
|
| 555 |
+
# Collect deferred keys info
|
| 556 |
+
all_deferred = {} # shard_file -> [deferred_keys]
|
| 557 |
+
for shard_file, key_list, shard_size, elapsed, deferred_keys in results:
|
| 558 |
+
if deferred_keys:
|
| 559 |
+
all_deferred[shard_file] = deferred_keys
|
| 560 |
+
|
| 561 |
+
# Post-process cross-shard expert groups (v2: shard-centric)
|
| 562 |
+
updated_shards = {}
|
| 563 |
+
if cross_shard_prefixes and all_deferred:
|
| 564 |
+
updated_shards = post_process_cross_shard_experts(
|
| 565 |
+
output_dir, cross_shard_prefixes, all_deferred
|
| 566 |
+
)
|
| 567 |
+
|
| 568 |
+
# Build weight_map and total_size
|
| 569 |
+
weight_map = OrderedDict()
|
| 570 |
+
total_size = 0
|
| 571 |
+
|
| 572 |
+
# For skipped shards, read their keys from the output files
|
| 573 |
+
for sf in skipped:
|
| 574 |
+
out_path = os.path.join(output_dir, sf)
|
| 575 |
+
with safe_open(out_path, framework="pt", device="cpu") as f:
|
| 576 |
+
keys = list(f.keys())
|
| 577 |
+
for key in keys:
|
| 578 |
+
weight_map[key] = sf
|
| 579 |
+
t = f.get_tensor(key)
|
| 580 |
+
total_size += t.numel() * t.element_size()
|
| 581 |
+
|
| 582 |
+
# Collect results from newly converted shards
|
| 583 |
+
for shard_file, key_list, shard_size, elapsed, deferred_keys in results:
|
| 584 |
+
if shard_file in updated_shards:
|
| 585 |
+
# This shard was updated by post-processing
|
| 586 |
+
updated_key_list, updated_size = updated_shards[shard_file]
|
| 587 |
+
for key in updated_key_list:
|
| 588 |
+
weight_map[key] = shard_file
|
| 589 |
+
total_size += updated_size
|
| 590 |
+
else:
|
| 591 |
+
for key in key_list:
|
| 592 |
+
weight_map[key] = shard_file
|
| 593 |
+
total_size += shard_size
|
| 594 |
+
|
| 595 |
+
# Build and save index
|
| 596 |
+
sorted_weight_map = OrderedDict(sorted(weight_map.items()))
|
| 597 |
+
index = {
|
| 598 |
+
"metadata": {"total_size": total_size},
|
| 599 |
+
"weight_map": sorted_weight_map,
|
| 600 |
+
}
|
| 601 |
+
index_path_out = os.path.join(output_dir, "model.safetensors.index.json")
|
| 602 |
+
with open(index_path_out, "w") as f:
|
| 603 |
+
json.dump(index, f, indent=2)
|
| 604 |
+
f.write("\n")
|
| 605 |
+
print(f"\nSaved {index_path_out}")
|
| 606 |
+
|
| 607 |
+
# Copy non-safetensors files (config, tokenizer, etc.)
|
| 608 |
+
skip_suffixes = {".safetensors"}
|
| 609 |
+
skip_names = {"model.safetensors.index.json"}
|
| 610 |
+
copied = []
|
| 611 |
+
for fname in os.listdir(input_dir):
|
| 612 |
+
if fname in skip_names:
|
| 613 |
+
continue
|
| 614 |
+
if any(fname.endswith(s) for s in skip_suffixes):
|
| 615 |
+
continue
|
| 616 |
+
src = os.path.join(input_dir, fname)
|
| 617 |
+
dst = os.path.join(output_dir, fname)
|
| 618 |
+
if os.path.isfile(src):
|
| 619 |
+
shutil.copy2(src, dst)
|
| 620 |
+
copied.append(fname)
|
| 621 |
+
elif os.path.isdir(src):
|
| 622 |
+
if os.path.exists(dst):
|
| 623 |
+
shutil.rmtree(dst)
|
| 624 |
+
shutil.copytree(src, dst)
|
| 625 |
+
copied.append(fname + "/")
|
| 626 |
+
|
| 627 |
+
if copied:
|
| 628 |
+
print(f"\nCopied files: {', '.join(copied)}")
|
| 629 |
+
|
| 630 |
+
t_total = time.time() - t_start
|
| 631 |
+
print(f"\n{'=' * 60}")
|
| 632 |
+
print(f"Conversion complete!")
|
| 633 |
+
print(f" Total keys : {len(weight_map)}")
|
| 634 |
+
print(f" Total size : {total_size / 1e9:.2f} GB")
|
| 635 |
+
print(f" Total time : {t_total:.1f}s ({t_total / 60:.1f} min)")
|
| 636 |
+
print(f" Output dir : {output_dir}")
|
| 637 |
+
print(f"{'=' * 60}")
|
| 638 |
+
|
| 639 |
+
if __name__ == "__main__":
|
| 640 |
+
_install_signal_handlers()
|
| 641 |
+
main()
|
train/train.py
ADDED
|
@@ -0,0 +1,564 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Tencent Inc. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
| 16 |
+
#
|
| 17 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
| 18 |
+
# and OPT implementations in this library. It has been modified from its
|
| 19 |
+
# original forms to accommodate minor architectural differences compared
|
| 20 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
| 21 |
+
#
|
| 22 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 23 |
+
# you may not use this file except in compliance with the License.
|
| 24 |
+
# You may obtain a copy of the License at
|
| 25 |
+
#
|
| 26 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 27 |
+
#
|
| 28 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 29 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 30 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 31 |
+
# See the License for the specific language governing permissions and
|
| 32 |
+
# limitations under the License.
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
import os
|
| 36 |
+
import re
|
| 37 |
+
import sys
|
| 38 |
+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 39 |
+
import json
|
| 40 |
+
import torch
|
| 41 |
+
import shutil
|
| 42 |
+
import logging
|
| 43 |
+
from dataclasses import dataclass, field
|
| 44 |
+
import deepspeed
|
| 45 |
+
from typing import Optional, Dict
|
| 46 |
+
|
| 47 |
+
import transformers
|
| 48 |
+
from torch.utils.data import Dataset
|
| 49 |
+
from transformers import Trainer, TrainerCallback
|
| 50 |
+
from peft import LoraConfig, get_peft_model, PeftModel
|
| 51 |
+
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
| 52 |
+
from transformers.modeling_utils import unwrap_model
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def print_args(args, name='arguments'):
|
| 56 |
+
"""Print arguments."""
|
| 57 |
+
if torch.distributed.get_rank() == 0:
|
| 58 |
+
print(f'------------------------ {name} ------------------------', flush=True)
|
| 59 |
+
str_list = []
|
| 60 |
+
for arg in vars(args):
|
| 61 |
+
dots = '.' * (48 - len(arg))
|
| 62 |
+
str_list.append(' {} {} {}'.format(arg, dots, getattr(args, arg)))
|
| 63 |
+
for arg in sorted(str_list, key=lambda x: x.lower()):
|
| 64 |
+
print(arg, flush=True)
|
| 65 |
+
print(f'-------------------- end of {name} ---------------------', flush=True)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
@dataclass
|
| 69 |
+
class ModelArguments:
|
| 70 |
+
use_flash_attn: bool = field(
|
| 71 |
+
default=False,
|
| 72 |
+
metadata={"help": "Enable FlashAttention-2 for faster training."}
|
| 73 |
+
)
|
| 74 |
+
use_lora: bool = field(default=False, metadata={"help": "Enable Lora for faster training."})
|
| 75 |
+
hidden_size: int = field(default=2048, metadata={"help": "The hidden size of the model."})
|
| 76 |
+
num_layers: int = field(default=24, metadata={"help": "The number of layers of the model."})
|
| 77 |
+
num_attention_heads: int = field(default=16, metadata={"help": "The number of attention heads of the model."})
|
| 78 |
+
intermediate_size: int = field(default=8192, metadata={"help": "The intermediate size of the model."})
|
| 79 |
+
max_position_embeddings: int = field(
|
| 80 |
+
default=2048,
|
| 81 |
+
metadata={"help": "The maximum sequence length that this model might ever be used with."}
|
| 82 |
+
)
|
| 83 |
+
vocab_size: int = field(default=50257, metadata={"help": "The vocabulary size of the model."})
|
| 84 |
+
type_vocab_size: int = field(default=1, metadata={"help": "The vocabulary size of the model."})
|
| 85 |
+
layer_norm_eps: float = field(
|
| 86 |
+
default=1e-5,
|
| 87 |
+
metadata={"help": "The epsilon used by the layer normalization layers of the model."}
|
| 88 |
+
)
|
| 89 |
+
moe_topk: int = field(default=4, metadata={"help": "The topk for MOE."})
|
| 90 |
+
num_experts: int = field(default=8, metadata={"help": "The number of experts for MOE."})
|
| 91 |
+
num_key_value_heads: int = field(default=16, metadata={"help": "The number of key-value heads in GQA."})
|
| 92 |
+
moe_intermediate_size: int = field(default=1536, metadata={"help": "The intermediate size of each MoE expert."})
|
| 93 |
+
use_mixed_mlp_moe: bool = field(
|
| 94 |
+
default=False,
|
| 95 |
+
metadata={"help": "Whether to use mixed MoE with shared expert."}
|
| 96 |
+
)
|
| 97 |
+
num_shared_expert: int = field(default=1, metadata={"help": "Number of shared experts."})
|
| 98 |
+
use_qk_norm: bool = field(default=False, metadata={"help": "Whether to use qk norm."})
|
| 99 |
+
moe_layer_num_skipped: int = field(default=1, metadata={"help": "Number of initial dense layers before MoE layers."})
|
| 100 |
+
tie_word_embeddings: bool = field(
|
| 101 |
+
default=True,
|
| 102 |
+
metadata={"help": "Whether to tie the word embeddings of the encoder and the decoder."}
|
| 103 |
+
)
|
| 104 |
+
lora_rank: int = field(default=64, metadata={"help": "The rank of lora."})
|
| 105 |
+
lora_alpha: int = field(default=8, metadata={"help": "Lora alpha"})
|
| 106 |
+
lora_dropout: float = field(default=0.0, metadata={"help": "Lora dropout"})
|
| 107 |
+
train_attention_params_only: bool = field(default=False, metadata={
|
| 108 |
+
"help": "Whether to train attention parameters only."}
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
@dataclass
|
| 113 |
+
class DataArguments:
|
| 114 |
+
train_data_file: str = field(default=None, metadata={"help": "Path to the training data."})
|
| 115 |
+
max_seq_length: int = field(
|
| 116 |
+
default=2048,
|
| 117 |
+
metadata={"help": "The max sequence length of the model inputs after tokenization."}
|
| 118 |
+
)
|
| 119 |
+
complex_data: Optional[str] = field(default=None)
|
| 120 |
+
use_dummy_data: bool = field(default=False, metadata={"help": "Use dummy data."})
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
@dataclass
|
| 124 |
+
class TrainingArguments(transformers.TrainingArguments):
|
| 125 |
+
cache_dir: Optional[str] = field(default=None)
|
| 126 |
+
optim: str = field(default="adamw_torch")
|
| 127 |
+
model_max_length: int = field(
|
| 128 |
+
default=2048,
|
| 129 |
+
metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
|
| 130 |
+
)
|
| 131 |
+
tokenizer_name_or_path: Optional[str] = field(default=None)
|
| 132 |
+
model_name_or_path: Optional[str] = field(default=None)
|
| 133 |
+
min_lr: float = field(
|
| 134 |
+
default=0.01,
|
| 135 |
+
metadata={"help": "The final learning rate at the end of the decay will be learning_rate * min_lr"}
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
IGNORE_INDEX = -100
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
class DummyDataset(Dataset):
|
| 143 |
+
def __init__(self, tokenizer, max_seq_length=512, length=1000):
|
| 144 |
+
self.tokenizer = tokenizer
|
| 145 |
+
self.max_seq_length = max_seq_length
|
| 146 |
+
self.length = length
|
| 147 |
+
|
| 148 |
+
def __len__(self):
|
| 149 |
+
return self.length
|
| 150 |
+
|
| 151 |
+
def __getitem__(self, index):
|
| 152 |
+
tokens = torch.randint(0, self.tokenizer.vocab_size, (self.max_seq_length, ))
|
| 153 |
+
return {'input_ids': tokens, 'labels': tokens}
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
class SFTDataset(Dataset):
|
| 157 |
+
def __init__(self, data_file, tokenizer, max_seq_length = 2048, prompt_format = 'mplus'):
|
| 158 |
+
self.tokenizer = tokenizer
|
| 159 |
+
self.prompt_format = prompt_format
|
| 160 |
+
self.max_seq_length = max_seq_length
|
| 161 |
+
|
| 162 |
+
self.data_list = self.load_data(data_file)
|
| 163 |
+
|
| 164 |
+
def __len__(self):
|
| 165 |
+
return len(self.data_list)
|
| 166 |
+
|
| 167 |
+
def load_data(self, data_file):
|
| 168 |
+
logging.info('Loading data: {}'.format(data_file))
|
| 169 |
+
with open(data_file, 'r', encoding='utf8') as f:
|
| 170 |
+
data_list = f.readlines()
|
| 171 |
+
logging.info("there are {} data in dataset".format(len(data_list)))
|
| 172 |
+
return data_list
|
| 173 |
+
|
| 174 |
+
def encode_data(self, data_dict):
|
| 175 |
+
model_inputs = {}
|
| 176 |
+
reasoning_effort = data_dict.get('reasoning_effort', None)
|
| 177 |
+
if reasoning_effort is None:
|
| 178 |
+
reasoning_effort = 'no_think'
|
| 179 |
+
template_output = self.tokenizer.apply_chat_template(data_dict['messages'], tokenize=True, return_dict=False, is_training=True, reasoning_effort=reasoning_effort)
|
| 180 |
+
if isinstance(template_output, list) and len(template_output) > 0 and isinstance(template_output[0], list):
|
| 181 |
+
template_output = template_output[0]
|
| 182 |
+
message_tokens = torch.tensor(template_output, dtype=torch.long)
|
| 183 |
+
|
| 184 |
+
# Use new HunYuan tokenizer special tokens
|
| 185 |
+
assistant_token_id = self.tokenizer.convert_tokens_to_ids('<|hy_Assistant|>')
|
| 186 |
+
eos_token_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.eos_token)
|
| 187 |
+
pad_token_id = self.tokenizer.pad_token_id
|
| 188 |
+
|
| 189 |
+
# Find assistant reply boundaries: starts at <|hy_Assistant|>, ends at eos_token
|
| 190 |
+
loss_token_begins = (message_tokens == assistant_token_id).nonzero(as_tuple=True)[0].tolist()
|
| 191 |
+
loss_token_ends = (message_tokens == eos_token_id).nonzero(as_tuple=True)[0].tolist()
|
| 192 |
+
message_labels = torch.tensor([IGNORE_INDEX] * message_tokens.shape[0])
|
| 193 |
+
for begin_idx, end_idx in zip(loss_token_begins, loss_token_ends):
|
| 194 |
+
# Compute loss from the token after <|hy_Assistant|> to eos_token (inclusive)
|
| 195 |
+
message_labels[begin_idx + 1:end_idx + 1] = message_tokens[begin_idx + 1:end_idx + 1]
|
| 196 |
+
input_ids = message_tokens.to(torch.long)
|
| 197 |
+
labels = message_labels.to(torch.long)
|
| 198 |
+
|
| 199 |
+
input_ids = input_ids[:self.max_seq_length]
|
| 200 |
+
labels = labels[:self.max_seq_length]
|
| 201 |
+
attention_mask = [1 if val != pad_token_id else 0 for val in input_ids]
|
| 202 |
+
model_inputs["input_ids"] = input_ids
|
| 203 |
+
model_inputs["attention_mask"] = torch.tensor(attention_mask, dtype=torch.bool)
|
| 204 |
+
model_inputs["labels"] = labels
|
| 205 |
+
|
| 206 |
+
return model_inputs
|
| 207 |
+
|
| 208 |
+
def __getitem__(self, index):
|
| 209 |
+
data = self.data_list[index]
|
| 210 |
+
data = json.loads(data)
|
| 211 |
+
model_inputs = self.encode_data(data)
|
| 212 |
+
|
| 213 |
+
return model_inputs
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
@dataclass
|
| 217 |
+
class DataCollatorForSupervisedDataset(object):
|
| 218 |
+
"""Collate examples for supervised fine-tuning."""
|
| 219 |
+
|
| 220 |
+
tokenizer: transformers.PreTrainedTokenizer
|
| 221 |
+
|
| 222 |
+
def __call__(self, instances):
|
| 223 |
+
input_ids = [instance['input_ids'] for instance in instances]
|
| 224 |
+
labels = [instance['labels'] for instance in instances]
|
| 225 |
+
pad_token_id = self.tokenizer.pad_token_id
|
| 226 |
+
input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=pad_token_id)
|
| 227 |
+
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
|
| 228 |
+
return dict(
|
| 229 |
+
input_ids=input_ids,
|
| 230 |
+
labels=labels,
|
| 231 |
+
attention_mask=input_ids.ne(pad_token_id),
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
def make_supervised_data_module(tokenizer, data_args) -> Dict:
|
| 236 |
+
"""Make dataset and collator for supervised fine-tuning."""
|
| 237 |
+
if data_args.use_dummy_data:
|
| 238 |
+
train_dataset = DummyDataset(tokenizer, data_args.max_seq_length)
|
| 239 |
+
else:
|
| 240 |
+
train_dataset = SFTDataset(
|
| 241 |
+
tokenizer=tokenizer,
|
| 242 |
+
data_file=data_args.train_data_file,
|
| 243 |
+
max_seq_length=data_args.max_seq_length
|
| 244 |
+
)
|
| 245 |
+
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
|
| 246 |
+
return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator)
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
# for full model training, change the config.json, copy the model and configuration to support Auto load
|
| 250 |
+
class CustomSaveCallback(TrainerCallback):
|
| 251 |
+
def on_save(self, args, state, control, **kwargs):
|
| 252 |
+
if torch.distributed.get_rank() == 0:
|
| 253 |
+
output_dir = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
|
| 254 |
+
|
| 255 |
+
# Copy tokenizer files to checkpoint directory
|
| 256 |
+
tokenizer_files = [
|
| 257 |
+
'generation_config.json',
|
| 258 |
+
'hy.tiktoken',
|
| 259 |
+
'tokenizer_config.json',
|
| 260 |
+
'tokenization_hy.py',
|
| 261 |
+
'tokenizer.json',
|
| 262 |
+
'special_tokens_map.json',
|
| 263 |
+
'chat_template.jinja',
|
| 264 |
+
]
|
| 265 |
+
for fname in tokenizer_files:
|
| 266 |
+
src = os.path.join(args.tokenizer_name_or_path, fname)
|
| 267 |
+
if os.path.isfile(src):
|
| 268 |
+
shutil.copy(src, os.path.join(output_dir, fname))
|
| 269 |
+
|
| 270 |
+
return control
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def train():
|
| 274 |
+
parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
|
| 275 |
+
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
| 276 |
+
print_args(model_args, 'model arguments')
|
| 277 |
+
print_args(data_args, 'data arguments')
|
| 278 |
+
print_args(training_args, 'training arguments')
|
| 279 |
+
|
| 280 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
| 281 |
+
training_args.tokenizer_name_or_path,
|
| 282 |
+
trust_remote_code = True
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
init_kwargs = {}
|
| 286 |
+
if model_args.use_flash_attn:
|
| 287 |
+
init_kwargs["attn_implementation"] = "flash_attention_2"
|
| 288 |
+
# Workaround: transformers >= 5.x uses importlib.metadata.packages_distributions()
|
| 289 |
+
# to verify flash-attn package name, which fails when the package is installed under
|
| 290 |
+
# a custom distribution name (e.g. ptm-flash-attn). Patch the check to skip it.
|
| 291 |
+
try:
|
| 292 |
+
from transformers.modeling_flash_attention_utils import FLASH_ATTENTION_COMPATIBILITY_MATRIX
|
| 293 |
+
_orig_pkg_check = FLASH_ATTENTION_COMPATIBILITY_MATRIX[2]["pkg_availability_check"]
|
| 294 |
+
FLASH_ATTENTION_COMPATIBILITY_MATRIX[2]["pkg_availability_check"] = lambda *a, **kw: True
|
| 295 |
+
print("[Patch] Bypassed flash_attn package distribution name check for FA2.")
|
| 296 |
+
except Exception as e:
|
| 297 |
+
print(f"[Patch] Could not patch FA2 pkg check (non-fatal): {e}")
|
| 298 |
+
if training_args.bf16:
|
| 299 |
+
init_kwargs["dtype"] = torch.bfloat16
|
| 300 |
+
elif training_args.fp16:
|
| 301 |
+
init_kwargs["dtype"] = torch.float16
|
| 302 |
+
|
| 303 |
+
# Check if model weights exist (not just the directory)
|
| 304 |
+
_has_weights = (
|
| 305 |
+
training_args.model_name_or_path is not None
|
| 306 |
+
and os.path.isdir(training_args.model_name_or_path)
|
| 307 |
+
and any(
|
| 308 |
+
os.path.isfile(os.path.join(training_args.model_name_or_path, f))
|
| 309 |
+
for f in ("model.safetensors", "pytorch_model.bin", "model.safetensors.index.json", "pytorch_model.bin.index.json")
|
| 310 |
+
)
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
# -----------------------------------------------------------------------
|
| 314 |
+
# Fix: Rename checkpoint keys so that old-style weight names (e.g.
|
| 315 |
+
# self_attn.q_norm) are mapped to the current model attribute names
|
| 316 |
+
# (e.g. self_attn.query_layernorm). The model's
|
| 317 |
+
# _fix_state_dict_key_on_load hook is NOT invoked on the DeepSpeed
|
| 318 |
+
# ZeRO-3 loading path, so we monkey-patch the ZeRO-3 loader instead.
|
| 319 |
+
# -----------------------------------------------------------------------
|
| 320 |
+
_CKPT_KEY_RENAMES = [
|
| 321 |
+
("mlp.gate.wg.", "mlp.router.gate."),
|
| 322 |
+
]
|
| 323 |
+
|
| 324 |
+
from transformers.integrations.deepspeed import (
|
| 325 |
+
_load_state_dict_into_zero3_model as _orig_load_zero3,
|
| 326 |
+
)
|
| 327 |
+
import transformers.integrations.deepspeed as _ds_mod
|
| 328 |
+
import transformers.modeling_utils as _mu_mod
|
| 329 |
+
|
| 330 |
+
def _patched_load_zero3(model_to_load, state_dict, load_config=None):
|
| 331 |
+
new_sd = {}
|
| 332 |
+
for k, v in state_dict.items():
|
| 333 |
+
new_k = k
|
| 334 |
+
for old_sub, new_sub in _CKPT_KEY_RENAMES:
|
| 335 |
+
if old_sub in new_k:
|
| 336 |
+
new_k = new_k.replace(old_sub, new_sub)
|
| 337 |
+
break
|
| 338 |
+
new_sd[new_k] = v
|
| 339 |
+
|
| 340 |
+
# Call original ZeRO-3 loader for parameters
|
| 341 |
+
result = _orig_load_zero3(model_to_load, new_sd, load_config)
|
| 342 |
+
|
| 343 |
+
# -------------------------------------------------------------------
|
| 344 |
+
# Patch: Manually load buffers (e.g. e_score_correction_bias).
|
| 345 |
+
# ZeRO-3's loader only handles named_parameters, not named_buffers.
|
| 346 |
+
# -------------------------------------------------------------------
|
| 347 |
+
buffers_loaded = 0
|
| 348 |
+
for name, buf in model_to_load.named_buffers():
|
| 349 |
+
if name in new_sd:
|
| 350 |
+
src_tensor = new_sd[name]
|
| 351 |
+
if isinstance(src_tensor, torch.Tensor):
|
| 352 |
+
buf.data.copy_(src_tensor.to(buf.dtype))
|
| 353 |
+
buffers_loaded += 1
|
| 354 |
+
# Remove from unexpected keys if tracked
|
| 355 |
+
if isinstance(result, tuple) and len(result) >= 2:
|
| 356 |
+
if isinstance(result[1], set):
|
| 357 |
+
result[1].discard(name)
|
| 358 |
+
if buffers_loaded > 0:
|
| 359 |
+
print(f"[HYV3 Patch] Manually loaded {buffers_loaded} buffers "
|
| 360 |
+
f"(e.g. e_score_correction_bias) into model.")
|
| 361 |
+
|
| 362 |
+
return result
|
| 363 |
+
|
| 364 |
+
_ds_mod._load_state_dict_into_zero3_model = _patched_load_zero3
|
| 365 |
+
_mu_mod._load_state_dict_into_zero3_model = _patched_load_zero3
|
| 366 |
+
# -----------------------------------------------------------------------
|
| 367 |
+
|
| 368 |
+
# -------------------------------------------------------------------
|
| 369 |
+
# Patch: Save-time reverse key rename + 3D -> per-expert unfuse.
|
| 370 |
+
#
|
| 371 |
+
# When saving checkpoints, the model state_dict uses 3D fused experts
|
| 372 |
+
# and new naming. We reverse both for old checkpoint compatibility:
|
| 373 |
+
# - mlp.gate. -> mlp.router.gate.
|
| 374 |
+
# - mlp.e_score_correction_bias -> mlp.expert_bias
|
| 375 |
+
# - mlp.shared_experts. -> mlp.shared_mlp.
|
| 376 |
+
# - experts.gate_up_proj -> experts.{N}.gate_proj.weight + up_proj
|
| 377 |
+
# - experts.down_proj -> experts.{N}.down_proj.weight
|
| 378 |
+
# -------------------------------------------------------------------
|
| 379 |
+
_SAVE_KEY_RENAMES = [
|
| 380 |
+
("mlp.gate.", "mlp.router.gate."),
|
| 381 |
+
("mlp.e_score_correction_bias", "mlp.expert_bias"),
|
| 382 |
+
("mlp.shared_experts.", "mlp.shared_mlp."),
|
| 383 |
+
]
|
| 384 |
+
_FUSED_EXPERT_KEY_RE = re.compile(
|
| 385 |
+
r"^(.*\.mlp\.experts\.)(gate_up_proj|down_proj)$"
|
| 386 |
+
)
|
| 387 |
+
|
| 388 |
+
def _apply_save_reverse_rename_patch():
|
| 389 |
+
try:
|
| 390 |
+
from transformers.models.hy_v3.modeling_hy_v3 import HYV3ForCausalLM
|
| 391 |
+
except ImportError:
|
| 392 |
+
try:
|
| 393 |
+
from transformers.hy_v3.modeling_hy_v3 import HYV3ForCausalLM
|
| 394 |
+
except ImportError:
|
| 395 |
+
print("[HYV3 Patch] Could not import HYV3ForCausalLM; "
|
| 396 |
+
"save reverse rename patch NOT applied.")
|
| 397 |
+
return
|
| 398 |
+
|
| 399 |
+
_orig_save_pretrained = HYV3ForCausalLM.save_pretrained
|
| 400 |
+
|
| 401 |
+
def _patched_save_pretrained(self, *args, **kwargs):
|
| 402 |
+
state_dict = kwargs.get("state_dict", None)
|
| 403 |
+
if state_dict is not None:
|
| 404 |
+
reversed_sd = {}
|
| 405 |
+
for k, v in state_dict.items():
|
| 406 |
+
new_k = k
|
| 407 |
+
# Apply simple key renames
|
| 408 |
+
for new_sub, old_sub in _SAVE_KEY_RENAMES:
|
| 409 |
+
if new_sub in new_k:
|
| 410 |
+
new_k = new_k.replace(new_sub, old_sub)
|
| 411 |
+
break
|
| 412 |
+
|
| 413 |
+
# Check if this is a fused 3D expert key
|
| 414 |
+
m = _FUSED_EXPERT_KEY_RE.match(new_k)
|
| 415 |
+
if m:
|
| 416 |
+
prefix = m.group(1) # e.g. "model.layers.1.mlp.experts."
|
| 417 |
+
proj_type = m.group(2) # "gate_up_proj" or "down_proj"
|
| 418 |
+
|
| 419 |
+
if proj_type == "gate_up_proj":
|
| 420 |
+
# v shape: [num_experts, 2*intermediate, hidden]
|
| 421 |
+
num_experts = v.shape[0]
|
| 422 |
+
intermediate = v.shape[1] // 2
|
| 423 |
+
for i in range(num_experts):
|
| 424 |
+
gate = v[i, :intermediate, :]
|
| 425 |
+
up = v[i, intermediate:, :]
|
| 426 |
+
reversed_sd[f"{prefix}{i}.gate_proj.weight"] = gate
|
| 427 |
+
reversed_sd[f"{prefix}{i}.up_proj.weight"] = up
|
| 428 |
+
elif proj_type == "down_proj":
|
| 429 |
+
# v shape: [num_experts, hidden, intermediate]
|
| 430 |
+
num_experts = v.shape[0]
|
| 431 |
+
for i in range(num_experts):
|
| 432 |
+
reversed_sd[f"{prefix}{i}.down_proj.weight"] = v[i]
|
| 433 |
+
else:
|
| 434 |
+
reversed_sd[new_k] = v
|
| 435 |
+
|
| 436 |
+
kwargs["state_dict"] = reversed_sd
|
| 437 |
+
print(f"[HYV3 Patch] Reverse-renamed and unfused "
|
| 438 |
+
f"{len(state_dict)} -> {len(reversed_sd)} "
|
| 439 |
+
f"state_dict keys for old checkpoint compatibility.")
|
| 440 |
+
return _orig_save_pretrained(self, *args, **kwargs)
|
| 441 |
+
|
| 442 |
+
HYV3ForCausalLM.save_pretrained = _patched_save_pretrained
|
| 443 |
+
print("[HYV3 Patch] Applied: save-time reverse key rename + "
|
| 444 |
+
"3D -> per-expert unfuse for old ckpt compatibility.")
|
| 445 |
+
|
| 446 |
+
_apply_save_reverse_rename_patch()
|
| 447 |
+
# -------------------------------------------------------------------
|
| 448 |
+
|
| 449 |
+
if _has_weights:
|
| 450 |
+
print(f"Initializing model from local file: {training_args.model_name_or_path}")
|
| 451 |
+
model = transformers.AutoModelForCausalLM.from_pretrained(
|
| 452 |
+
training_args.model_name_or_path,
|
| 453 |
+
trust_remote_code=True,
|
| 454 |
+
**init_kwargs
|
| 455 |
+
)
|
| 456 |
+
else:
|
| 457 |
+
from transformers import HYV3Config
|
| 458 |
+
from transformers import HYV3ForCausalLM
|
| 459 |
+
print(f"Model weights not found at: {training_args.model_name_or_path}, "
|
| 460 |
+
f"using random initialized HYV3 model instead.")
|
| 461 |
+
# Use len(tokenizer) to include added special tokens; tokenizer.vocab_size
|
| 462 |
+
# may only return the base vocabulary size and miss special tokens whose
|
| 463 |
+
# IDs exceed that range, causing index-out-of-bounds in the embedding layer.
|
| 464 |
+
config = HYV3Config(
|
| 465 |
+
vocab_size=len(tokenizer),
|
| 466 |
+
hidden_size=model_args.hidden_size,
|
| 467 |
+
intermediate_size=model_args.intermediate_size,
|
| 468 |
+
max_position_embeddings=training_args.model_max_length,
|
| 469 |
+
moe_topk=model_args.moe_topk,
|
| 470 |
+
num_experts=model_args.num_experts,
|
| 471 |
+
num_attention_heads=model_args.num_attention_heads,
|
| 472 |
+
num_key_value_heads=model_args.num_key_value_heads,
|
| 473 |
+
num_hidden_layers=model_args.num_layers,
|
| 474 |
+
moe_intermediate_size=model_args.moe_intermediate_size,
|
| 475 |
+
use_mixed_mlp_moe=model_args.use_mixed_mlp_moe,
|
| 476 |
+
num_shared_expert=model_args.num_shared_expert,
|
| 477 |
+
use_qk_norm=model_args.use_qk_norm,
|
| 478 |
+
moe_layer_num_skipped=model_args.moe_layer_num_skipped,
|
| 479 |
+
tie_word_embeddings=model_args.tie_word_embeddings,
|
| 480 |
+
)
|
| 481 |
+
with deepspeed.zero.Init(dtype=init_kwargs.get("torch_dtype", torch.bfloat16), config_dict_or_path=training_args.deepspeed):
|
| 482 |
+
model = HYV3ForCausalLM(config)
|
| 483 |
+
|
| 484 |
+
if model_args.train_attention_params_only:
|
| 485 |
+
for name, param in model.named_parameters():
|
| 486 |
+
if 'self_attn' not in name:
|
| 487 |
+
param.requires_grad = False
|
| 488 |
+
|
| 489 |
+
if model_args.use_lora:
|
| 490 |
+
# define Lora configuration
|
| 491 |
+
lora_config = LoraConfig(
|
| 492 |
+
r=model_args.lora_rank,
|
| 493 |
+
lora_alpha=model_args.lora_alpha,
|
| 494 |
+
lora_dropout=model_args.lora_dropout,
|
| 495 |
+
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
|
| 496 |
+
bias="none",
|
| 497 |
+
task_type="CAUSAL_LM",
|
| 498 |
+
)
|
| 499 |
+
model = get_peft_model(model, lora_config)
|
| 500 |
+
|
| 501 |
+
data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
|
| 502 |
+
# Tell Trainer not to attempt DataParallel
|
| 503 |
+
model.is_parallelizable = True
|
| 504 |
+
model.model_parallel = True
|
| 505 |
+
|
| 506 |
+
training_args.lr_scheduler_kwargs = {
|
| 507 |
+
'min_lr_rate': training_args.min_lr / training_args.learning_rate,
|
| 508 |
+
}
|
| 509 |
+
|
| 510 |
+
# -----------------------------------------------------------------------
|
| 511 |
+
# Fix: DeepSpeed ZeRO-3 + gradient checkpointing compatibility.
|
| 512 |
+
#
|
| 513 |
+
# PyTorch's torch.utils.checkpoint with use_reentrant=False (the default
|
| 514 |
+
# in transformers) performs strict metadata checks on recomputed tensors
|
| 515 |
+
# during backward. Under ZeRO-3, parameters are all-gathered during the
|
| 516 |
+
# first forward pass (shape=[full_size]) but may be partitioned back
|
| 517 |
+
# (shape=[0]) when the checkpoint recomputes, causing a CheckpointError.
|
| 518 |
+
#
|
| 519 |
+
# Setting use_reentrant=True avoids this strict metadata check.
|
| 520 |
+
# -----------------------------------------------------------------------
|
| 521 |
+
if training_args.gradient_checkpointing and training_args.deepspeed:
|
| 522 |
+
training_args.gradient_checkpointing_kwargs = {"use_reentrant": True}
|
| 523 |
+
|
| 524 |
+
trainer = Trainer(
|
| 525 |
+
model=model,
|
| 526 |
+
processing_class=tokenizer,
|
| 527 |
+
args=training_args,
|
| 528 |
+
callbacks=[CustomSaveCallback],
|
| 529 |
+
**data_module
|
| 530 |
+
)
|
| 531 |
+
model.config.use_cache = False
|
| 532 |
+
|
| 533 |
+
# -----------------------------------------------------------------------
|
| 534 |
+
# Monkey-patch: fix dtype mismatch in DeepSpeed ZeRO-3 linear wrapper.
|
| 535 |
+
#
|
| 536 |
+
# By this point the DeepSpeed engine has been initialised by the Trainer
|
| 537 |
+
# and torch.nn.functional.linear has been replaced with
|
| 538 |
+
# zero3_linear_wrap. That wrapper does NOT auto-align input/weight
|
| 539 |
+
# dtypes before the matmul, causing "expected mat1 and mat2 to have the
|
| 540 |
+
# same dtype" errors in mixed-precision paths (MoE router gate in fp32
|
| 541 |
+
# with bf16 weights, expert FFN receiving fp32 routing-weighted input
|
| 542 |
+
# with bf16 weights, etc.).
|
| 543 |
+
#
|
| 544 |
+
# We wrap F.linear HERE (after DeepSpeed init) so that:
|
| 545 |
+
# 1. We are sure to capture the already-replaced function.
|
| 546 |
+
# 2. The dtype cast happens *outside* the autograd.Function, so
|
| 547 |
+
# gradient-checkpointing recompute sees identical tensor metadata.
|
| 548 |
+
# -----------------------------------------------------------------------
|
| 549 |
+
import torch.nn.functional as _F
|
| 550 |
+
_orig_F_linear = _F.linear
|
| 551 |
+
|
| 552 |
+
def _dtype_safe_linear(input, weight, bias=None):
|
| 553 |
+
if input.dtype != weight.dtype:
|
| 554 |
+
input = input.to(weight.dtype)
|
| 555 |
+
return _orig_F_linear(input, weight, bias)
|
| 556 |
+
|
| 557 |
+
_F.linear = _dtype_safe_linear
|
| 558 |
+
# -----------------------------------------------------------------------
|
| 559 |
+
|
| 560 |
+
trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
|
| 561 |
+
|
| 562 |
+
|
| 563 |
+
if __name__ == "__main__":
|
| 564 |
+
train()
|
train/train.sh
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
NET_TYPE="high"
|
| 4 |
+
export NCCL_DEBUG=WARN
|
| 5 |
+
export NCCL_P2P_LEVEL=NVL
|
| 6 |
+
export NCCL_IB_TIMEOUT=24
|
| 7 |
+
export NCCL_NVLS_ENABLE=0
|
| 8 |
+
export NCCL_MPI_PROFILE_PRIMS_ENABLE=0
|
| 9 |
+
export CUDA_DEVICE_MAX_CONNECTIONS=1
|
| 10 |
+
export TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC=3600
|
| 11 |
+
if [[ "${NET_TYPE}" = "low" ]]; then
|
| 12 |
+
export NCCL_SOCKET_IFNAME=eth1
|
| 13 |
+
export NCCL_IB_GID_INDEX=3
|
| 14 |
+
export NCCL_IB_HCA=mlx5_2:1
|
| 15 |
+
export NCCL_IB_SL=3
|
| 16 |
+
export NCCL_CHECK_DISABLE=1
|
| 17 |
+
export NCCL_P2P_DISABLE=0
|
| 18 |
+
export NCCL_LL_THRESHOLD=16384
|
| 19 |
+
export NCCL_IB_CUDA_SUPPORT=1
|
| 20 |
+
else
|
| 21 |
+
export NCCL_IB_GID_INDEX=3
|
| 22 |
+
export NCCL_IB_SL=3
|
| 23 |
+
export NCCL_CHECK_DISABLE=1
|
| 24 |
+
export NCCL_P2P_DISABLE=0
|
| 25 |
+
export NCCL_IB_DISABLE=0
|
| 26 |
+
export NCCL_LL_THRESHOLD=16384
|
| 27 |
+
export NCCL_IB_CUDA_SUPPORT=1
|
| 28 |
+
export NCCL_SOCKET_IFNAME=bond1
|
| 29 |
+
export UCX_NET_DEVICES=bond1
|
| 30 |
+
export NCCL_IB_HCA=mlx5_bond_1,mlx5_bond_5,mlx5_bond_3,mlx5_bond_7,mlx5_bond_4,mlx5_bond_8,mlx5_bond_2,mlx5_bond_6
|
| 31 |
+
export NCCL_COLLNET_ENABLE=0
|
| 32 |
+
export SHARP_COLL_ENABLE_SAT=0
|
| 33 |
+
export NCCL_NET_GDR_LEVEL=2
|
| 34 |
+
export NCCL_IB_QPS_PER_CONNECTION=4
|
| 35 |
+
export NCCL_IB_TC=160
|
| 36 |
+
export NCCL_PXN_DISABLE=1
|
| 37 |
+
fi
|
| 38 |
+
|
| 39 |
+
export HOST_GPU_NUM=8
|
| 40 |
+
# IP list, comma separated. e.g. "192.168.1.1,192.168.1.2" or single node "192.168.1.1"
|
| 41 |
+
IP_LIST=${IP_LIST:-"127.0.0.1"}
|
| 42 |
+
|
| 43 |
+
IFS=',' read -ra IP_ARRAY <<< "$IP_LIST"
|
| 44 |
+
export NODES=${#IP_ARRAY[@]}
|
| 45 |
+
export LOCAL_IP=${IP_ARRAY[0]}
|
| 46 |
+
NODE_IP_LIST=""
|
| 47 |
+
for ip in "${IP_ARRAY[@]}"; do
|
| 48 |
+
if [ -n "$NODE_IP_LIST" ]; then
|
| 49 |
+
NODE_IP_LIST="${NODE_IP_LIST},"
|
| 50 |
+
fi
|
| 51 |
+
NODE_IP_LIST="${NODE_IP_LIST}${ip}:${HOST_GPU_NUM}"
|
| 52 |
+
done
|
| 53 |
+
export NODE_IP_LIST
|
| 54 |
+
export NODE_NUM=$((${NODES} * ${HOST_GPU_NUM}))
|
| 55 |
+
|
| 56 |
+
model_path=path_to_model_weight
|
| 57 |
+
tokenizer_path=../models
|
| 58 |
+
train_data_file=example_data.jsonl
|
| 59 |
+
|
| 60 |
+
# ds_config_file=ds_zero2_no_offload.json
|
| 61 |
+
# ds_config_file=ds_zero3_no_offload.json
|
| 62 |
+
ds_config_file=ds_zero3_offload_no_auto.json
|
| 63 |
+
|
| 64 |
+
output_path=/root/hf_train_output_full
|
| 65 |
+
|
| 66 |
+
mkdir -p ${output_path}
|
| 67 |
+
|
| 68 |
+
current_time=$(date "+%Y.%m.%d-%H.%M.%S")
|
| 69 |
+
log_file=${output_path}/"log_${current_time}.txt"
|
| 70 |
+
|
| 71 |
+
echo $NODE_IP_LIST > env.txt 2>&1
|
| 72 |
+
sed "s/:/ slots=/g" env.txt | sed "s/,/\n/g" > "hostfile"
|
| 73 |
+
sed "s/:.//g" env.txt | sed "s/,/\n/g" > "pssh.hosts"
|
| 74 |
+
export CHIEF_IP=$LOCAL_IP
|
| 75 |
+
|
| 76 |
+
if [ ${NODES} -gt 1 ]; then
|
| 77 |
+
HOST_PATH=hostfile
|
| 78 |
+
DS_ARGS="--hostfile=${HOST_PATH} --master_addr ${CHIEF_IP}"
|
| 79 |
+
else
|
| 80 |
+
DS_ARGS=""
|
| 81 |
+
fi
|
| 82 |
+
|
| 83 |
+
echo "NODES: ${NODES}, LOCAL_IP: ${LOCAL_IP}, NODE_IP_LIST: ${NODE_IP_LIST}"
|
| 84 |
+
|
| 85 |
+
deepspeed ${DS_ARGS} \
|
| 86 |
+
train.py \
|
| 87 |
+
--do_train \
|
| 88 |
+
--model_name_or_path ${model_path} \
|
| 89 |
+
--tokenizer_name_or_path ${tokenizer_path} \
|
| 90 |
+
--train_data_file ${train_data_file} \
|
| 91 |
+
--deepspeed ${ds_config_file} \
|
| 92 |
+
--output_dir ${output_path} \
|
| 93 |
+
--per_device_train_batch_size 1 \
|
| 94 |
+
--gradient_accumulation_steps 1 \
|
| 95 |
+
--gradient_checkpointing \
|
| 96 |
+
--lr_scheduler_type cosine_with_min_lr \
|
| 97 |
+
--logging_steps 1 \
|
| 98 |
+
--max_steps 200 \
|
| 99 |
+
--save_steps 100 \
|
| 100 |
+
--learning_rate 1e-5 \
|
| 101 |
+
--min_lr 1e-6 \
|
| 102 |
+
--warmup_ratio 0.01 \
|
| 103 |
+
--save_strategy steps \
|
| 104 |
+
--bf16 \
|
| 105 |
+
--hidden_size 4096 \
|
| 106 |
+
--intermediate_size 13312 \
|
| 107 |
+
--model_max_length 4096 \
|
| 108 |
+
--max_seq_length 4096 \
|
| 109 |
+
--moe_topk 8 \
|
| 110 |
+
--num_experts 192 \
|
| 111 |
+
--moe_intermediate_size 1536 \
|
| 112 |
+
--moe_layer_num_skipped 1 \
|
| 113 |
+
--num_attention_heads 64 \
|
| 114 |
+
--num_key_value_heads 8 \
|
| 115 |
+
--num_layers 80 \
|
| 116 |
+
--use_mixed_mlp_moe \
|
| 117 |
+
--num_shared_expert 1 \
|
| 118 |
+
--use_qk_norm | tee ${log_file}
|
train/train_lora.sh
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
NET_TYPE="high"
|
| 4 |
+
export NCCL_DEBUG=WARN
|
| 5 |
+
export NCCL_P2P_LEVEL=NVL
|
| 6 |
+
export NCCL_IB_TIMEOUT=24
|
| 7 |
+
export NCCL_NVLS_ENABLE=0
|
| 8 |
+
export NCCL_MPI_PROFILE_PRIMS_ENABLE=0
|
| 9 |
+
export CUDA_DEVICE_MAX_CONNECTIONS=1
|
| 10 |
+
export TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC=3600
|
| 11 |
+
if [[ "${NET_TYPE}" = "low" ]]; then
|
| 12 |
+
export NCCL_SOCKET_IFNAME=eth1
|
| 13 |
+
export NCCL_IB_GID_INDEX=3
|
| 14 |
+
export NCCL_IB_HCA=mlx5_2:1
|
| 15 |
+
export NCCL_IB_SL=3
|
| 16 |
+
export NCCL_CHECK_DISABLE=1
|
| 17 |
+
export NCCL_P2P_DISABLE=0
|
| 18 |
+
export NCCL_LL_THRESHOLD=16384
|
| 19 |
+
export NCCL_IB_CUDA_SUPPORT=1
|
| 20 |
+
else
|
| 21 |
+
export NCCL_IB_GID_INDEX=3
|
| 22 |
+
export NCCL_IB_SL=3
|
| 23 |
+
export NCCL_CHECK_DISABLE=1
|
| 24 |
+
export NCCL_P2P_DISABLE=0
|
| 25 |
+
export NCCL_IB_DISABLE=0
|
| 26 |
+
export NCCL_LL_THRESHOLD=16384
|
| 27 |
+
export NCCL_IB_CUDA_SUPPORT=1
|
| 28 |
+
export NCCL_SOCKET_IFNAME=bond1
|
| 29 |
+
export UCX_NET_DEVICES=bond1
|
| 30 |
+
export NCCL_IB_HCA=mlx5_bond_1,mlx5_bond_5,mlx5_bond_3,mlx5_bond_7,mlx5_bond_4,mlx5_bond_8,mlx5_bond_2,mlx5_bond_6
|
| 31 |
+
export NCCL_COLLNET_ENABLE=0
|
| 32 |
+
export SHARP_COLL_ENABLE_SAT=0
|
| 33 |
+
export NCCL_NET_GDR_LEVEL=2
|
| 34 |
+
export NCCL_IB_QPS_PER_CONNECTION=4
|
| 35 |
+
export NCCL_IB_TC=160
|
| 36 |
+
export NCCL_PXN_DISABLE=1
|
| 37 |
+
fi
|
| 38 |
+
|
| 39 |
+
export HOST_GPU_NUM=8
|
| 40 |
+
# IP list, comma separated. e.g. "192.168.1.1,192.168.1.2" or single node "192.168.1.1"
|
| 41 |
+
IP_LIST=${IP_LIST:-"127.0.0.1"}
|
| 42 |
+
|
| 43 |
+
IFS=',' read -ra IP_ARRAY <<< "$IP_LIST"
|
| 44 |
+
export NODES=${#IP_ARRAY[@]}
|
| 45 |
+
export LOCAL_IP=${IP_ARRAY[0]}
|
| 46 |
+
NODE_IP_LIST=""
|
| 47 |
+
for ip in "${IP_ARRAY[@]}"; do
|
| 48 |
+
if [ -n "$NODE_IP_LIST" ]; then
|
| 49 |
+
NODE_IP_LIST="${NODE_IP_LIST},"
|
| 50 |
+
fi
|
| 51 |
+
NODE_IP_LIST="${NODE_IP_LIST}${ip}:${HOST_GPU_NUM}"
|
| 52 |
+
done
|
| 53 |
+
export NODE_IP_LIST
|
| 54 |
+
export NODE_NUM=$((${NODES} * ${HOST_GPU_NUM}))
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
model_path=path_to_model_weight
|
| 58 |
+
tokenizer_path=../models
|
| 59 |
+
train_data_file=example_data.jsonl
|
| 60 |
+
|
| 61 |
+
# ds_config_file=ds_zero2_no_offload.json
|
| 62 |
+
# ds_config_file=ds_zero3_no_offload.json
|
| 63 |
+
ds_config_file=ds_zero3_offload.json
|
| 64 |
+
|
| 65 |
+
output_path=/root/hf_train_output
|
| 66 |
+
|
| 67 |
+
mkdir -p ${output_path}
|
| 68 |
+
|
| 69 |
+
current_time=$(date "+%Y.%m.%d-%H.%M.%S")
|
| 70 |
+
log_file=${output_path}/"log_${current_time}.txt"
|
| 71 |
+
|
| 72 |
+
echo $NODE_IP_LIST > env.txt 2>&1
|
| 73 |
+
sed "s/:/ slots=/g" env.txt | sed "s/,/\n/g" > "hostfile"
|
| 74 |
+
sed "s/:.//g" env.txt | sed "s/,/\n/g" > "pssh.hosts"
|
| 75 |
+
export CHIEF_IP=$LOCAL_IP
|
| 76 |
+
|
| 77 |
+
if [ ${NODES} -gt 1 ]; then
|
| 78 |
+
HOST_PATH=hostfile
|
| 79 |
+
DS_ARGS="--hostfile=${HOST_PATH} --master_addr ${CHIEF_IP}"
|
| 80 |
+
else
|
| 81 |
+
DS_ARGS=""
|
| 82 |
+
fi
|
| 83 |
+
|
| 84 |
+
echo "NODES: ${NODES}, LOCAL_IP: ${LOCAL_IP}, NODE_IP_LIST: ${NODE_IP_LIST}"
|
| 85 |
+
|
| 86 |
+
deepspeed ${DS_ARGS} \
|
| 87 |
+
train.py \
|
| 88 |
+
--do_train \
|
| 89 |
+
--model_name_or_path ${model_path} \
|
| 90 |
+
--tokenizer_name_or_path ${tokenizer_path} \
|
| 91 |
+
--train_data_file ${train_data_file} \
|
| 92 |
+
--deepspeed ${ds_config_file} \
|
| 93 |
+
--output_dir ${output_path} \
|
| 94 |
+
--per_device_train_batch_size 1 \
|
| 95 |
+
--gradient_accumulation_steps 1 \
|
| 96 |
+
--gradient_checkpointing \
|
| 97 |
+
--lr_scheduler_type cosine_with_min_lr \
|
| 98 |
+
--logging_steps 1 \
|
| 99 |
+
--max_steps 200 \
|
| 100 |
+
--save_steps 100 \
|
| 101 |
+
--learning_rate 1e-5 \
|
| 102 |
+
--min_lr 1e-6 \
|
| 103 |
+
--warmup_ratio 0.01 \
|
| 104 |
+
--save_strategy steps \
|
| 105 |
+
--bf16 \
|
| 106 |
+
--use_lora \
|
| 107 |
+
--lora_rank 64 \
|
| 108 |
+
--lora_alpha 128 \
|
| 109 |
+
--lora_dropout 0.1 \
|
| 110 |
+
--hidden_size 4096 \
|
| 111 |
+
--intermediate_size 13312 \
|
| 112 |
+
--model_max_length 4096 \
|
| 113 |
+
--max_seq_length 4096 \
|
| 114 |
+
--moe_topk 8 \
|
| 115 |
+
--num_experts 192 \
|
| 116 |
+
--moe_intermediate_size 1536 \
|
| 117 |
+
--moe_layer_num_skipped 1 \
|
| 118 |
+
--num_attention_heads 64 \
|
| 119 |
+
--num_key_value_heads 8 \
|
| 120 |
+
--num_layers 80 \
|
| 121 |
+
--use_mixed_mlp_moe \
|
| 122 |
+
--num_shared_expert 1 \
|
| 123 |
+
--use_qk_norm | tee ${log_file}
|