yiqichen01 commited on
Commit
cf1003d
·
verified ·
1 Parent(s): f332661

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/bench_agent_overview_v3.jpg filter=lfs diff=lfs merge=lfs -text
37
+ assets/bench_claw_agent2.jpg filter=lfs diff=lfs merge=lfs -text
38
+ assets/bench_context.jpg filter=lfs diff=lfs merge=lfs -text
39
+ assets/bench_stem.jpg filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TENCENT HY COMMUNITY LICENSE AGREEMENT
2
+ Tencent Hy3 preview Release Date: April 23, 2026
3
+ THIS LICENSE AGREEMENT DOES NOT APPLY IN THE EUROPEAN UNION, UNITED KINGDOM AND SOUTH KOREA AND IS EXPRESSLY LIMITED TO THE TERRITORY, AS DEFINED BELOW.
4
+ By clicking to agree or by using, reproducing, modifying, distributing, performing or displaying any portion or element of the Tencent Hy Works, including via any Hosted Service, You will be deemed to have recognized and accepted the content of this Agreement, which is effective immediately.
5
+ 1. DEFINITIONS.
6
+ a. “Acceptable Use Policy” shall mean the policy made available by Tencent as set forth in the Exhibit A.
7
+ b. “Agreement” shall mean the terms and conditions for use, reproduction, distribution, modification, performance and displaying of Tencent Hy Works or any portion or element thereof set forth herein.
8
+ c. “Documentation” shall mean the specifications, manuals and documentation for Tencent Hy made publicly available by Tencent.
9
+ d. “Hosted Service” shall mean a hosted service offered via an application programming interface (API), web access, or any other electronic or remote means.
10
+ e. “Licensee,” “You” or “Your” shall mean a natural person or legal entity exercising the rights granted by this Agreement and/or using the Tencent Hy Works for any purpose and in any field of use.
11
+ f. “Materials” shall mean, collectively, Tencent’s proprietary Tencent Hy and Documentation (and any portion thereof) as made available by Tencent under this Agreement.
12
+ g. “Model Derivatives” shall mean all: (i) modifications to Tencent Hy or any Model Derivative of Tencent Hy; (ii) works based on Tencent Hy or any Model Derivative of Tencent Hy; or (iii) any other machine learning model which is created by transfer of patterns of the weights, parameters, operations, or Output of Tencent Hy or any Model Derivative of Tencent Hy, to that model in order to cause that model to perform similarly to Tencent Hy or a Model Derivative of Tencent Hy, including distillation methods, methods that use intermediate data representations, or methods based on the generation of synthetic data Outputs by Tencent Hy or a Model Derivative of Tencent Hy for training that model. For clarity, Outputs by themselves are not deemed Model Derivatives.
13
+ h. “Output” shall mean the information and/or content output of Tencent Hy or a Model Derivative that results from operating or otherwise using Tencent Hy or a Model Derivative, including via a Hosted Service.
14
+ i. “Tencent,” “We” or “Us” shall mean the applicable entity or entities in the Tencent corporate family that own(s) intellectual property or other rights embodied in or utilized by the Materials.
15
+ j. “Tencent Hy” shall mean the large language models, text/image/video/audio/3D generation models, and multimodal large language models and their software and algorithms, including trained model weights, parameters (including optimizer states), machine-learning model code, inference-enabling code, training-enabling code, fine-tuning enabling code and other elements of the foregoing made publicly available by Us, including, without limitation to, Tencent Hy3 preview released at [https://huggingface.co/tencent/Hy3-preview; https://github.com/Tencent-Hunyuan/Hy3-preview].
16
+ k. “Tencent Hy Works” shall mean: (i) the Materials; (ii) Model Derivatives; and (iii) all derivative works thereof.
17
+ l. “Territory” shall mean the worldwide territory, excluding the territory of the European Union, United Kingdom and South Korea.
18
+ m. “Third Party” or “Third Parties” shall mean individuals or legal entities that are not under common control with Us or You.
19
+ n. “including” shall mean including but not limited to.
20
+ 2. GRANT OF RIGHTS.
21
+ We grant You, for the Territory only, a non-exclusive, non-transferable and royalty-free limited license under Tencent’s intellectual property or other rights owned by Us embodied in or utilized by the Materials to use, reproduce, distribute, create derivative works of (including Model Derivatives), and make modifications to the Materials, only in accordance with the terms of this Agreement and the Acceptable Use Policy, and You must not violate (or encourage or permit anyone else to violate) any term of this Agreement or the Acceptable Use Policy.
22
+ 3. DISTRIBUTION.
23
+ You may, subject to Your compliance with this Agreement, distribute or make available to Third Parties the Tencent Hy Works, exclusively in the Territory, provided that You meet all of the following conditions:
24
+ a. You must provide all such Third Party recipients of the Tencent Hy Works or products or services using them a copy of this Agreement;
25
+ b. You must cause any modified files to carry prominent notices stating that You changed the files;
26
+ c. You are encouraged to: (i) publish at least one technology introduction blogpost or one public statement expressing Your experience of using the Tencent Hy Works; and (ii) mark the products or services developed by using the Tencent Hy Works to indicate that the product/service is “Powered by Tencent Hy”; and
27
+ d. All distributions to Third Parties (other than through a Hosted Service) must be accompanied by a “Notice” text file that contains the following notice: “Tencent Hy is licensed under the Tencent Hy Community License Agreement, Copyright © 2026 Tencent. All Rights Reserved. The trademark rights of “Tencent Hy” are owned by Tencent or its affiliate.”
28
+ e. In the event that You use, integrate, implement, or otherwise deploy the Tencent Hy Works, in whole or in part, to provide, enable, or support any service, product, or functionality to third parties, You shall clearly, accurately, and prominently disclose to all end users the full legal name and entity of the actual provider of such service, product, or functionality. You shall expressly and conspicuously state that Tencent is not affiliated with, associated with, sponsoring, or endorsing any such service, product, or functionality. You shall not use or display any name, logo, trademark, trade name, or other indicia of Tencent in any manner that could be construed as, or be likely to create, confusion, deception, or a false impression regarding any relationship, affiliation, sponsorship, or endorsement by Tencent.
29
+ You may add Your own copyright statement to Your modifications and, except as set forth in this Section and in Section 5, may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Model Derivatives as a whole, provided Your use, reproduction, modification, distribution, performance and display of the work otherwise complies with the terms and conditions of this Agreement (including as regards the Territory). If You receive Tencent Hy Works from a Licensee as part of an integrated end user product, then this Section 3 of this Agreement will not apply to You.
30
+ 4. ADDITIONAL COMMERCIAL TERMS.
31
+ If, on the Tencent Hy version release date, the monthly active users of all products or services made available by or for Licensee is greater than 100 million monthly active users in the preceding calendar month, You must request a license from Tencent, which Tencent may grant to You in its sole discretion, and You are not authorized to exercise any of the rights under this Agreement unless or until Tencent otherwise expressly grants You such rights.
32
+ 5. RULES OF USE.
33
+ a. Your use of the Tencent Hy Works must comply with applicable laws and regulations (including trade compliance laws and regulations) and adhere to the Acceptable Use Policy for the Tencent Hy Works, which is hereby incorporated by reference into this Agreement. You must include the use restrictions referenced in these Sections 5(a) and 5(b) as an enforceable provision in any agreement (e.g., license agreement, terms of use, etc.) governing the use and/or distribution of Tencent Hy Works and You must provide notice to subsequent users to whom You distribute that Tencent Hy Works are subject to the use restrictions in these Sections 5(a) and 5(b).
34
+ b. You must not use the Tencent Hy Works or any Output or results of the Tencent Hy Works to improve any other AI model (other than Tencent Hy or Model Derivatives thereof).
35
+ c. You must not use, reproduce, modify, distribute, or display the Tencent Hy Works, Output or results of the Tencent Hy Works outside the Territory. Any such use outside the Territory is unlicensed and unauthorized under this Agreement.
36
+ 6. INTELLECTUAL PROPERTY.
37
+ a. Subject to Tencent’s ownership of Tencent Hy Works made by or for Tencent and intellectual property rights therein, conditioned upon Your compliance with the terms and conditions of this Agreement, as between You and Tencent, You will be the owner of any derivative works and modifications of the Materials and any Model Derivatives that are made by or for You.
38
+ b. No trademark licenses are granted under this Agreement, and in connection with the Tencent Hy Works, Licensee may not use any name or mark owned by or associated with Tencent or any of its affiliates, except as required for reasonable and customary use in describing and distributing the Tencent Hy Works. Tencent hereby grants You a license to use “Tencent Hy” (the “Mark”) in the Territory solely as required to comply with the provisions of Section 3(c), provided that You comply with any applicable laws related to trademark protection. All goodwill arising out of Your use of the Mark will inure to the benefit of Tencent.
39
+ c. If You commence a lawsuit or other proceedings (including a cross-claim or counterclaim in a lawsuit) against Us or any person or entity alleging that the Materials or any Output, or any portion of any of the foregoing, infringe any intellectual property or other right owned or licensable by You, then all licenses granted to You under this Agreement shall terminate as of the date such lawsuit or other proceeding is filed. You will defend, indemnify and hold harmless Us from and against any claim by any Third Party arising out of or related to Your or the Third Party’s use or distribution of the Tencent Hy Works.
40
+ d. Tencent claims no rights in Outputs You generate. You and Your users are solely responsible for Outputs and their subsequent uses.
41
+ 7. DISCLAIMERS OF WARRANTY AND LIMITATIONS OF LIABILITY.
42
+ a. We are not obligated to support, update, provide training for, or develop any further version of the Tencent Hy Works or to grant any license thereto.
43
+ b. UNLESS AND ONLY TO THE EXTENT REQUIRED BY APPLICABLE LAW, THE Tencent Hy WORKS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED “AS IS” WITHOUT ANY EXPRESS OR IMPLIED WARRANTIES OF ANY KIND INCLUDING ANY WARRANTIES OF TITLE, MERCHANTABILITY, NONINFRINGEMENT, COURSE OF DEALING, USAGE OF TRADE, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING, REPRODUCING, MODIFYING, PERFORMING, DISPLAYING OR DISTRIBUTING ANY OF THE Tencent Hy WORKS OR OUTPUTS AND ASSUME ANY AND ALL RISKS ASSOCIATED WITH YOUR OR A THIRD PARTY’S USE OR DISTRIBUTION OF ANY OF THE Tencent Hy WORKS OR OUTPUTS AND YOUR EXERCISE OF RIGHTS AND PERMISSIONS UNDER THIS AGREEMENT.
44
+ c. TO THE FULLEST EXTENT PERMITTED BY APPLICABLE LAW, IN NO EVENT SHALL TENCENT OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, FOR ANY DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, CONSEQUENTIAL OR PUNITIVE DAMAGES, OR LOST PROFITS OF ANY KIND ARISING FROM THIS AGREEMENT OR RELATED TO ANY OF THE Tencent Hy WORKS OR OUTPUTS, EVEN IF TENCENT OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
45
+ 8. SURVIVAL AND TERMINATION.
46
+ a. The term of this Agreement shall commence upon Your acceptance of this Agreement or access to the Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein.
47
+ b. We may terminate this Agreement if You breach any of the terms or conditions of this Agreement. Upon termination of this Agreement, You must promptly delete and cease use of the Tencent Hy Works. Sections 6(a), 6(c), 7 and 9 shall survive the termination of this Agreement.
48
+ 9. GOVERNING LAW AND JURISDICTION.
49
+ a. This Agreement and any dispute arising out of or relating to it will be governed by the laws of the Hong Kong Special Administrative Region of the People’s Republic of China, without regard to conflict of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement.
50
+ b. Exclusive jurisdiction and venue for any dispute arising out of or relating to this Agreement will be a court of competent jurisdiction in the Hong Kong Special Administrative Region of the People’s Republic of China, and Tencent and Licensee consent to the exclusive jurisdiction of such court with respect to any such dispute.
51
+
52
+ EXHIBIT A
53
+ ACCEPTABLE USE POLICY
54
+
55
+ Tencent reserves the right to update this Acceptable Use Policy from time to time.
56
+ Last modified: December 30, 2025
57
+
58
+ Tencent endeavors to promote safe and fair use of its tools and features, including Tencent Hy. You agree not to use Tencent Hy or Model Derivatives:
59
+ 1. Outside the Territory;
60
+ 2. In any way that violates any applicable national, federal, state, local, international or any other law or regulation;
61
+ 3. To harm Yourself or others;
62
+ 4. To repurpose or distribute output from Tencent Hy or any Model Derivatives to harm Yourself or others;
63
+ 5. To override or circumvent the safety guardrails and safeguards We have put in place;
64
+ 6. For the purpose of exploiting, harming or attempting to exploit or harm minors in any way;
65
+ 7. To generate or disseminate verifiably false information and/or content with the purpose of harming others or influencing elections;
66
+ 8. To generate or facilitate false online engagement, including fake reviews and other means of fake online engagement;
67
+ 9. To intentionally defame, disparage or otherwise harass others;
68
+ 10. To generate and/or disseminate malware (including ransomware) or any other content to be used for the purpose of harming electronic systems;
69
+ 11. To generate or disseminate personal identifiable information with the purpose of harming others;
70
+ 12. To generate or disseminate information (including images, code, posts, articles), and place the information in any public context (including –through the use of bot generated tweets), without expressly and conspicuously identifying that the information and/or content is machine generated;
71
+ 13. To impersonate another individual without consent, authorization, or legal right;
72
+ 14. To make high-stakes automated decisions in domains that affect an individual’s safety, rights or wellbeing (e.g., law enforcement, migration, medicine/health, management of critical infrastructure, safety components of products, essential services, credit, employment, housing, education, social scoring, or insurance);
73
+ 15. In a manner that violates or disrespects the social ethics and moral standards of other countries or regions;
74
+ 16. To perform, facilitate, threaten, incite, plan, promote or encourage violent extremism or terrorism;
75
+ 17. For any use intended to discriminate against or harm individuals or groups based on protected characteristics or categories, online or offline social behavior or known or predicted personal or personality characteristics;
76
+ 18. To intentionally exploit any of the vulnerabilities of a specific group of persons based on their age, social, physical or mental characteristics, in order to materially distort the behavior of a person pertaining to that group in a manner that causes or is likely to cause that person or another person physical or psychological harm;
77
+ 19. For military purposes;
78
+ 20. To engage in the unauthorized or unlicensed practice of any profession including, but not limited to, financial, legal, medical/health, or other professional practices.
README.md ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <p align="left">
2
+ <a href="README_CN.md">中文</a>&nbsp;|&nbsp;English
3
+ </p>
4
+ <br>
5
+
6
+ <p align="center">
7
+ <img src="assets/logo-en.png" width="400"/> <br>
8
+ </p>
9
+
10
+ <div align="center" style="line-height: 1;">
11
+
12
+
13
+ [![License](https://img.shields.io/badge/License-Tencent%20Hy%20Community-blue)](#license)
14
+ &nbsp;&nbsp;
15
+ [![HuggingFace](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Tencent%20Hy-ffc107?color=ffc107&logoColor=white)](https://huggingface.co/tencent/Hy3-preview)
16
+ &nbsp;&nbsp;
17
+ [![ModelScope](https://img.shields.io/badge/ModelScope-Tencent%20Hy-624aff)](https://modelscope.cn/models/Tencent-Hunyuan/Hy3-preview)
18
+ &nbsp;&nbsp;
19
+ [![cnb.cool](https://img.shields.io/badge/cnb.cool-Tencent%20Hy-blue?logoColor=white)](https://cnb.cool/ai-models/tencent/Hy3-preview)
20
+ &nbsp;&nbsp;
21
+ [![GitCode](https://img.shields.io/badge/GitCode-Tencent%20Hy-red?logoColor=white)](https://ai.gitcode.com/tencent_hunyuan/Hy3-preview)
22
+
23
+ </div>
24
+
25
+ <p align="center">
26
+ 🖥️&nbsp;<a href="https://aistudio.tencent.com/"><b>Official Website</b></a>&nbsp;&nbsp;|&nbsp;&nbsp;
27
+ 💬&nbsp;<a href="https://github.com/Tencent-Hunyuan/Hy3-preview"><b>GitHub</b></a></p>
28
+
29
+ ---
30
+
31
+ ## Table of Contents
32
+
33
+ - [Model Introduction](#model-introduction)
34
+ - [Highlights](#highlights)
35
+ - [Benchmark Results](#benchmark-results)
36
+ - [STEM & Reasoning](#stem--reasoning)
37
+ - [Context Learning & Instruction Following](#context-learning--instruction-following)
38
+ - [Code & Agent](#code--agent)
39
+ - [News](#news)
40
+ - [Model Links](#model-links)
41
+ - [Quickstart](#quickstart)
42
+ - [Deployment](#deployment)
43
+ - [vLLM](#vllm)
44
+ - [SGLang](#sglang)
45
+ - [Training](#training)
46
+ - [Quantization](#quantization)
47
+ - [License](#license)
48
+ - [Contact Us](#contact-us)
49
+
50
+ ---
51
+
52
+ ## Model Introduction
53
+
54
+ **Hy3 preview** is a 295B-parameter Mixture-of-Experts (MoE) model with 21B active parameters and 3.8B MTP layer parameters, developed by the Tencent Hy Team. Hy3 preview is the first model trained on our rebuilt infrastructure, and the strongest we've shipped so far. It improves significantly on complex reasoning, instruction following, context learning, coding, and agent tasks.
55
+
56
+
57
+ | Property | Value |
58
+ |:---|:---|
59
+ | Architecture | Mixture-of-Experts (MoE) |
60
+ | Total Parameters | 295B |
61
+ | Activated Parameters | 21B |
62
+ | MTP Layer Parameters | 3.8B |
63
+ | Number of Layers (excluding MTP layer) | 80 |
64
+ | Number of MTP Layers | 1 |
65
+ | Attention Heads | 64 (GQA, 8 KV heads, head dim 128) |
66
+ | Hidden Size | 4096 |
67
+ | Intermediate Size | 13312 |
68
+ | Context Length | 256K |
69
+ | Vocabulary Size | 120832 |
70
+ | Number of Experts | 192 experts, top-8 activated |
71
+ | Supported Precisions | BF16 |
72
+
73
+ ## Highlights
74
+
75
+ - **STEM & Reasoning** — Complex reasoning underpins everything else. Hy3 preview performs well on challenging STEM benchmarks like FrontierScience-Olympiad and IMOAnswerBench, and achieved excellent results in the Tsinghua Qiuzhen College Math PhD qualifying exam (Spring '26) and the China High School Biology Olympiad (CHSBO 2025), demonstrating generalizable reasoning capacity.
76
+
77
+ - **Context Learning & Instruction Following** — Real-world tasks require the ability to parse messy, lengthy contexts and follow complex rules. We built CL-bench and CL-bench-Life from our own business scenarios to innovatively measure context learning ability. Hy3 preview exhibits solid gains in both context learning and instruction following capabilities.
78
+
79
+ - **Code & Agent** — Coding and agents saw the biggest gains. With a rebuilt RL infrastructure and larger-scale training tasks, we posted competitive scores across mainstream coding agent benchmarks (SWE-bench Verified, Terminal-Bench 2.0) and search agent benchmarks (BrowseComp, WideSearch).
80
+
81
+ ## Benchmark Results
82
+
83
+ ### Pre-trained Model Performance
84
+
85
+ | Category | Benchmark (Metric) | # Shots | Kimi-K2 BASE | DeepSeek-V3 BASE | GLM-4.5 BASE | Hy3 preview-Base |
86
+ |---|---|---|---|---|---|---|
87
+ | | #ActivatedParams | - | 32B | 37B | 32B | 21B |
88
+ | | #TotalParams | - | 1043B | 671B | 355B | 295B |
89
+ | **English** | MMLU | 5-shot | **88.24** | 87.68 | 87.73 | 87.42 |
90
+ | | MMLU-Pro | 5-shot | **65.98** | 63.98 | 63.67 | 65.76 |
91
+ | | MMLU-Redux | 5-shot | **87.18** | 86.81 | 86.56 | 86.86 |
92
+ | | ARC-Challenge | 0-shot | **96.66** | 94.65 | 96.32 | 95.99 |
93
+ | | DROP | 5-shot | 86.40 | **86.50** | 82.90 | 85.50 |
94
+ | | PIQA | 4-shot | **84.93** | 84.22 | 84.71 | 84.39 |
95
+ | | SuperGPQA | 5-shot | 51.10 | 46.17 | 49.64 | **51.60** |
96
+ | | SimpleQA | 5-shot | **34.37** | 26.15 | 29.26 | 26.47 |
97
+ | **Code** | MBPP-plus | 3-shot | **81.35** | 75.47 | 78.05 | 78.71 |
98
+ | | CRUXEval-I | 3-shot | 68.01 | 67.79 | 68.51 | **71.19** |
99
+ | | CRUXEval-O | 3-shot | 69.62 | **71.00** | 67.75 | 68.38 |
100
+ | | LiveCodeBench-v6 | 1-shot | 30.86 | 29.31 | 27.43 | **34.86** |
101
+ | **Math** | GSM8K | 4-shot | 93.46 | 88.15 | 90.06 | **95.37** |
102
+ | | MATH | 4-shot | 71.20 | 59.37 | 61.00 | **76.28** |
103
+ | | CMath | 4-shot | 90.83 | 85.50 | 89.33 | **91.17** |
104
+ | **Chinese** | C-Eval | 5-shot | **91.51** | 90.35 | 85.84 | 89.80 |
105
+ | | CMMLU | 5-shot | **90.72** | 87.90 | 86.46 | 89.61 |
106
+ | | Chinese-simpleQA | 5-shot | **74.58** | 68.72 | 68.49 | 69.73 |
107
+ | **Multilingual** | MMMLU | 5-shot | 77.63 | 79.54 | 79.26 | **80.15** |
108
+ | | INCLUDE | 5-shot | 75.66 | 77.86 | 76.27 | **78.64** |
109
+
110
+ ### Instruct Model Performance
111
+
112
+ #### STEM & Reasoning
113
+
114
+ Complex reasoning underpins everything else. Hy3 preview performs well on challenging STEM benchmarks like FrontierScience-Olympiad and IMOAnswerBench. It also achieved excellent results in the Tsinghua Qiuzhen College Math PhD qualifying exam (Spring '26) and the China High School Biology Olympiad (CHSBO 2025), demonstrating a high degree of generalizable reasoning capacity.
115
+
116
+ <p align="center"><img src="assets/bench_stem.jpg" width="800" alt="STEM & Reasoning benchmarks"/></p>
117
+
118
+ #### Context Learning & Instruction Following
119
+
120
+ Real-world tasks require the ability to parse messy, lengthy contexts and follow complex rules. We built CL-bench and CL-bench-Life from our own business scenarios to innovatively measure context learning ability. Hy3 preview exhibits solid gains in both context learning and instruction following capabilities.
121
+
122
+ <p align="center"><img src="assets/bench_context.jpg" width="800" alt="Context Learning & Instruction Following benchmarks"/></p>
123
+
124
+ #### Code & Agent
125
+
126
+ Coding and agents saw the biggest gains. With a rebuilt RL infrastructure and larger-scale training tasks, we posted competitive scores across mainstream coding agent benchmarks (SWE-bench Verified, Terminal-Bench 2.0) and search agent benchmarks (BrowseComp, WideSearch).
127
+
128
+ <p align="center"><img src="assets/bench_agent_overview_v3.jpg" width="800" alt="Agent benchmarks overview"/></p>
129
+
130
+ Coding is about whether a model can execute in a development environment. Search is about whether it can find and combine information from the open web. Both matter for complex agent scenarios like OpenClaw. Hy3 preview scores well on ClawEval and WildClawBench — a sign that its agent capabilities are becoming practical.
131
+
132
+ <p align="center"><img src="assets/bench_claw_agent.png" width="800" alt="Claw Agent benchmarks"/></p>
133
+
134
+ Beyond public benchmarks, we built internal evaluation sets to test the model in real development scenarios. On Hy-Backend (backend-focused tasks), Hy-Vibe Bench (real-user dev workflows), and Hy-SWE Max, Hy3 preview scores competitively against other open-source models.
135
+
136
+ <p align="center"><img src="assets/bench_claw_agent2.jpg" width="800" alt="Internal benchmarks"/></p>
137
+
138
+ ## News
139
+
140
+
141
+ * **[2026-04-23]** 🔥 We open-source **Hy3 preview** model weights on [Hugging Face](https://huggingface.co/tencent/Hy3-preview), [ModelScope](https://modelscope.cn/models/Tencent-Hunyuan/Hy3-preview), and [GitCode](https://ai.gitcode.com/tencent_hunyuan/Hy3-preview).
142
+
143
+ ## Model Links
144
+
145
+
146
+ | Model Name | Description | Hugging Face | ModelScope | GitCode |
147
+ |:---|:---|:---:|:---:|:---:|
148
+ | Hy3 preview | Instruct model | 🤗 [Model](https://huggingface.co/tencent/Hy3-preview) | [Model](https://modelscope.cn/models/Tencent-Hunyuan/Hy3-preview) | [Model](https://ai.gitcode.com/tencent_hunyuan/Hy3-preview) |
149
+ | Hy3 preview-Base | Pre-trained base model | 🤗 [Model](https://huggingface.co/tencent/Hy3-preview-Base) | [Model](https://modelscope.cn/models/Tencent-Hunyuan/Hy3-preview-Base) | [Model](https://ai.gitcode.com/tencent_hunyuan/Hy3-preview-Base) |
150
+
151
+ ## Quickstart
152
+
153
+ Deploy Hy3 preview with [vLLM](#vllm) or [SGLang](#sglang) first, then call the OpenAI-compatible API:
154
+
155
+ ```python
156
+ from openai import OpenAI
157
+
158
+ client = OpenAI(base_url="http://localhost:8000/v1", api_key="EMPTY")
159
+
160
+ response = client.chat.completions.create(
161
+ model="tencent/Hy3-preview",
162
+ messages=[
163
+ {"role": "user", "content": "Hello! Can you briefly introduce yourself?"},
164
+ ],
165
+ temperature=0.9,
166
+ top_p=1.0,
167
+ # reasoning_effort: "no_think" (default, direct response), "low", "high" (deep chain-of-thought)
168
+ extra_body={"chat_template_kwargs": {"reasoning_effort": "no_think"}},
169
+ )
170
+ print(response.choices[0].message.content)
171
+ ```
172
+
173
+ > **Recommended parameters**: `temperature=0.9`, `top_p=1.0`.
174
+ >
175
+ > **Reasoning mode**: Set `reasoning_effort` to `"high"` for complex tasks (math, coding, reasoning) or `"no_think"` for direct responses.
176
+
177
+ See the [Deployment](#deployment) section below for how to start the API server.
178
+
179
+ ## Deployment
180
+
181
+ Hy3-preview has 295B parameters in total. To serve it on 8 GPUs, we recommend using H20-3e or other GPUs with larger memory capacity.
182
+
183
+ ### vLLM
184
+
185
+ Build vLLM from source:
186
+ ```bash
187
+ uv venv --python 3.12 --seed --managed-python
188
+ source .venv/bin/activate
189
+ git clone https://github.com/vllm-project/vllm.git
190
+ cd vllm
191
+ uv pip install --editable . --torch-backend=auto
192
+ ```
193
+
194
+ Start the vLLM server with MTP enabled:
195
+
196
+ ```bash
197
+ vllm serve tencent/Hy3-preview \
198
+ --tensor-parallel-size 8 \
199
+ --speculative-config.method mtp \
200
+ --speculative-config.num_speculative_tokens 1 \
201
+ --tool-call-parser hy_v3 \
202
+ --reasoning-parser hy_v3 \
203
+ --enable-auto-tool-choice \
204
+ --served-model-name hy3-preview
205
+ ```
206
+
207
+ ### SGLang
208
+
209
+ Build SGLang from source:
210
+ ```bash
211
+ git clone https://github.com/sgl-project/sglang
212
+ cd sglang
213
+ pip3 install pip --upgrade
214
+ pip3 install "transformers>=5.6.0"
215
+ pip3 install -e "python"
216
+ ```
217
+
218
+ Launch SGLang server with MTP enabled:
219
+
220
+ ```bash
221
+ python3 -m sglang.launch_server \
222
+ --model tencent/Hy3-preview \
223
+ --tp 8 \
224
+ --tool-call-parser hunyuan \
225
+ --reasoning-parser hunyuan \
226
+ --speculative-num-steps 1 \
227
+ --speculative-eagle-topk 1 \
228
+ --speculative-num-draft-tokens 2 \
229
+ --speculative-algorithm EAGLE \
230
+ --served-model-name hy3-preview
231
+ ```
232
+
233
+ ## Training
234
+
235
+ Hy3 preview provides a complete model training pipeline, supporting both full fine-tuning and LoRA fine-tuning, with DeepSpeed ZeRO configurations and LLaMA-Factory integration.
236
+
237
+ For detailed training documentation, please refer to: [Training Guide](./train/README.md)
238
+
239
+ ## Quantization
240
+
241
+ We provide [AngelSlim](https://github.com/tencent/AngelSlim), a more accessible, comprehensive, and efficient toolkit for large model compression. AngelSlim supports a comprehensive suite of compression tools for large-scale multimodal models, including common quantization algorithms, low-bit quantization, and speculative sampling.
242
+
243
+ ## License
244
+
245
+
246
+ Hy3 preview is released under the **Tencent Hy Community License Agreement**. See [LICENSE](./LICENSE) for details.
247
+
248
+ ## Contact Us
249
+
250
+ If you would like to leave a message for our R&D and product teams, welcome to contact us. You can also reach us via email:
251
+
252
+ 📧 **hunyuan_opensource@tencent.com**
253
+
254
+ ---
255
+
256
+ <p align="center">
257
+ <i>Hy3 preview is developed by the Tencent Hy Team.</i>
258
+ </p>
README_CN.md ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <p align="left">
2
+ <a href="README.md">English</a>&nbsp;|&nbsp;中文
3
+ </p>
4
+ <br>
5
+
6
+ <p align="center">
7
+ <img src="assets/logo-zh.png" width="400"/> <br>
8
+ </p>
9
+
10
+ <div align="center" style="line-height: 1;">
11
+
12
+
13
+ [![License](https://img.shields.io/badge/License-Tencent%20Hy%20Community-blue)](#许可证)
14
+ &nbsp;&nbsp;
15
+ [![HuggingFace](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Tencent%20Hy-ffc107?color=ffc107&logoColor=white)](https://huggingface.co/tencent/Hy3-preview)
16
+ &nbsp;&nbsp;
17
+ [![ModelScope](https://img.shields.io/badge/ModelScope-Tencent%20Hy-624aff)](https://modelscope.cn/models/Tencent-Hunyuan/Hy3-preview)
18
+ &nbsp;&nbsp;
19
+ [![cnb.cool](https://img.shields.io/badge/cnb.cool-Tencent%20Hy-blue?logoColor=white)](https://cnb.cool/ai-models/tencent/Hy3-preview)
20
+ &nbsp;&nbsp;
21
+ [![GitCode](https://img.shields.io/badge/GitCode-Tencent%20Hy-red?logoColor=white)](https://ai.gitcode.com/tencent_hunyuan/Hy3-preview)
22
+
23
+ </div>
24
+
25
+ <p align="center">
26
+ 🖥️&nbsp;<a href="https://aistudio.tencent.com/"><b>官方网站</b></a>&nbsp;&nbsp;|&nbsp;&nbsp;
27
+ 💬&nbsp;<a href="https://github.com/Tencent-Hunyuan/Hy3-preview"><b>GitHub</b></a></p>
28
+
29
+ ---
30
+
31
+ ## 目录
32
+
33
+ - [模型介绍](#模型介绍)
34
+ - [亮点展示](#亮点展示)
35
+ - [评测结果](#评测结果)
36
+ - [复杂推理(STEM & Reasoning)](#复杂推理stem--reasoning)
37
+ - [上下文学习和指令遵循(Context Learning & Instruction Following)](#上下文学习和指令遵循context-learning--instruction-following)
38
+ - [代码和智能体(Code & Agent)](#代码和智能体code--agent)
39
+ - [新闻](#新闻)
40
+ - [模型链接](#模型链接)
41
+ - [快速开始](#快速开始)
42
+ - [推理和部署](#推理和部署)
43
+ - [vLLM](#使用-vllm-推理)
44
+ - [SGLang](#使用-sglang-推理)
45
+ - [模型训练](#模型训练)
46
+ - [量化工具](#量化工具)
47
+ - [许可证](#许可证)
48
+ - [联系我们](#联系我们)
49
+
50
+ ---
51
+
52
+ ## 模型介绍
53
+
54
+ **Hy3 preview** 是由腾讯混元团队研发的快慢思考融合的混合专家模型,总参数量 295B,激活参数 21B,MTP 层参数 3.8B。Hy3 preview 是我们重建后训练的第一个模型,也是混元迄今最智能的模型,在复杂推理、指令遵循、上下文学习、代码、智能体等能力及推理性能上实现了大幅的提升。
55
+
56
+
57
+ | 属性 | 值 |
58
+ |:---|:---|
59
+ | 架构 | 混合专家(MoE) |
60
+ | 总参数量 | 295B |
61
+ | 激活参数量 | 21B |
62
+ | MTP层参数量 | 3.8B |
63
+ | 层数(不含MTP层) | 80 |
64
+ | MTP层数 | 1 |
65
+ | 注意力头 | 64(GQA,8 个 KV 头,head dim 128) |
66
+ | 隐藏层维度 | 4096 |
67
+ | FFN 中间层维度 | 13312 |
68
+ | 上下文长度 | 256K |
69
+ | 词表大小 | 120832 |
70
+ | 专家数量 | 192 个专家,top-8 激活 |
71
+ | 支持精度 | BF16 |
72
+
73
+ ## 亮点展示
74
+
75
+ - **复杂推理(STEM & Reasoning)** — 推理能力是模型解决各种问题的基础。在 FrontierScience-Olympiad、IMOAnswerBench 等高难度理工科推理任务中表现突出,并在最新的清华大学求真书院数学博资考(26春)和全国中学生生物学联赛(CHSBO 2025)中取得优异成绩,展现出可泛化的强推理能力。
76
+
77
+ - **上下文学习和指令遵循(Context Learning & Instruction Following)** — 在各种真实的生产与生活场景,理解杂乱冗长的上下文并遵从复杂多变的规则是模型的首要挑战。基于我们多种业务场景的灵感,我们提出了 CL-bench 和 CL-bench-Life 来创新性地评估模型的上下文学习能力,并在 Hy3 preview 显著地提升了模型上下文学习和指令遵循能力。
78
+
79
+ - **代码和智能体(Code & Agent)** — Hy3 preview 提升最为显著的方向。得益于预训练及强化学习框架的重建和强化学习任务规模的提升,我们以较快的速度在 SWE-Bench Verified、Terminal-Bench 2.0 等主流代码智能体基准以及 BrowseComp、WideSearch 等主流搜索智能体基准中取得了强竞争力的结果。
80
+
81
+ ## 评测结果
82
+
83
+ ### 预训练模型效果
84
+
85
+ | Category | Benchmark (Metric) | # Shots | Kimi-K2 BASE | DeepSeek-V3 BASE | GLM-4.5 BASE | Hy3 preview-Base |
86
+ |---|---|---|---|---|---|---|
87
+ | | #ActivatedParams | - | 32B | 37B | 32B | 21B |
88
+ | | #TotalParams | - | 1043B | 671B | 355B | 295B |
89
+ | **English** | MMLU | 5-shot | **88.24** | 87.68 | 87.73 | 87.42 |
90
+ | | MMLU-Pro | 5-shot | **65.98** | 63.98 | 63.67 | 65.76 |
91
+ | | MMLU-Redux | 5-shot | **87.18** | 86.81 | 86.56 | 86.86 |
92
+ | | ARC-Challenge | 0-shot | **96.66** | 94.65 | 96.32 | 95.99 |
93
+ | | DROP | 5-shot | 86.40 | **86.50** | 82.90 | 85.50 |
94
+ | | PIQA | 4-shot | **84.93** | 84.22 | 84.71 | 84.39 |
95
+ | | SuperGPQA | 5-shot | 51.10 | 46.17 | 49.64 | **51.60** |
96
+ | | SimpleQA | 5-shot | **34.37** | 26.15 | 29.26 | 26.47 |
97
+ | **Code** | MBPP-plus | 3-shot | **81.35** | 75.47 | 78.05 | 78.71 |
98
+ | | CRUXEval-I | 3-shot | 68.01 | 67.79 | 68.51 | **71.19** |
99
+ | | CRUXEval-O | 3-shot | 69.62 | **71.00** | 67.75 | 68.38 |
100
+ | | LiveCodeBench-v6 | 1-shot | 30.86 | 29.31 | 27.43 | **34.86** |
101
+ | **Math** | GSM8K | 4-shot | 93.46 | 88.15 | 90.06 | **95.37** |
102
+ | | MATH | 4-shot | 71.20 | 59.37 | 61.00 | **76.28** |
103
+ | | CMath | 4-shot | 90.83 | 85.50 | 89.33 | **91.17** |
104
+ | **Chinese** | C-Eval | 5-shot | **91.51** | 90.35 | 85.84 | 89.80 |
105
+ | | CMMLU | 5-shot | **90.72** | 87.90 | 86.46 | 89.61 |
106
+ | | Chinese-simpleQA | 5-shot | **74.58** | 68.72 | 68.49 | 69.73 |
107
+ | **Multilingual** | MMMLU | 5-shot | 77.63 | 79.54 | 79.26 | **80.15** |
108
+ | | INCLUDE | 5-shot | 75.66 | 77.86 | 76.27 | **78.64** |
109
+
110
+ ### Instruct 模型效果
111
+
112
+ #### 复杂推理(STEM & Reasoning)
113
+
114
+ 推理能力是模型解决各种问题的基础。Hy3 preview 在 FrontierScience-Olympiad、IMOAnswerBench 等高难度理工科推理任务中表现突出,并在最新的清华大学求真书院数学博资考(26春)和全国中学生生物学联赛(CHSBO 2025)中取得优异成绩,展现出可泛化的强推理能力。
115
+
116
+ <p align="center"><img src="assets/bench_stem.jpg" width="800" alt="STEM & Reasoning 评测结果"/></p>
117
+
118
+ #### 上下文学习和指令遵循(Context Learning & Instruction Following)
119
+
120
+ 在各种真实的生产与生活场景,理解杂乱冗长的上下文并遵从复杂多变的规则是模型的首要挑战。基于我们多种业务场景的灵感,我们提出了 CL-bench 和 CL-bench-Life 来创新性地评估模型的上下文学习能力,并在 Hy3 preview 显著地提升了模型上下文学习和指令遵循能力。
121
+
122
+ <p align="center"><img src="assets/bench_context.jpg" width="800" alt="上下文学习和指令遵循评测结果"/></p>
123
+
124
+ #### 代码和智能体(Code & Agent)
125
+
126
+ 代码和智能体是 Hy3 preview 提升最为显著的方向。得益于预训练及强化学习框架的重建和强化学习任务规模的提升,我们以较快的速度在 SWE-Bench Verified、Terminal-Bench 2.0 等主流代码智能体基准以及 BrowseComp、WideSearch 等主流搜索智能体基准中取得了强竞争力的结果。
127
+
128
+ <p align="center"><img src="assets/bench_agent_overview_v3.jpg" width="800" alt="Agent 评测总览"/></p>
129
+
130
+ 在数字世界中,代码关注的是模型在开发环境中的执行能力,搜索则聚焦于开放信息空间中的检索、筛选与整合能力,两者共同决定了模型在复杂智能体场景(例如 OpenClaw)中是否真正具备可用性。Hy3 preview 在 ClawEval 和 WildClawBench 等评测中表现突出,进一步表明我们的智能体能力的全面与实用性。
131
+
132
+ <p align="center"><img src="assets/bench_claw_agent.png" width="800" alt="Claw Agent 评测"/></p>
133
+
134
+ 除了公开榜单,我们进一步构建了多个内部的评测集,对模型在真实开发场景中的表现进行评估。结果表明,无论是在后端工程任务集 Hy-Backend,贴近真实用户开发交互的 Hy-Vibe Bench,还是高难度软件工程开发任务集 Hy-SWE Max 上,Hy3 preview 均体现出了强竞争力。
135
+
136
+ <p align="center"><img src="assets/bench_claw_agent2.jpg" width="800" alt="内部评测结果"/></p>
137
+
138
+ ## 新闻
139
+
140
+ * **[2026-04-23]** 🔥 我们在 [Hugging Face](https://huggingface.co/tencent/Hy3-preview)、[ModelScope](https://modelscope.cn/models/Tencent-Hunyuan/Hy3-preview) 和 [GitCode](https://ai.gitcode.com/tencent_hunyuan/Hy3-preview) 开源了 **Hy3 preview** 模型权重。
141
+
142
+ ## 模型链接
143
+
144
+
145
+ | 模型名 | 简介 | Hugging Face | ModelScope | GitCode |
146
+ |:---|:---|:---:|:---:|:---:|
147
+ | Hy3 preview | Instruct 模型 | 🤗 [Model](https://huggingface.co/tencent/Hy3-preview) | [Model](https://modelscope.cn/models/Tencent-Hunyuan/Hy3-preview) | [Model](https://ai.gitcode.com/tencent_hunyuan/Hy3-preview) |
148
+ | Hy3 preview-Base | 预训练基座模型 | 🤗 [Model](https://huggingface.co/tencent/Hy3-preview-Base) | [Model](https://modelscope.cn/models/Tencent-Hunyuan/Hy3-preview-Base) | [Model](https://ai.gitcode.com/tencent_hunyuan/Hy3-preview-Base) |
149
+
150
+ ## 快速开始
151
+
152
+ 建议先通过 [vLLM](#使用-vllm-推理) 或 [SGLang](#使用-sglang-推理) 部署服务,然后通过 OpenAI 兼容 API 调用:
153
+
154
+ ```python
155
+ from openai import OpenAI
156
+
157
+ client = OpenAI(base_url="http://localhost:8000/v1", api_key="EMPTY")
158
+
159
+ response = client.chat.completions.create(
160
+ model="tencent/Hy3-preview",
161
+ messages=[
162
+ {"role": "user", "content": "你好!请简单介绍一下你自己。"},
163
+ ],
164
+ temperature=0.9,
165
+ top_p=1.0,
166
+ # reasoning_effort: "no_think"(默认,直接回复)、"low"、"high"(深度思维链)
167
+ extra_body={"chat_template_kwargs": {"reasoning_effort": "no_think"}},
168
+ )
169
+ print(response.choices[0].message.content)
170
+ ```
171
+
172
+ > **推荐参数**:`temperature=0.9`,`top_p=1.0`。
173
+ >
174
+ > **推理模式**:复杂任务(数学、编程、推理)建议设置 `reasoning_effort="high"`,日常对话可使用默认的 `"no_think"` 直接回复。
175
+
176
+ 具体部署方式请参考下方[推理和部署](#推理和部署)章节。
177
+
178
+ ## 推理和部署
179
+
180
+ Hy3-preview 总参数量为 295B,当使用 8 张 GPU 时,建议使用 H20-3e 或其他有更大显存的卡型。
181
+
182
+ ### vLLM
183
+
184
+ 从源码构建 vLLM:
185
+
186
+ ```bash
187
+ uv venv --python 3.12 --seed --managed-python
188
+ source .venv/bin/activate
189
+ git clone https://github.com/vllm-project/vllm.git
190
+ cd vllm
191
+ uv pip install --editable . --torch-backend=auto
192
+ ```
193
+
194
+ 启动 vLLM 服务,开启 MTP:
195
+
196
+ ```bash
197
+ vllm serve tencent/Hy3-preview \
198
+ --tensor-parallel-size 8 \
199
+ --speculative-config.method mtp \
200
+ --speculative-config.num_speculative_tokens 1 \
201
+ --tool-call-parser hy_v3 \
202
+ --reasoning-parser hy_v3 \
203
+ --enable-auto-tool-choice \
204
+ --served-model-name hy3-preview
205
+ ```
206
+
207
+ ### SGLang
208
+
209
+ 从源码构建 SGLang:
210
+
211
+ ```bash
212
+ git clone https://github.com/sgl-project/sglang
213
+ cd sglang
214
+ pip3 install pip --upgrade
215
+ pip3 install "transformers>=5.6.0"
216
+ pip3 install -e "python"
217
+ ```
218
+
219
+ 启动 SGLang 服务,开启 MTP:
220
+
221
+ ```bash
222
+ python3 -m sglang.launch_server \
223
+ --model tencent/Hy3-preview \
224
+ --tp 8 \
225
+ --tool-call-parser hunyuan \
226
+ --reasoning-parser hunyuan \
227
+ --speculative-num-steps 1 \
228
+ --speculative-eagle-topk 1 \
229
+ --speculative-num-draft-tokens 2 \
230
+ --speculative-algorithm EAGLE \
231
+ --served-model-name hy3-preview
232
+ ```
233
+
234
+ ## 模型训练
235
+
236
+ Hy3 preview 提供了完整的模型训练流程,支持全量微调和 LoRA 微调,同时支持 DeepSpeed ZeRO 多种配置以及 LLaMA-Factory 集成。
237
+
238
+ 详细的训练文档请参考:[模型训练指南](./train/README_CN.md)
239
+
240
+ ## 量化工具
241
+
242
+ 我们提供了 [AngelSlim](https://github.com/tencent/AngelSlim)——一套易用、全面、高效的大模型压缩工具包,涵盖常用量化算法、低比特量化和投机采样等能力。
243
+
244
+ ## 许可证
245
+
246
+
247
+ Hy3 preview 基于 **腾讯混元社区许可协议** 发布。详情请参阅 [LICENSE](./LICENSE)。
248
+
249
+ ## 联系我们
250
+
251
+ 如有问题或建议,欢迎通过邮件联系我们:
252
+
253
+ 📧 **hunyuan_opensource@tencent.com**
254
+
255
+ ---
256
+
257
+ <p align="center">
258
+ <i>Hy3 preview 由腾讯混元团队研发。</i>
259
+ </p>
assets/bench_agent_overview_v3.jpg ADDED

Git LFS Details

  • SHA256: 0a358cf8370dca51d962c4cf905049cc968f2db5c2898833c6a7111ec2468f28
  • Pointer size: 132 Bytes
  • Size of remote file: 1.38 MB
assets/bench_claw_agent.png ADDED
assets/bench_claw_agent2.jpg ADDED

Git LFS Details

  • SHA256: d8d9e3e6bd57f2cc9e5841c9dff682833339decdfb871b6ef7e310d118641f89
  • Pointer size: 131 Bytes
  • Size of remote file: 519 kB
assets/bench_context.jpg ADDED

Git LFS Details

  • SHA256: 624087fbed489270cbd0751053c2df75957d4bdc6d0646a03fe0496235a78253
  • Pointer size: 131 Bytes
  • Size of remote file: 977 kB
assets/bench_stem.jpg ADDED

Git LFS Details

  • SHA256: c7f89bcb6910b7389b7feb180192a810cc5818057fc44b1b30093f276fe73a11
  • Pointer size: 132 Bytes
  • Size of remote file: 1.15 MB
assets/logo-en.png ADDED
assets/logo-zh.png ADDED
train/README.md ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <p align="left">
2
+ <a href="README_CN.md">中文</a> | English
3
+ </p>
4
+
5
+ # Model Training
6
+
7
+ Hy3 preview provides processes related to model training. This section details how to process training data for model training purposes.
8
+
9
+ ## Training Data Format and Processing
10
+
11
+ **Hy3 preview supports both "slow thinking" and "fast thinking" modes. By default, the model outputs in slow thinking mode. If you wish the model to use fast thinking, you can control it via the `reasoning_effort` parameter (options: `high`, `low`, `no_think`).**
12
+
13
+ The training data should be formatted as a list of messages. By default, the system prompt for both training and inference is empty, but you may customize it as needed.
14
+
15
+ ```python
16
+ # Fast thinking pattern (no_think)
17
+ {"reasoning_effort": "no_think", "messages": [{"content": "You are a helpful assistant.\nThe current time is 2026-01-01 13:26:12 Thursday", "role": "system"}, {"content": "1+1=?", "role": "user"}, {"role": "assistant", "content": "1+1=2"}]}
18
+
19
+ # Slow thinking pattern (high)
20
+ {"reasoning_effort": "high", "messages": [{"content": "You are a helpful assistant.\nThe current time is 2026-01-01 13:26:12 Thursday", "role": "system"}, {"content": "1+1=?", "role": "user"}, {"role": "assistant", "content": "1+1=2", "reasoning_content": "The user is asking for the result of 1 + 1. In basic decimal arithmetic, 1 + 1 equals 2."}]}
21
+
22
+ from transformers import AutoTokenizer
23
+ tokenizer = AutoTokenizer.from_pretrained("./models", use_fast=False, trust_remote_code=True)
24
+ ids = tokenizer.apply_chat_template(messages, is_training=True)
25
+ ```
26
+
27
+ ## Checkpoint Format Conversion
28
+
29
+ The original Hy3 preview checkpoint stores each expert's weights independently. Before training, it needs to be converted to the HuggingFace-compatible format (fusing multiple experts per layer into 3D tensors with unified key naming) to improve loading and training speed. We provide a conversion script `convert_ckpt_to_outer.py` and a validation script `check_converted.py`, located in the `train/tools` directory.
30
+
31
+ ### Conversion
32
+
33
+ ```sh
34
+ python convert_ckpt_to_outer.py \
35
+ --input_dir <original_checkpoint_dir> \
36
+ --output_dir <output_dir> \
37
+ --workers 8
38
+ ```
39
+
40
+ **Parameters:**
41
+
42
+ - `--input_dir`: Path to the original checkpoint directory (required)
43
+ - `--output_dir`: Path to the converted checkpoint output directory (required)
44
+ - `--workers`: Number of parallel worker processes, default is 8 (optional)
45
+
46
+ The conversion script performs the following steps:
47
+ 1. Pre-scans `model.safetensors.index.json` to detect cross-shard expert groups
48
+ 2. Converts weights shard-by-shard in parallel (key renaming + expert fusion)
49
+ 3. Post-processes cross-shard expert groups (merges data from multiple shards)
50
+ 4. Copies `config.json`, tokenizer, and other files
51
+ 5. Rebuilds `model.safetensors.index.json`
52
+
53
+ ### Validation
54
+
55
+ After conversion, it is recommended to validate the result using the validation script:
56
+
57
+ ```sh
58
+ python check_converted.py <converted_checkpoint_dir> --spot-check 3
59
+ ```
60
+
61
+ **Parameters:**
62
+
63
+ - First argument: Path to the converted checkpoint directory (required)
64
+ - `--spot-check`: Number of shard files to spot-check by loading tensors and verifying shape, dtype, NaN/Inf, etc. Default is 3 (optional)
65
+
66
+ The validation script checks the following:
67
+ 1. Completeness of `config.json`
68
+ 2. Whether all expected keys in `model.safetensors.index.json` are present (including regular layers and MTP layers)
69
+ 3. Whether all referenced shard files exist and are non-empty
70
+ 4. Spot-checks tensor shape, dtype, and NaN/Inf in selected shard files
71
+ 5. Detects orphan empty shard files (cross-shard merge residues, safe to delete)
72
+
73
+ ## Quick Start
74
+
75
+ You can quickly get started by following the instructions in the Quick Start Guide.
76
+
77
+ ## Model Training
78
+
79
+ ### Hardware Requirements
80
+
81
+ Based on testing, when `make_moe_param_leaf_module` and `zero3+offload` are disabled and `max_seq_length` is set to 4096, full fine-tuning with LoRA requires at least a single machine with 8 GPUs (each with at least 80GB of memory).
82
+
83
+ Without LoRA, at least 4 machines with 32 GPUs (each with at least 80GB of memory) are required.
84
+
85
+ ### Launch Methods
86
+
87
+ Reference: [HuggingFace Transformers Trainer](https://huggingface.co/docs/transformers/main/en/main_classes/trainer)
88
+
89
+ #### Single-Machine Training
90
+ In the `train` directory, execute:
91
+
92
+ ```sh
93
+ pip install -r requirements.txt
94
+ bash train.sh
95
+ ```
96
+
97
+ #### Multi-Machine Training
98
+
99
+ To launch training across multiple machines, please follow the steps below and ensure all machines are within the same cluster.
100
+
101
+ ##### Configure Passwordless SSH Login Between Machines
102
+
103
+ The following instructions use two machines as an example, with their IPs denoted as `${ip1}` and `${ip2}`. All steps should be performed inside the Docker container.
104
+
105
+ First, configure passwordless SSH for each container on every machine:
106
+
107
+ ```sh
108
+ ssh-keygen # Generate id_rsa and id_rsa.pub for passwordless login
109
+ ssh-keygen -t rsa -A # Generate /etc/ssh/ssh_host_rsa_key and ssh_host_ecdsa_key for SSH listening
110
+ /usr/sbin/sshd -p 36005 -o ListenAddress=0.0.0.0 # Start SSH listening
111
+ echo "Port 36005" > ~/.ssh/config # Set SSH connection port to 36005
112
+ passwd root # Set the root password to avoid monitoring platform alerts
113
+ ```
114
+
115
+ Note: `36005` is an example port. You may use any available port, but ensure it is **open** and **not occupied by other processes**.
116
+
117
+ Next, in each machine's container, execute:
118
+
119
+ ```sh
120
+ cat ~/.ssh/id_rsa.pub
121
+ ```
122
+
123
+ **Copy the output SSH public key and paste it into the `~/.ssh/authorized_keys` file, one key per line. This must be done on every machine.** In the end, the `~/.ssh/authorized_keys` file on each machine should be identical and contain the public keys of all machines.
124
+
125
+ Please note that for multi-node training, the code executed on each node must be identical. It is recommended to mount a shared network drive. If this is not possible, you must manually copy the dataset, scripts, and code to the same directory on each machine.
126
+
127
+ ##### Launching Multi-Machine Training
128
+
129
+ Once the above preparations are complete and all dependencies are installed (if not, run `pip install -r requirements.txt`), add the following configuration at the beginning of `train.sh`:
130
+
131
+ ```shell
132
+ export HOST_GPU_NUM=8
133
+ # IP list, comma separated. e.g. "192.168.1.1,192.168.1.2" or single node "192.168.1.1"
134
+ export IP_LIST=${IP_LIST:-"127.0.0.1"}
135
+ ```
136
+
137
+ Note: If the `IP_LIST` environment variable is not set, replace `IP_LIST` with the IP list! The format is:
138
+ ```
139
+ For a single IP:
140
+ IP_LIST=${ip_1}
141
+
142
+ For multiple IPs:
143
+ IP_LIST=${ip_1},${ip_2}
144
+
145
+ ```
146
+
147
+ Replace `${ip_1}` and `${ip_2}` with the actual IP addresses.
148
+
149
+ Then, on the machine with `${ip1}`, execute `bash train.sh` in the `train/` directory. On first launch, you may see the following output:
150
+
151
+ ```ssh
152
+ The authenticity of host '[ip]:36005 ([ip]:36005)' can't be established.
153
+ ECDSA key fingerprint is xxxxxx.
154
+ ECDSA key fingerprint is MD5:xxxxxx.
155
+ Are you sure you want to continue connecting (yes/no)?
156
+ ```
157
+
158
+ Type `yes` to continue.
159
+
160
+ ##### Key Parameters
161
+
162
+ The key parameters in the script are as follows:
163
+
164
+ - `--deepspeed`: Path to the DeepSpeed configuration file. Three default DeepSpeed configuration files are provided in the `train` folder: `ds_zero2_no_offload.json`, `ds_zero3_no_offload.json`, and `ds_zero3_offload.json`, with decreasing memory requirements in that order.
165
+ - `--model_name_or_path`: Path to the Hy3 preview HF pre-trained model weights to load, otherwise loading will fail.
166
+ - `--tokenizer_name_or_path`: Path to the tokenizer folder, otherwise loading will fail.
167
+ - `--train_data_file`: Path to the training file, which should be a jsonl file.
168
+ - `--output_dir`: Output directory where logs, tensorboard files, and model weights will be stored.
169
+ - `--per_device_train_batch_size`: Batch size per GPU.
170
+ - `--gradient_accumulation_steps`: Number of gradient accumulation steps. The global batch size is `per_device_train_batch_size * gradient_accumulation_steps * dp_size`.
171
+ - `--max_steps`: Total number of training steps.
172
+ - `--save_steps`: Number of steps between saving checkpoints.
173
+ - `--use_lora`: Whether to use LoRA training. Also accepts `--lora_rank`, `--lora_alpha`, and `--lora_dropout` parameters. By default, LoRA is applied to "q_proj", "k_proj", "v_proj", and "o_proj". To change this, modify the code. Note: ** When using LoRA training, only the LoRA weights are saved, not the base model weights. ** To merge LoRA weights, see the "LoRA Weight Merging" section below.
174
+ - `--make_moe_param_leaf_module`: When using zero3 and MoE training, treat the MoE module as a leaf module, i.e., its parameters are not partitioned by zero3. This option is expected to significantly increase memory usage.
175
+ - `--gradient_checkpointing`: Enable gradient checkpointing.
176
+ - `--train_attention_params_only`: Whether to train only attention parameters.
177
+ - `--learning_rate`: Maximum learning rate during training.
178
+ - `--min_lr`: Minimum learning rate during training.
179
+ - `--use_flash_attn`: Enable flash-attention for accelerated training.
180
+
181
+ **Notes:**
182
+
183
+ - To resume training from a previously saved checkpoint rather than loading pre-trained weights, specify `--resume_from_checkpoint` with the path to the checkpoint. Do not specify `--model_name_or_path`, this will load only the weights, not the training state.
184
+ - When resuming from a checkpoint, there may be minor differences in loss due to the randomness of some non-deterministic algorithms. This is normal. See: [HuggingFace Transformers Trainer Randomness](https://huggingface.co/docs/transformers/main/en/main_classes/trainer#randomness)
185
+ - When `--model_name_or_path` is specified, all model-related parameters will be ignored.
186
+ - Samples within a batch are padded to the length of the longest sample in the batch, but the maximum length of each sample is `max_seq_length`. Any excess will be truncated.
187
+ - If you see a warning about bias weights not being loaded, you can ignore it. Hunyuan-Large does not use bias.
188
+
189
+ #### What if GPU Memory is Insufficient?
190
+
191
+ Reference: [DeepSpeed Configuration](https://www.deepspeed.ai/docs/config-json/)
192
+
193
+ You can try modifying the DeepSpeed configuration by removing the `auto` attribute from the following parameters and reducing their values:
194
+
195
+ - `stage3_param_persistence_threshold`
196
+ - `stage3_prefetch_bucket_size`
197
+ - `stage3_max_reuse_distance`
198
+
199
+ #### LoRA Weight Merging
200
+
201
+ LoRA weights saved during training cannot be merged into the zero3 model at runtime, as zero3 partitions model weights across data parallel ranks. To merge LoRA weights into the base model, you can do so offline to obtain a merged weight file. Run `merge_lora_weight.sh` to merge the LoRA and base model weights. The parameters are:
202
+
203
+ - `--base_model_path`: Directory of the base model weights
204
+ - `--adapter_model_path`: Directory of the LoRA weights
205
+ - `--output_path`: Directory to save the merged weights
206
+ - `--save_dtype`: Data type for saving the merged weights; options are: fp16, bf16, fp32
207
+
208
+ #### LLaMA-Factory Support
209
+
210
+ If you are familiar with LLaMA-Factory, you may use it for fine-tuning. All scripts, code, and configuration files are archived in the `./train/llama_factory_support` directory. Unless otherwise specified, all files mentioned below are located in this directory.
211
+
212
+ ##### Installation
213
+
214
+ You can install LLaMA-Factory by downloading the source code from https://github.com/hiyouga/LLaMA-Factory/tree/main and following the instructions on the website.
215
+
216
+ ##### Configuration Files
217
+
218
+ We provide sample LLaMA-Factory training configuration files: `hy_v3_lora_sft.yaml` and `hy_v3_full_sft.yaml`, corresponding to LoRA training and full fine-tuning respectively.
219
+
220
+ Key parameters in the configuration files are as follows:
221
+
222
+ **Model:**
223
+
224
+ - `model_name_or_path`: Path to the Hy3 preview HF format pre-trained model weights
225
+ - `trust_remote_code`: Whether to trust remote code; Hy3 preview requires this to be set to `true`
226
+
227
+ **Training Method:**
228
+
229
+ - `stage`: Training stage, currently `sft` (supervised fine-tuning)
230
+ - `finetuning_type`: Fine-tuning type, either `full` (full fine-tuning) or `lora` (LoRA fine-tuning)
231
+ - `deepspeed`: DeepSpeed configuration file path; `ds_zero3_offload_hy.json` is recommended for full fine-tuning, `ds_zero2_offload_lora.json` for LoRA fine-tuning
232
+
233
+ **LoRA Parameters (only effective during LoRA fine-tuning):**
234
+
235
+ - `lora_rank`: LoRA rank, default `64`
236
+ - `lora_alpha`: LoRA alpha coefficient, default `128`
237
+ - `lora_dropout`: LoRA dropout ratio, default `0.05`
238
+ - `lora_target`: Target modules for LoRA, default `q_proj,k_proj,v_proj,o_proj`
239
+
240
+ **Dataset:**
241
+
242
+ - `dataset_dir`: Dataset directory path
243
+ - `dataset`: Dataset name, must be registered in `dataset_info.json` under `dataset_dir`
244
+ - `template`: Chat template; Hy3 preview uses `hy_v3`
245
+ - `cutoff_len`: Maximum sequence length; sequences exceeding this will be truncated. For full fine-tuning, can be set to `262144` (262K); for LoRA fine-tuning, `8192` is recommended to save memory
246
+ - `max_samples`: Maximum number of samples per dataset
247
+ - `overwrite_cache`: Whether to overwrite cached preprocessed datasets
248
+
249
+ **Output:**
250
+
251
+ - `output_dir`: Output directory where logs, TensorBoard files, and weights will be stored
252
+ - `logging_steps`: Number of steps between logging
253
+ - `save_steps`: Number of steps between saving checkpoints
254
+ - `plot_loss`: Whether to plot the training loss curve
255
+ - `overwrite_output_dir`: Whether to overwrite the existing output directory
256
+ - `save_only_model`: Whether to save only model weights (excluding optimizer states, etc.)
257
+ - `report_to`: Logging tool, options: `none`, `wandb`, `tensorboard`, `swanlab`, `mlflow`
258
+
259
+ **Training Hyperparameters:**
260
+
261
+ - `per_device_train_batch_size`: Batch size per GPU
262
+ - `gradient_accumulation_steps`: Gradient accumulation steps; `per_device_train_batch_size * gradient_accumulation_steps * dp_size` equals the global batch size
263
+ - `learning_rate`: Maximum learning rate; `1.0e-5` recommended for full fine-tuning, `2.0e-4` for LoRA fine-tuning
264
+ - `num_train_epochs`: Number of training epochs
265
+ - `lr_scheduler_type`: Learning rate scheduler type; `cosine_with_min_lr` is recommended
266
+ - `lr_scheduler_kwargs.min_lr_rate`: Ratio of minimum to maximum learning rate; e.g., `0.1` means the minimum learning rate is 10% of the maximum
267
+ - `warmup_ratio`: Proportion of total training steps used for warmup
268
+ - `bf16`: Whether to use BFloat16 mixed precision training
269
+ - `gradient_checkpointing`: Whether to enable gradient checkpointing to save memory
270
+ - `ddp_timeout`: Distributed training timeout (milliseconds)
271
+ - `flash_attn`: Attention implementation; `fa2` (FlashAttention-2) is recommended, `sdpa` is also available; using `fa2` requires the flash-attn package
272
+ - `resume_from_checkpoint`: Resume training from a specified checkpoint path; set to `null` to start from scratch
273
+
274
+ ##### Launch Training
275
+
276
+ Please first configure passwordless SSH login between machines following the [Configure Passwordless SSH Login Between Machines](#configure-passwordless-ssh-login-between-machines) section above.
277
+
278
+ Modify the following configuration at the beginning of `train_lf.sh`:
279
+
280
+ ```shell
281
+ export HOST_GPU_NUM=8
282
+ # IP list, comma separated. e.g. "192.168.1.1,192.168.1.2" or single node "192.168.1.1"
283
+ export IP_LIST=${IP_LIST:-"127.0.0.1"}
284
+ ```
285
+
286
+ Note: If the `IP_LIST` environment variable is not set, replace `IP_LIST` with the IP list! The format is:
287
+ ```
288
+ For a single IP:
289
+ IP_LIST=${ip_1}
290
+
291
+ For multiple IPs:
292
+ IP_LIST=${ip_1},${ip_2}
293
+
294
+ ```
295
+
296
+ Replace `${ip_1}` and `${ip_2}` with the actual IP addresses.
297
+
298
+ Then, on each machine, run `bash train_lf.sh` in the `train/llama_factory_support/` directory.
train/README_CN.md ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <p align="left">
2
+ <a href="README.md">English</a> | 中文
3
+ </p>
4
+
5
+ # 模型训练
6
+
7
+ Hy3 preview 提供了模型训练相关流程,您可以在此章节对训练数据格式进行处理以供模型训练使用。
8
+
9
+ ## 训练数据格式及处理
10
+
11
+ **Hy3 preview 同时支持慢思考与快思考两种模式,模型的默认输出是慢思考模式,若想让模型进行快思考,可通过 `reasoning_effort` 参数控制(可选值:`high`、`low`、`no_think`)。**
12
+
13
+ 训练数据按照以下形式处理为 messages 格式,训练和推理的默认 system prompt 为空,可以根据自己的需求进行设定。
14
+
15
+ ```python
16
+ # Fast thinking pattern (no_think)
17
+ {"reasoning_effort": "no_think", "messages": [{"content": "你是一个有用的人工智能助手。\n现在的时间是2026-01-01 13:26:12 周四", "role": "system"}, {"content": "1+1=?", "role": "user"}, {"role": "assistant", "content": "1+1=2"}]}
18
+
19
+ # Slow thinking pattern (high)
20
+ {"reasoning_effort": "high", "messages": [{"content": "你是一个有用的人工智能助手。\n现在的时间是2026-01-01 13:26:12 周四", "role": "system"}, {"content": "1+1=?", "role": "user"}, {"role": "assistant", "content": "1+1=2", "reasoning_content": "1+1=2"}]}
21
+
22
+ from transformers import AutoTokenizer
23
+ tokenizer = AutoTokenizer.from_pretrained("./models", use_fast=False, trust_remote_code=True)
24
+ ids = tokenizer.apply_chat_template(messages, is_training=True)
25
+ ```
26
+
27
+ ## 权重格式转换
28
+
29
+ Hy3 preview 的原始 checkpoint 采用每个 expert 独立存储的格式,在训练前需要转换为 expert 融合后的 HuggingFace 标准格式(将同一层的多个 expert 权重融合为 3D 张量,并统一 key 命名),用于提高加载和训练的速率。我们提供了转换脚本 `convert_ckpt_to_outer.py` 和校验脚本 `check_converted.py`,位于 `train/tools` 目录下。
30
+
31
+ ### 转换
32
+
33
+ ```sh
34
+ python convert_ckpt_to_outer.py \
35
+ --input_dir <原始checkpoint目录> \
36
+ --output_dir <输出目录> \
37
+ --workers 8
38
+ ```
39
+
40
+ **参数说明:**
41
+
42
+ - `--input_dir`:原始 checkpoint 目录路径(必选)
43
+ - `--output_dir`:转换后的 checkpoint 输出目录路径(必选)
44
+ - `--workers`:并行转换的进程数,默认为 8(可选)
45
+
46
+ 转换脚本会执行以下步骤:
47
+ 1. 预扫描 `model.safetensors.index.json`,检测跨 shard 的 expert 分组
48
+ 2. 逐 shard 并行转换权重(key 重命名 + expert 融合)
49
+ 3. 后处理跨 shard 的 expert 分组(合并来自多个 shard 的数据)
50
+ 4. 复制 `config.json`、tokenizer 等其他文件
51
+ 5. 重建 `model.safetensors.index.json`
52
+
53
+ ### 校验
54
+
55
+ 转换完成后,建议使用校验脚本验证转换结果的完整性:
56
+
57
+ ```sh
58
+ python check_converted.py <转换后的checkpoint目录> --spot-check 3
59
+ ```
60
+
61
+ **参数说明:**
62
+
63
+ - 第一个参数:转换后的 checkpoint 目录路径(必选)
64
+ - `--spot-check`:随机抽检的 shard 文件数量,会加载 tensor 并检查 shape、dtype、NaN/Inf 等,默认为 3(可选)
65
+
66
+ 校验脚本会检查以下内容:
67
+ 1. `config.json` 的完整性
68
+ 2. `model.safetensors.index.json` 中所有预期 key 是否齐全(包括常规层和 MTP 层)
69
+ 3. 所有引用的 shard 文件是否存在且非空
70
+ 4. 抽检 shard 文件中 tensor 的 shape、dtype 是否正确,是否存在 NaN/Inf
71
+ 5. 检测孤立的空 shard 文件(跨 shard 合并残留,可安全删除)
72
+
73
+ ## 快速开始
74
+
75
+ 您可以参照快速开始文档中的内容进行快速上手。
76
+
77
+ ## 模型训练
78
+
79
+ ### 硬件需求
80
+
81
+ 经过测试,不开 make_moe_param_leaf_module 以及 zero3+offload,max_seq_length 为 4096,使用LORA,全量微调最少需要单机 8 卡(显存至少80G)。
82
+
83
+ 如果不使用LORA,最少需要4机32卡(显存至少80G)。
84
+
85
+ ### 启动方式
86
+
87
+ 参考:[HuggingFace Transformers Trainer](https://huggingface.co/docs/transformers/main/en/main_classes/trainer)
88
+
89
+ #### 单机启动训练
90
+ 在`train`目录下,执行:
91
+
92
+ ```sh
93
+ pip install -r requirements.txt
94
+ bash train.sh
95
+ ```
96
+
97
+ #### 多机启动训练
98
+
99
+ 如果要用多台机器启动训练,请按照以下步骤执行,并保证多台机器在一个集群内。
100
+
101
+ ##### 配置机器间免密 ssh 登录
102
+
103
+ 以下操作以两个机器为例,两台机器的 ip 分别以`${ip1}`和`${ip2}`标识,以下操作均在 docker container 内执行。
104
+
105
+ 首先,配置多机container免密,在每台机器上执行。
106
+
107
+ ```sh
108
+ ssh-keygen # 生成id_rsa和id_rsa.pub,用于免密登录
109
+ ssh-keygen -t rsa -A # 生成/etc/ssh/ssh_host_rsa_key和ssh_host_ecdsa_key, 用于后面启动ssh listen
110
+ /usr/sbin/sshd -p 36005 -o ListenAddress=0.0.0.0 # 启动 SSH 监听
111
+ echo "Port 36005" > ~/.ssh/config # ssh 连接端口修改为 36005
112
+ passwd root # 需要配置root密码,否则监测平台会报警
113
+ ```
114
+
115
+ 注意:这里的`36005`是一个示例端口,可以选用任意端口,但需要保证使用的端口**开放**且**不被其他的进程占用**。
116
+
117
+ 接下来,在每台机器的 container 内,执行:
118
+
119
+ ```sh
120
+ cat ~/.ssh/id_rsa.pub
121
+ ```
122
+
123
+ **将输出的 ssh 公钥复制并粘贴到`~/.ssh/authorized_keys`文件中,每行一个公钥,每台机器上都要做这个操作**。最终每台机器上的`~/.ssh/authorized_keys`文件内容应当是一致的,并且包含了所有机器的公钥。
124
+
125
+ 需要注意,多节点训练时,每个节点上执行的代码都得一致,建议挂载一个共享的网络盘,如果无法挂载共享网盘,则需要手动将数据集、脚本、代码复制在多台机器的相同目录下。
126
+
127
+ ##### 启动多机训练
128
+
129
+ 在以上准备步骤准备好了之后,以及确认依赖已经安装完成(如未安装,请执行`pip install -r requirements.txt`安装),就可以在`train.sh`中的开头增加以下配置:
130
+
131
+ ```shell
132
+ export HOST_GPU_NUM=8
133
+ # IP list, comma separated. e.g. "192.168.1.1,192.168.1.2" or single node "192.168.1.1"
134
+ IP_LIST=${IP_LIST:-"127.0.0.1"}
135
+ ```
136
+
137
+ 注意:如果`IP_LIST`环境变量未设置,则将`IP_LIST`替换为IP列表!格式为:
138
+ ```
139
+ 如果只有一个IP:
140
+ IP_LIST=${ip_1}
141
+
142
+ 如果有多个IP:
143
+ IP_LIST=${ip_1},${ip_2}
144
+
145
+ ```
146
+
147
+ 请将`${ip_1}`和`${ip_2}`替换为真实的IP地址。
148
+
149
+ 然后,在`${ip1}`的机器上,在`train/`目录下,执行`bash train.sh`即可,注意第一次启动时可能会看见以下的输出:
150
+
151
+ ```ssh
152
+ The authenticity of host '[ip]:36005 ([ip]:36005)' can't be established.
153
+ ECDSA key fingerprint is xxxxxx.
154
+ ECDSA key fingerprint is MD5:xxxxxx.
155
+ Are you sure you want to continue connecting (yes/no)?
156
+ ```
157
+
158
+ 此时输入`yes`即可继续。
159
+
160
+ ##### 关键参数
161
+
162
+ 脚本中的关键参数如下:
163
+
164
+ - `--deepspeed`: 此参数应当指向一个 deepspeed 的配置文件,`train`文件夹下提供了三种 DeepSpeed 的默认配置文件:`ds_zero2_no_offload.json`, `ds_zero3_no_offload.json`, `ds_zero3_offload.json`,这三个配置文件所需显存依次减少
165
+ - `--model_name_or_path`: 要加载的 Hy3 preview 的 HF 预训练模型权重,否则无法加载
166
+ - `--tokenizer_name_or_path`: tokenizer 文件夹路径, 否则无法加载
167
+ - `--train_data_file`: 训练文件路径,应该为一个 jsonl 文件
168
+ - `--output_dir`: 输出文件夹,log、tensorboard 和权重都会存储在这个路径下
169
+ - `--per_device_train_batch_size`: 每张卡上的 batch size
170
+ - `--gradient_accumulation_steps`: 梯度累计次数,`per_device_train_batch_size * gradient_accumulation_steps * dp_size`为 global_batch_size
171
+ - `--max_steps`: 训练的总步数
172
+ - `--save_steps`: 每多少个 step 存储一个 checkpoint
173
+ - `--use_lora`: 是否用 lora 训练,同时接收`--lora_rank`,`--lora_alpha`和`--lora_dropout`参数。lora 默认应用于 "q_proj", "k_proj", "v_proj", "o_proj" 四个参数,如果需要改变的话在代码中修改即可。注意:**使用 lora 训练时,只会保存 lora 的权重,而不会保存 base 模型的权重**,如果需要合并 lora 权重,看下面的“Lora 权重合并”一节
174
+ - `--make_moe_param_leaf_module`:当用 zero3 以及 MoE 训练时,将 MoE 模块视作一个 leaf module,即它的参数不进行 zero3 切分,这个选项预计会显著增加显存占用
175
+ - `--gradient_checkpointing`:开启梯度重计算
176
+ - `--train_attention_params_only`: 是否只训练 attention 参数
177
+ - `--learning_rate`: 训练时的最大学习率
178
+ - `--min_lr`: 训练时的最小学习率
179
+ - `--use_flash_attn`: 开启 flash-attention 进行训练加速
180
+
181
+ **注意:**
182
+
183
+ - 如果想从一个中途保存的 ckpt 继续训练,而不是加载一个预训练的权重,直接指定`--resume_from_checkpoint`为之前训练保存的 ckpt 路径,不要指定`--model_name_or_path`,这样只会加载权重,而不会加载训练状态
184
+ - 从 ckpt 继续训练时,loss 可能会有微小的偏差,这是由一些非确定性算法带来的随机性,是正常现象。参考:[HuggingFace Transformers Trainer Randomness](https://huggingface.co/docs/transformers/main/en/main_classes/trainer#randomness)
185
+ - 当 `--model_name_or_path` 有效时,所有模型相关的参数都会被忽略
186
+ - 一个 batch 内的样本会通过 padding 对齐 batch 内最长的样本,而每条样本的长度最长为 max_seq_length,超出的部分会被裁剪
187
+ - 如果报出 bias 权重没有 load 的 warning,忽略即可,Hunyuan-Large 中不会用到 bias
188
+
189
+ #### 显存不足怎么办?
190
+
191
+ 参考:[DeepSpeed Configuration](https://www.deepspeed.ai/docs/config-json/)
192
+
193
+ 可以尝试修改 ds config,去掉这几个参数的 auto 属性,改小试试看:
194
+
195
+ - `stage3_param_persistence_threshold`
196
+ - `stage3_prefetch_bucket_size`
197
+ - `stage3_max_reuse_distance`
198
+
199
+ #### Lora 模型合并
200
+
201
+ 保存下来的 lora 权重没法在训练运行时合并到 zero3 模型中,因为 zero3 开启时模型权重会切分到各 dp rank 上。因此如果想把 lora 权重合并到 base 模型上,可以通过离线的方式合并后得到权重文件。执行`merge_lora_weight.sh`即可完成 lora 权重和 base 模型权重的合并,其中的参数有:
202
+
203
+ - `--base_model_path`:base 模型的权重目录
204
+ - `--adapter_model_path`:lora 权重目录
205
+ - `--output_path`:合并后的权重保存目录
206
+ - `--save_dtype`: 以什么数据格式存储合并后的权重,可选值:fp16,bf16,fp32
207
+
208
+ #### LLaMA-Factory 支持
209
+
210
+ 如果对 LLaMA-Factory 较为熟悉,可使用LLaMA-Factory进行微调。脚本、代码以及配置文件都归档在`./train/llama_factory_support`目录下。如果没有特别说明,接下来我们提到的文件都是该目录下的文件。
211
+
212
+ ##### 安装
213
+
214
+ 可以通过下载源码 https://github.com/hiyouga/LLaMA-Factory/tree/main ,根据网站的指引进行安装。
215
+
216
+ ##### 配置文件
217
+
218
+ 我们提供了 llama-factory 的训练示例配置文件 `hy_v3_lora_sft.yaml`和`hy_v3_full_sft.yaml`文件,分别对应LORA训练和非LORA训练。
219
+
220
+ 脚本中的关键参数如下:
221
+
222
+ **模型相关:**
223
+
224
+ - `model_name_or_path`: Hy3 preview HF 格式预训练模型权重路径
225
+ - `trust_remote_code`: 是否信任远程代码, Hy3 preview 需要设置为 `true`
226
+
227
+ **训练方法:**
228
+
229
+ - `stage`: 训练阶段, 当前为 `sft`(监督微调)
230
+ - `finetuning_type`: 微调类型, 可选 `full`(全量微调) 或 `lora`(LoRA 微调)
231
+ - `deepspeed`: DeepSpeed 配置文件路径, 全量微调推荐 `ds_zero3_offload_hy.json`, LoRA 微调推荐 `ds_zero2_offload_lora.json`
232
+
233
+ **LoRA 参数(仅 LoRA 微调时生效):**
234
+
235
+ - `lora_rank`: LoRA 秩, 默认 `64`
236
+ - `lora_alpha`: LoRA alpha 系数, 默认 `128`
237
+ - `lora_dropout`: LoRA dropout 比率, 默认 `0.05`
238
+ - `lora_target`: LoRA 应用的目标模块, 默认为 `q_proj,k_proj,v_proj,o_proj`
239
+
240
+ **数据集:**
241
+
242
+ - `dataset_dir`: 数据集目录路径
243
+ - `dataset`: 数据集名称, 需要在 `dataset_dir` 下的 `dataset_info.json` 中注册
244
+ - `template`: 对话模板, Hy3 preview 使用 `hy_v3`
245
+ - `cutoff_len`: 最大序列长度, 超出部分会被截断; 全量微调可设为 `262144`(262K), LoRA 微调建议设为 `8192` 以节省显存
246
+ - `max_samples`: 每个数据集最多使用的样本数
247
+ - `overwrite_cache`: 是否覆盖已缓存的预处理数据集
248
+
249
+ **输出:**
250
+
251
+ - `output_dir`: 输出目录, 日志、TensorBoard 和权重都会存储在此路径下
252
+ - `logging_steps`: 每多少步记录一次日志
253
+ - `save_steps`: 每多少步保存一次 checkpoint
254
+ - `plot_loss`: 是否绘制训练 loss 曲线
255
+ - `overwrite_output_dir`: 是否覆盖已有的输出目录
256
+ - `save_only_model`: 是否只保存模型权重(不保存优化器状态等)
257
+ - `report_to`: 日志上报工具, 可选 `none`, `wandb`, `tensorboard`, `swanlab`, `mlflow`
258
+
259
+ **训练超参数:**
260
+
261
+ - `per_device_train_batch_size`: 每张卡上的 batch size
262
+ - `gradient_accumulation_steps`: 梯度累积步数, `per_device_train_batch_size * gradient_accumulation_steps * dp_size` 为 global batch size
263
+ - `learning_rate`: 最大学习率, 全量微调推荐 `1.0e-5`, LoRA 微调推荐 `2.0e-4`
264
+ - `num_train_epochs`: 训练轮数
265
+ - `lr_scheduler_type`: 学习率调度器类型, 推荐使用 `cosine_with_min_lr`
266
+ - `lr_scheduler_kwargs.min_lr_rate`: 最小学习率与最大学习率的比值, 例如 `0.1` 表示最小学习率为最大学习率的 10%
267
+ - `warmup_ratio`: 预热阶段占总训练步数的比例
268
+ - `bf16`: 是否使用 BFloat16 混合精度训练
269
+ - `gradient_checkpointing`: 是否开启梯度重计算以节省显存
270
+ - `ddp_timeout`: 分布式训练超时时间(毫秒)
271
+ - `flash_attn`: 注意力实现方式, 推荐 `fa2`(FlashAttention-2), 也可选 `sdpa`; 使用 `fa2` 需要安装 flash-attn 包
272
+ - `resume_from_checkpoint`: 从指定 checkpoint 路径恢复训练, 设为 `null` 表示从头开始训练
273
+
274
+ ##### 启动训练
275
+
276
+ 请先按照前面章节 [配置机器间免密 ssh 登录](#配置机器间免密-ssh-登录) 配置多机免密登录。
277
+
278
+ 修改`train_lf.sh`中开头的以下配置:
279
+
280
+ ```shell
281
+ export HOST_GPU_NUM=8
282
+ # IP list, comma separated. e.g. "192.168.1.1,192.168.1.2" or single node "192.168.1.1"
283
+ export IP_LIST=${IP_LIST:-"127.0.0.1"}
284
+ ```
285
+
286
+ 注意:如果`IP_LIST`环境变量未设置,则将`IP_LIST`替换为IP列表!格式为:
287
+ ```
288
+ 如果只有一个IP:
289
+ IP_LIST=${ip_1}
290
+
291
+ 如果有多个IP:
292
+ IP_LIST=${ip_1},${ip_2}
293
+
294
+ ```
295
+
296
+ 请将`${ip_1}`和`${ip_2}`替换为真实的IP地址。
297
+
298
+ 然后,在每一台机器上,在`train/llama_factory_support/`目录下执行`bash train_lf.sh`。
train/ds_zero2_no_offload.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fp16": {
3
+ "enabled": "auto",
4
+ "loss_scale": 0,
5
+ "loss_scale_window": 100,
6
+ "initial_scale_power": 16,
7
+ "hysteresis": 2,
8
+ "min_loss_scale": 1e-10
9
+ },
10
+ "zero_optimization": {
11
+ "stage": 2,
12
+ "allgather_partitions": true,
13
+ "allgather_bucket_size": 1e8,
14
+ "overlap_comm": true,
15
+ "reduce_scatter": true,
16
+ "reduce_bucket_size": 1e8,
17
+ "contiguous_gradients": true
18
+ },
19
+ "gradient_accumulation_steps": "auto",
20
+ "gradient_clipping": "auto",
21
+ "steps_per_print": 10,
22
+ "train_batch_size": "auto",
23
+ "train_micro_batch_size_per_gpu": "auto",
24
+ "wall_clock_breakdown": false
25
+ }
train/ds_zero3_no_offload.json ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fp16": {
3
+ "enabled": "auto",
4
+ "loss_scale": 0,
5
+ "loss_scale_window": 1000,
6
+ "initial_scale_power": 16,
7
+ "hysteresis": 2,
8
+ "min_loss_scale": 1
9
+ },
10
+ "bf16": {
11
+ "enabled": "auto"
12
+ },
13
+
14
+ "zero_optimization": {
15
+ "stage": 3,
16
+ "offload_optimizer": {
17
+ "device": "none",
18
+ "pin_memory": true
19
+ },
20
+ "offload_param": {
21
+ "device": "none",
22
+ "pin_memory": true
23
+ },
24
+ "overlap_comm": true,
25
+ "contiguous_gradients": true,
26
+ "sub_group_size": 1e9,
27
+ "reduce_bucket_size": "auto",
28
+ "stage3_prefetch_bucket_size": "auto",
29
+ "stage3_param_persistence_threshold": "auto",
30
+ "stage3_max_live_parameters": 1e9,
31
+ "stage3_max_reuse_distance": 1e9,
32
+ "stage3_gather_16bit_weights_on_model_save": true
33
+ },
34
+
35
+ "gradient_accumulation_steps": "auto",
36
+ "gradient_clipping": "auto",
37
+ "steps_per_print": 10,
38
+ "train_batch_size": "auto",
39
+ "train_micro_batch_size_per_gpu": "auto",
40
+ "wall_clock_breakdown": false
41
+ }
train/ds_zero3_offload.json ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fp16": {
3
+ "enabled": "auto",
4
+ "loss_scale": 0,
5
+ "loss_scale_window": 1000,
6
+ "initial_scale_power": 16,
7
+ "hysteresis": 2,
8
+ "min_loss_scale": 1
9
+ },
10
+ "bf16": {
11
+ "enabled": "auto"
12
+ },
13
+
14
+ "zero_optimization": {
15
+ "stage": 3,
16
+ "offload_optimizer": {
17
+ "device": "cpu",
18
+ "pin_memory": true
19
+ },
20
+ "offload_param": {
21
+ "device": "cpu",
22
+ "pin_memory": true
23
+ },
24
+ "overlap_comm": true,
25
+ "contiguous_gradients": true,
26
+ "sub_group_size": 1e9,
27
+ "reduce_bucket_size": "auto",
28
+ "stage3_prefetch_bucket_size": "auto",
29
+ "stage3_param_persistence_threshold": "auto",
30
+ "stage3_max_live_parameters": 1e9,
31
+ "stage3_max_reuse_distance": 1e9,
32
+ "stage3_gather_16bit_weights_on_model_save": false
33
+ },
34
+
35
+ "gradient_accumulation_steps": "auto",
36
+ "gradient_clipping": "auto",
37
+ "steps_per_print": 10,
38
+ "train_batch_size": "auto",
39
+ "train_micro_batch_size_per_gpu": "auto",
40
+ "wall_clock_breakdown": false
41
+ }
train/ds_zero3_offload_no_auto.json ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fp16": {
3
+ "enabled": false,
4
+ "loss_scale": 0,
5
+ "loss_scale_window": 1000,
6
+ "initial_scale_power": 16,
7
+ "hysteresis": 2,
8
+ "min_loss_scale": 1
9
+ },
10
+ "bf16": {
11
+ "enabled": true
12
+ },
13
+
14
+ "zero_optimization": {
15
+ "stage": 3,
16
+ "offload_optimizer": {
17
+ "device": "cpu",
18
+ "pin_memory": true
19
+ },
20
+ "overlap_comm": true,
21
+ "contiguous_gradients": true,
22
+ "sub_group_size": 1e9,
23
+ "reduce_bucket_size": 1e8,
24
+ "stage3_prefetch_bucket_size": 1e8,
25
+ "stage3_param_persistence_threshold": 0,
26
+ "stage3_max_live_parameters": 1e8,
27
+ "stage3_max_reuse_distance": 1e8,
28
+ "stage3_gather_16bit_weights_on_model_save": true
29
+ },
30
+
31
+ "gradient_accumulation_steps": 1,
32
+ "gradient_clipping": "auto",
33
+ "steps_per_print": 10,
34
+ "train_micro_batch_size_per_gpu": 1,
35
+ "wall_clock_breakdown": false
36
+ }
train/example_data.jsonl ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {"messages": [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "能否帮我撰写一个关于环保的议论文开头部分?\n主题:环保\n立场:重要性和我们的责任"}, {"role": "assistant", "content": "在我们的生活中,环保已经成为一个不可忽视的议题。随着科技的发展和人口的增长,我们的环境正在遭受前所未有的压力。森林被砍伐,河流被污染,动植物种类正在消失,全球气候正在发生变化。这些都是我们无法忽视的事实。然而,环保并不仅仅是一个环境问题,它更是一个关乎我们生存和发展的问题。环保的重要性已经超越了任何一个单一的领域,它关乎我们的健康、经济、社会和文化。因此,我们每一个人都有责任和义务去保护我们的环境,这不仅是为了我们自己,也是为了我们的后代和整个地球。"}]}
2
+ {"messages": [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "请描述小红书APP的用户画像,并用表格的形式展示结果。"}, {"role": "assistant", "content": "| 用户特征 | 描述 |\n| -------- | ------------------------------ |\n| 年龄段 | 18-35岁 |\n| 性别 | 女性用户较多(约占70%) |\n| 地域 | 一线城市和二线城市用户占比较高 |\n| 职业 | 白领、创业者、自由职业者等 |\n| 爱好 | 美容、时尚、旅行、美食、健康、生活方式等 |\n| 消费能力 | 中高消费能力用户 |\n| 使用目的 | 阅读内容,分享购物心得,结识相似兴趣的朋友 |\n\n注:以上用户画像仅供参考,实际情况可能会有所不同。"}]}
3
+ {"messages": [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "请根据以下两个句子的描述,判断它们是否对同一情况进行了描述。句子1:“外面下着大雨,人们纷纷打起了伞。” 句子2:“由于雨量较大,人们都撑起了雨伞。”"}, {"role": "assistant", "content": "是的,这两个句子描述的是同一情况。"}]}
4
+ {"messages": [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "你认为游戏开发者应该如何处理游戏中的社交网络集成?"}, {"role": "assistant", "content": "游戏开发者在处理游戏中的社交网络集成时可以考虑以下几点:\n\n1. 以用户为中心:确保集成社交网络功能是为了提高用户体验,让玩家更容易地与朋友互动、分享成就和游戏进度。\n\n2. 保护隐私:在集成社交网络时,务必保护用户的隐私。在收集和使用用户数据时,需要遵循相关法规和用户协议,同时提供明确的隐私政策。\n\n3. 适度推送:在推送游戏动态、邀请和通知时要适度,避免过度打扰用户或侵犯用户隐私。可提供自定义设置,让玩家选择性地接收通知。\n\n4. 社交功能设计:在游戏中加入社交功能,如聊天、组队、比赛等,以增进玩家之间的互动。同时,确保这些功能易用且不影响游戏体验。\n\n5. 社区建设:鼓励玩家在社交平台上建立游戏相关的社区,分享游戏心得、攻略和经验。这将有助于增加游戏的知名度和吸引力。\n\n6. 奖励机制:设定奖励机制,鼓励玩家在社交网络上分享游戏内容、邀请好友等。但要确保奖励机制不会破坏游戏平衡。\n\n7. 持续更新和优化:根据玩家的反馈,不断优化社交网络集成功能,确保其与游戏的其他方面保持一致。\n\n总之,游戏开发者应该在保护用户隐私的前提下,提供便捷的社交网络集成功能,以提高游戏的互动性和吸引力。同时,要关注玩家的反馈,不断优化功能。"}]}
5
+ {"messages": [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "将这段过去进行时的文字改为现在完成进行时:“她们在公园里跳舞。”"}, {"role": "assistant", "content": "她们已经在公园里跳舞了。"}]}
6
+ {"messages": [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "在godaddy上买了一台服务器怎么更改root账户"}, {"role": "assistant", "content": "在GoDaddy上购买的服务器可能是Linux系统的VPS或者是专用服务器。要更改root账户,你可以通过SSH连接到服务器并修改密码。以下是具体步骤:\n\n1. 打开终端(Terminal)或使用SSH客户端(如PuTTY)。\n2. 使用SSH连接到服务器。在终端中输入以下命令:\n\n ```\n ssh root@your_server_ip_address\n ```\n\n 将`your_server_ip_address`替换为你的服务器IP地址。如果需要,你可能还需要输入服务器的SSH端口号。\n \n3. 当系统提示输入密码时,输入你的当前root密码。如果连接成功,你将看到服务器的命令提示符。\n\n4. 要更改root密码,请输入以下命令:\n\n ```\n passwd\n ```\n\n5. 系统将提示你输入新的UNIX密码。输入你想要设置的新密码并按回车键。\n\n6. 系统将要求你重新输入新密码以进行确认。再次输入新密码并按回车键。\n\n7. 如果密码更改成功,系统将显示“密码已成功更新”的消息。\n\n8. 使用`exit`命令断开与服务器的连接。\n\n现在你已经成功更改了root账户的密码。下次使用SSH连接到服务器时,请使用新的密码进行身份验证。"}]}
7
+ {"messages": [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "PHP算法逻辑:如何计算年龄?"}, {"role": "assistant", "content": "要计算年龄,您可以使用以下PHP算法。此算法考虑了生日、当前日期以及闰年的计算。\n\n```php\nfunction calculateAge($birthdate) {\n // 格式为YYYY-MM-DD的生日\n $birthdate = strtotime($birthdate);\n \n // 获取当前日期\n $now = time();\n\n // 计算年龄\n $age = date('Y', $now) - date('Y', $birthdate);\n \n // 如果还没有过生日,减去一年\n if (date('m-d', $now) < date('m-d', $birthdate)) {\n $age--;\n }\n\n return $age;\n}\n\n// 使用\n$birthdate = \"2000-01-01\";\necho \"年龄是: \" . calculateAge($birthdate) . \" 岁\";\n```\n\n这个函数将生日作为输入,输出用户的年龄。首先,它将生日转换为时间戳,然后获取当前的时间戳。接下来,它计算年份之间的差异。如果当前日期仍未超过这一年的生日,这个函数将年龄减少1。\n\n在使用函数时,只需传入出生日期即可,函数会返回相应的年龄。"}]}
8
+ {"messages": [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "给我解释一下什么是胰岛素抵抗? 请详细描述胰岛素抵抗的症状、原因与治疗方法。 胰岛素抵抗会影响一个人的日常生活吗? 有哪些生活习惯改变可以帮助改善胰岛素抵抗?"}, {"role": "assistant", "content": "胰岛素抵抗是指身体对胰岛素的敏感性降低,胰岛素不能有效地发挥作用,导致血糖无法正常进入身体细胞。胰岛素是一种激素,由胰腺产生,其主要作用是帮助身体细胞吸收血糖并将其转化为能量。当胰岛素抵抗发生时,胰腺会分泌更多的胰岛素以维持正常的血糖水平,但长期的胰岛素抵抗可能导致2型糖尿病和其他健康问题的发生。\n\n胰岛素抵抗的症状可能包括:\n\n1. 高血糖:由于胰岛素不能正常工作,血糖无法进入细胞,导致血糖升高。\n2. 疲劳:由于身体无法有效利用血糖产生能量,患者可能会感到疲劳无力。\n3. 肥胖:胰岛素抵抗可能导致过量的胰岛素分泌,使得脂肪存储增加,进而导致体重增加。\n4. 高血压:胰岛素抵抗与高血压之间存在关联,可能导致血压升高。\n5. 黑棘皮病:皮肤出现黑色、粗糙的皮疹,通常出现在颈部、腋窝等皮肤皱褶处。\n\n胰岛素抵抗的原因可能包括:\n\n1. 遗传因素:遗传易感基因可能会增加一个人患胰岛素抵抗的风险。\n2. 肥胖:过重或肥胖会增加患胰岛素抵抗的风险。\n3. 缺乏运动:缺乏运动可能会导致肌肉对胰岛素的敏感性降低。\n4. 高糖饮食:长期摄入高糖食品可能导致胰岛素抵抗。\n5. 睡眠不足:长期睡眠不足可能会影响胰岛素的正常功能。\n\n治疗胰岛素抵抗的方法主要包括生活方式的改变和药物治疗:\n\n1. 减轻体重:通过改善饮食和增加运动来减轻体重,有助于提高身体对胰岛素的敏感性。\n2. 增加运动:增加有氧运动和力量训练,有助于提高肌肉对胰岛素的敏感性。\n3. 改善饮食:摄入低糖、高纤维、富含全谷物和蔬菜水果的食物,有助于控制血糖和减轻体重。\n4. 药物治疗:在必要时,医生可能会开处降糖药物,如二甲双胍,以帮助提高胰岛素敏感性。\n\n胰岛素抵抗会影响一个人的日常生活,包括疲劳、体重增加、血糖波动等问题。通过改变生活习惯,如减轻体重、增加运动、改善饮食和保证充足睡眠,可以帮助改善胰岛素抵抗,降低患2型糖尿病的风险。"}]}
train/llama_factory_support/dataset_info.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "hy_v3_demo": {
3
+ "file_name": "../example_data.jsonl",
4
+ "formatting": "sharegpt",
5
+ "columns": {
6
+ "messages": "messages"
7
+ },
8
+ "tags": {
9
+ "role_tag": "role",
10
+ "content_tag": "content",
11
+ "user_tag": "user",
12
+ "assistant_tag": "assistant",
13
+ "system_tag": "system"
14
+ }
15
+ }
16
+ }
train/llama_factory_support/ds_zero2_offload_lora.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fp16": {
3
+ "enabled": false,
4
+ "loss_scale": 0,
5
+ "loss_scale_window": 1000,
6
+ "initial_scale_power": 16,
7
+ "hysteresis": 2,
8
+ "min_loss_scale": 1
9
+ },
10
+ "bf16": {
11
+ "enabled": true
12
+ },
13
+
14
+ "zero_optimization": {
15
+ "stage": 2,
16
+ "offload_optimizer": {
17
+ "device": "cpu",
18
+ "pin_memory": true
19
+ },
20
+ "allgather_partitions": true,
21
+ "allgather_bucket_size": 5e8,
22
+ "overlap_comm": true,
23
+ "reduce_scatter": true,
24
+ "reduce_bucket_size": 5e8,
25
+ "contiguous_gradients": true
26
+ },
27
+
28
+ "gradient_accumulation_steps": "auto",
29
+ "gradient_clipping": "auto",
30
+ "steps_per_print": 10,
31
+ "train_micro_batch_size_per_gpu": "auto",
32
+ "wall_clock_breakdown": false
33
+ }
train/llama_factory_support/ds_zero3_offload.json ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fp16": {
3
+ "enabled": false,
4
+ "loss_scale": 0,
5
+ "loss_scale_window": 1000,
6
+ "initial_scale_power": 16,
7
+ "hysteresis": 2,
8
+ "min_loss_scale": 1
9
+ },
10
+ "bf16": {
11
+ "enabled": true
12
+ },
13
+
14
+ "zero_optimization": {
15
+ "stage": 3,
16
+ "offload_optimizer": {
17
+ "device": "cpu",
18
+ "pin_memory": true
19
+ },
20
+ "offload_param": {
21
+ "device": "cpu",
22
+ "pin_memory": true
23
+ },
24
+ "overlap_comm": true,
25
+ "contiguous_gradients": true,
26
+ "sub_group_size": 1e9,
27
+ "reduce_bucket_size": 1e8,
28
+ "stage3_prefetch_bucket_size": 5e8,
29
+ "stage3_param_persistence_threshold": 1e6,
30
+ "stage3_max_live_parameters": 5e9,
31
+ "stage3_max_reuse_distance": 5e9,
32
+ "stage3_gather_16bit_weights_on_model_save": false
33
+ },
34
+
35
+ "gradient_accumulation_steps": "auto",
36
+ "gradient_clipping": "auto",
37
+ "steps_per_print": 10,
38
+ "train_micro_batch_size_per_gpu": "auto",
39
+ "wall_clock_breakdown": false
40
+ }
train/llama_factory_support/hy_v3_full_sft.yaml ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### model
2
+ model_name_or_path: ../hf
3
+ trust_remote_code: true
4
+
5
+ ### method
6
+ stage: sft
7
+ do_train: true
8
+ finetuning_type: full
9
+ deepspeed: ds_zero3_offload.json
10
+
11
+ ### dataset
12
+ dataset_dir: .
13
+ dataset: hy_v3_demo
14
+ template: hy_v3
15
+ cutoff_len: 4096 # HYV3 supports 262k context length
16
+ max_samples: 1000
17
+ overwrite_cache: true
18
+
19
+ ### output
20
+ output_dir: saves/hy_v3/full/sft
21
+ logging_steps: 1
22
+ save_steps: 500
23
+ plot_loss: true
24
+ overwrite_output_dir: true
25
+ save_only_model: false
26
+ report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
27
+
28
+ ### train
29
+ per_device_train_batch_size: 1
30
+ gradient_accumulation_steps: 1
31
+ learning_rate: 1.0e-5
32
+ num_train_epochs: 100.0
33
+ lr_scheduler_type: cosine_with_min_lr
34
+ lr_scheduler_kwargs:
35
+ min_lr_rate: 0.1 # min_lr / learning_rate = 1e-6 / 1e-5 = 0.1
36
+ warmup_ratio: 0.1
37
+ bf16: true
38
+ gradient_checkpointing: true
39
+ gradient_checkpointing_kwargs:
40
+ use_reentrant: true
41
+ ddp_timeout: 180000000
42
+ flash_attn: fa2
43
+ resume_from_checkpoint: null
44
+
45
+ ### eval
46
+ # eval_dataset: alpaca_en_demo
47
+ # val_size: 0.1
48
+ # per_device_eval_batch_size: 1
49
+ # eval_strategy: steps
50
+ # eval_steps: 500
51
+
52
+ ### HYV3 specific parameters (MoE architecture)
53
+ # moe_num_experts: 192
54
+ # moe_topk: 8
55
+ # moe_intermediate_size: 1536
56
+ # moe_layer_num_skipped: 1
57
+ # output_router_logits: true
train/llama_factory_support/hy_v3_lora_sft.yaml ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### model
2
+ model_name_or_path: ../hf
3
+ trust_remote_code: true
4
+
5
+ ### method
6
+ stage: sft
7
+ do_train: true
8
+ finetuning_type: lora
9
+ deepspeed: ds_zero2_offload_lora.json
10
+
11
+ ### LoRA parameters
12
+ lora_rank: 64
13
+ lora_alpha: 128
14
+ lora_dropout: 0.05
15
+ lora_target: q_proj,k_proj,v_proj,o_proj
16
+
17
+ ### dataset
18
+ dataset_dir: .
19
+ dataset: hy_v3_demo
20
+ template: hy_v3
21
+ cutoff_len: 4096 # Use shorter context for LoRA to save memory; increase if needed
22
+ max_samples: 1000
23
+ overwrite_cache: true
24
+
25
+ ### output
26
+ output_dir: saves/hy_v3/lora/sft
27
+ logging_steps: 10
28
+ save_steps: 500
29
+ plot_loss: true
30
+ overwrite_output_dir: true
31
+ save_only_model: false
32
+ report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
33
+
34
+ ### train
35
+ per_device_train_batch_size: 1
36
+ gradient_accumulation_steps: 4
37
+ learning_rate: 2.0e-4
38
+ num_train_epochs: 3.0
39
+ lr_scheduler_type: cosine_with_min_lr
40
+ lr_scheduler_kwargs:
41
+ min_lr_rate: 0.1 # min_lr / learning_rate = 2e-5 / 2e-4 = 0.1
42
+ warmup_ratio: 0.1
43
+ bf16: true
44
+ gradient_checkpointing: true
45
+ gradient_checkpointing_kwargs:
46
+ use_reentrant: true
47
+ ddp_timeout: 180000000
48
+ flash_attn: fa2
49
+ resume_from_checkpoint: null
50
+
51
+ ### eval
52
+ # eval_dataset: alpaca_en_demo
53
+ # val_size: 0.1
54
+ # per_device_eval_batch_size: 1
55
+ # eval_strategy: steps
56
+ # eval_steps: 500
train/llama_factory_support/hy_v3_patches.py ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HYV3 monkey-patches for LLaMA Factory + DeepSpeed ZeRO-3 training.
3
+
4
+ This module applies all necessary runtime patches so that HYV3 (MoE)
5
+ can be trained correctly under LLaMA Factory with DeepSpeed ZeRO-3.
6
+
7
+ Usage:
8
+ Import this module **before** calling `llamafactory-cli train`:
9
+
10
+ import hy_v3_patches # applies patches on import
11
+ # ... then start training
12
+
13
+ Or add to the LLaMA Factory YAML via a custom entry-point wrapper.
14
+
15
+ Patches applied:
16
+ 1. ZeRO-3 buffer loading (e_score_correction_bias etc.)
17
+ Key renaming + expert fusing is now done offline by convert_ckpt_to_outer.py
18
+ 2. Router forward dtype fix (MoE router gate dtype alignment for ZeRO-3)
19
+ 3. gradient_checkpointing (use_reentrant=True for ZeRO-3)
20
+ 4. Tokenizer file copy (CustomSaveCallback)
21
+ 5. (Removed) -- was per-expert ModuleList, now using native 3D Parameters
22
+ 6. Save-time reverse key rename + 3D -> per-expert unfuse
23
+ """
24
+
25
+ import os
26
+ import re
27
+ import logging
28
+ import shutil
29
+ from typing import Optional
30
+
31
+ import torch
32
+ import torch.nn as _nn
33
+ import torch.nn.functional as _F
34
+
35
+ logger = logging.getLogger(__name__)
36
+
37
+ # ============================================================================
38
+ # Patch 1: Buffer loading for ZeRO-3
39
+ #
40
+ # The checkpoint is expected to be in outer format (pre-converted by
41
+ # convert_ckpt_to_outer.py). Key renaming and expert fusing are no longer
42
+ # needed at load time.
43
+ #
44
+ # However, ZeRO-3's _load_state_dict_into_zero3_model only handles
45
+ # named_parameters, not named_buffers (e.g. e_score_correction_bias).
46
+ # We still need to manually load buffers from the state_dict.
47
+ # ============================================================================
48
+
49
+
50
+ def _apply_buffer_loading_patch():
51
+ """Patch the DeepSpeed ZeRO-3 state_dict loader to manually load buffers.
52
+
53
+ ZeRO-3's _load_state_dict_into_zero3_model only handles named_parameters.
54
+ Buffers like e_score_correction_bias must be loaded manually.
55
+ """
56
+ try:
57
+ from transformers.integrations.deepspeed import (
58
+ _load_state_dict_into_zero3_model as _orig_load_zero3,
59
+ )
60
+ import transformers.integrations.deepspeed as _ds_mod
61
+ import transformers.modeling_utils as _mu_mod
62
+ except ImportError:
63
+ logger.warning(
64
+ "Could not import transformers.integrations.deepspeed; "
65
+ "buffer loading patch NOT applied."
66
+ )
67
+ return
68
+
69
+ def _patched_load_zero3(model_to_load, state_dict, *args, **kwargs):
70
+ # Load parameters via original ZeRO-3 loader
71
+ result = _orig_load_zero3(model_to_load, state_dict, *args, **kwargs)
72
+
73
+ # Manually load buffers (e.g. e_score_correction_bias) from state_dict.
74
+ # ZeRO-3's loader only handles named_parameters, not named_buffers.
75
+ buffers_loaded = 0
76
+ for name, buf in model_to_load.named_buffers():
77
+ if name in state_dict:
78
+ src_tensor = state_dict[name]
79
+ if isinstance(src_tensor, torch.Tensor):
80
+ buf.data.copy_(src_tensor.to(buf.dtype))
81
+ buffers_loaded += 1
82
+ if isinstance(result, tuple) and len(result) >= 2:
83
+ if isinstance(result[1], set):
84
+ result[1].discard(name)
85
+ if buffers_loaded > 0:
86
+ logger.info(
87
+ "HYV3 Patch 1: Manually loaded %d buffers into model.",
88
+ buffers_loaded
89
+ )
90
+
91
+ return result
92
+
93
+ _ds_mod._load_state_dict_into_zero3_model = _patched_load_zero3
94
+ _mu_mod._load_state_dict_into_zero3_model = _patched_load_zero3
95
+ logger.info(
96
+ "HYV3 patch applied: ZeRO-3 buffer loading for e_score_correction_bias."
97
+ )
98
+
99
+ # ============================================================================
100
+ # Patch 2: Router forward dtype alignment for ZeRO-3
101
+ #
102
+ # The HYV3 MoE HYV3TopKRouter.forward() calls F.linear with .float().
103
+ # Under DeepSpeed ZeRO-3, F.linear is replaced by zero3_linear_wrap which
104
+ # internally does input.matmul(weight.t()) WITHOUT aligning dtypes.
105
+ # When ZeRO-3 stores the gate weight in bf16, the fp32 input causes a
106
+ # dtype mismatch RuntimeError.
107
+ #
108
+ # Fix: monkey-patch HYV3TopKRouter.forward to cast input to
109
+ # self.weight.dtype before F.linear, then cast the output back to float32.
110
+ # ============================================================================
111
+
112
+ _router_patch_applied = False
113
+
114
+ def _apply_router_dtype_patch():
115
+ """Monkey-patch HYV3TopKRouter.forward to align gate input dtype with weight dtype."""
116
+ global _router_patch_applied
117
+ if _router_patch_applied:
118
+ return
119
+
120
+ try:
121
+ from transformers.models.hy_v3.modeling_hy_v3 import HYV3TopKRouter
122
+ except ImportError:
123
+ try:
124
+ from transformers.hy_v3.modeling_hy_v3 import HYV3TopKRouter
125
+ except ImportError:
126
+ logger.warning(
127
+ "Could not import HYV3TopKRouter; "
128
+ "router dtype patch NOT applied."
129
+ )
130
+ return
131
+
132
+ def _patched_router_forward(
133
+ self,
134
+ hidden_states: torch.Tensor,
135
+ e_score_correction_bias: torch.Tensor,
136
+ ) -> tuple:
137
+ hidden_states = hidden_states.reshape(-1, self.hidden_dim)
138
+ # Cast input to match weight dtype (bf16 under ZeRO-3)
139
+ # instead of hard-coding float32, to avoid matmul dtype mismatch.
140
+ weight_dtype = self.weight.dtype
141
+ router_logits = _F.linear(hidden_states.to(weight_dtype), self.weight.to(weight_dtype))
142
+ # Cast back to float32 for numerically stable sigmoid
143
+ router_logits = router_logits.to(torch.float32)
144
+ routing_weights = torch.sigmoid(router_logits)
145
+
146
+ scores_for_choice = routing_weights + e_score_correction_bias
147
+ _, top_k_index = torch.topk(scores_for_choice, self.top_k, dim=-1, sorted=False)
148
+ top_k_weights = routing_weights.gather(1, top_k_index)
149
+
150
+ top_k_weights = top_k_weights / (top_k_weights.sum(dim=-1, keepdim=True) + 1e-20)
151
+ top_k_weights = top_k_weights * self.router_scaling_factor
152
+
153
+ return router_logits, top_k_weights, top_k_index
154
+
155
+ HYV3TopKRouter.forward = _patched_router_forward
156
+ _router_patch_applied = True
157
+ logger.info("HYV3 patch applied: HYV3TopKRouter.forward dtype alignment for ZeRO-3.")
158
+
159
+ # ============================================================================
160
+ # Patch 3: gradient_checkpointing use_reentrant=True
161
+ #
162
+ # PyTorch's torch.utils.checkpoint with use_reentrant=False (the default
163
+ # in transformers) performs strict metadata checks on recomputed tensors.
164
+ # Under ZeRO-3, parameters are all-gathered during the first forward pass
165
+ # but may be partitioned back when the checkpoint recomputes, causing a
166
+ # CheckpointError. Setting use_reentrant=True avoids this.
167
+ #
168
+ # This is applied via a Trainer callback that modifies training_args
169
+ # before training starts.
170
+ # ============================================================================
171
+
172
+ # ============================================================================
173
+ # Patch 4: Tokenizer file copy callback
174
+ #
175
+ # Ensures each checkpoint directory is self-contained for inference by
176
+ # copying all tokenizer-related files from the original tokenizer path.
177
+ # ============================================================================
178
+
179
+ # Tokenizer files that should be copied to each checkpoint
180
+ _TOKENIZER_FILES = [
181
+ "generation_config.json",
182
+ "hy.tiktoken",
183
+ "tokenizer_config.json",
184
+ "tokenization_hy.py",
185
+ "tokenizer.json",
186
+ "special_tokens_map.json",
187
+ "chat_template.jinja",
188
+ ]
189
+
190
+ def _copy_tokenizer_to_checkpoint(tokenizer_dir: str, checkpoint_dir: str):
191
+ """Copy tokenizer files from tokenizer_dir to checkpoint_dir."""
192
+ for fname in _TOKENIZER_FILES:
193
+ src = os.path.join(tokenizer_dir, fname)
194
+ if os.path.isfile(src):
195
+ shutil.copy(src, os.path.join(checkpoint_dir, fname))
196
+
197
+ # ============================================================================
198
+ # Patch 6: Save-time reverse key rename + 3D -> per-expert unfuse
199
+ #
200
+ # When saving checkpoints, the model state_dict uses:
201
+ # - 3D fused experts (experts.gate_up_proj, experts.down_proj)
202
+ # - New naming (mlp.gate, mlp.e_score_correction_bias, mlp.shared_experts)
203
+ #
204
+ # We need to reverse both for old checkpoint compatibility:
205
+ # - mlp.gate. -> mlp.router.gate.
206
+ # - mlp.e_score_correction_bias -> mlp.expert_bias
207
+ # - mlp.shared_experts. -> mlp.shared_mlp.
208
+ # - experts.gate_up_proj -> experts.{N}.gate_proj.weight + experts.{N}.up_proj.weight
209
+ # - experts.down_proj -> experts.{N}.down_proj.weight
210
+ # ============================================================================
211
+
212
+ # Reverse mapping: new model name -> old checkpoint name
213
+ _SAVE_KEY_RENAMES = [
214
+ ("mlp.gate.", "mlp.router.gate."),
215
+ ("mlp.e_score_correction_bias", "mlp.expert_bias"),
216
+ ("mlp.shared_experts.", "mlp.shared_mlp."),
217
+ ]
218
+
219
+ # Regex to match fused 3D expert keys
220
+ _FUSED_EXPERT_KEY_RE = re.compile(
221
+ r"^(.*\.mlp\.experts\.)(gate_up_proj|down_proj)$"
222
+ )
223
+
224
+ _save_patch_applied = False
225
+
226
+ def _apply_save_reverse_rename_patch():
227
+ """Monkey-patch save_pretrained to reverse-rename keys and unfuse 3D experts."""
228
+ global _save_patch_applied
229
+ if _save_patch_applied:
230
+ return
231
+
232
+ try:
233
+ from transformers.models.hy_v3.modeling_hy_v3 import HYV3ForCausalLM
234
+ except ImportError:
235
+ try:
236
+ from transformers.hy_v3.modeling_hy_v3 import HYV3ForCausalLM
237
+ except ImportError:
238
+ logger.warning(
239
+ "Could not import HYV3ForCausalLM; "
240
+ "save reverse rename patch NOT applied."
241
+ )
242
+ return
243
+
244
+ _orig_save_pretrained = HYV3ForCausalLM.save_pretrained
245
+
246
+ def _patched_save_pretrained(self, *args, **kwargs):
247
+ state_dict = kwargs.get("state_dict", None)
248
+ if state_dict is not None:
249
+ reversed_sd = {}
250
+
251
+ for k, v in state_dict.items():
252
+ new_k = k
253
+ # Apply simple key renames
254
+ for new_sub, old_sub in _SAVE_KEY_RENAMES:
255
+ if new_sub in new_k:
256
+ new_k = new_k.replace(new_sub, old_sub)
257
+ break
258
+
259
+ # Check if this is a fused 3D expert key
260
+ m = _FUSED_EXPERT_KEY_RE.match(new_k)
261
+ if m:
262
+ prefix = m.group(1) # e.g. "model.layers.1.mlp.experts."
263
+ proj_type = m.group(2) # "gate_up_proj" or "down_proj"
264
+
265
+ if proj_type == "gate_up_proj":
266
+ # v shape: [num_experts, 2*intermediate, hidden]
267
+ num_experts = v.shape[0]
268
+ intermediate = v.shape[1] // 2
269
+ for i in range(num_experts):
270
+ gate = v[i, :intermediate, :]
271
+ up = v[i, intermediate:, :]
272
+ reversed_sd[f"{prefix}{i}.gate_proj.weight"] = gate
273
+ reversed_sd[f"{prefix}{i}.up_proj.weight"] = up
274
+ elif proj_type == "down_proj":
275
+ # v shape: [num_experts, hidden, intermediate]
276
+ num_experts = v.shape[0]
277
+ for i in range(num_experts):
278
+ reversed_sd[f"{prefix}{i}.down_proj.weight"] = v[i]
279
+ else:
280
+ reversed_sd[new_k] = v
281
+
282
+ kwargs["state_dict"] = reversed_sd
283
+ logger.info(
284
+ "HYV3 Patch 6: Reverse-renamed and unfused %d -> %d "
285
+ "state_dict keys for old checkpoint compatibility.",
286
+ len(state_dict), len(reversed_sd)
287
+ )
288
+ return _orig_save_pretrained(self, *args, **kwargs)
289
+
290
+ HYV3ForCausalLM.save_pretrained = _patched_save_pretrained
291
+
292
+ _save_patch_applied = True
293
+ logger.info(
294
+ "HYV3 patch applied: save-time reverse key rename + "
295
+ "3D -> per-expert unfuse for old ckpt compatibility."
296
+ )
297
+
298
+ # ============================================================================
299
+ # LLaMA Factory Callback: integrates patches 3, 4 into the training loop
300
+ # ============================================================================
301
+
302
+ try:
303
+ from transformers import TrainerCallback
304
+ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
305
+
306
+ class HYV3PatchCallback(TrainerCallback):
307
+ """
308
+ LLaMA Factory compatible callback that applies HYV3-specific patches.
309
+
310
+ Add to your YAML or pass to Trainer:
311
+ callbacks: [hy_v3_patches.HYV3PatchCallback]
312
+ """
313
+
314
+ def __init__(self, tokenizer_dir: Optional[str] = None):
315
+ """
316
+ Args:
317
+ tokenizer_dir: Path to the original tokenizer directory.
318
+ If None, will try to use model_name_or_path from training args.
319
+ """
320
+ self._tokenizer_dir = tokenizer_dir
321
+
322
+ def on_train_begin(self, args, state, control, **kwargs):
323
+ # --- Patch 3: gradient_checkpointing use_reentrant ---
324
+ if getattr(args, "gradient_checkpointing", False) and getattr(args, "deepspeed", None):
325
+ if not hasattr(args, "gradient_checkpointing_kwargs") or not args.gradient_checkpointing_kwargs:
326
+ args.gradient_checkpointing_kwargs = {"use_reentrant": True}
327
+ elif "use_reentrant" not in args.gradient_checkpointing_kwargs:
328
+ args.gradient_checkpointing_kwargs["use_reentrant"] = True
329
+ logger.info("HYV3 patch applied: gradient_checkpointing use_reentrant=True.")
330
+
331
+ return control
332
+
333
+ def on_save(self, args, state, control, **kwargs):
334
+ # --- Patch 4: Copy tokenizer files ---
335
+ if torch.distributed.is_initialized() and torch.distributed.get_rank() != 0:
336
+ return control
337
+
338
+ checkpoint_dir = os.path.join(
339
+ args.output_dir,
340
+ f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
341
+ )
342
+
343
+ # Determine tokenizer directory
344
+ tokenizer_dir = self._tokenizer_dir
345
+ if tokenizer_dir is None:
346
+ # Try common locations
347
+ tokenizer_dir = getattr(args, "tokenizer_name_or_path", None)
348
+ if tokenizer_dir is None:
349
+ tokenizer_dir = getattr(args, "model_name_or_path", None)
350
+
351
+ if tokenizer_dir and os.path.isdir(tokenizer_dir):
352
+ _copy_tokenizer_to_checkpoint(tokenizer_dir, checkpoint_dir)
353
+ logger.info(
354
+ "HYV3: Copied tokenizer files from %s to %s",
355
+ tokenizer_dir, checkpoint_dir
356
+ )
357
+
358
+ return control
359
+
360
+ except ImportError:
361
+ logger.warning(
362
+ "transformers not available; HYV3PatchCallback not defined."
363
+ )
364
+
365
+ # ============================================================================
366
+ # Auto-apply patches on import
367
+ # ============================================================================
368
+
369
+ # Patch 1: ZeRO-3 buffer loading (key rename + fuse now done by preprocessing)
370
+ _apply_buffer_loading_patch()
371
+
372
+ # Patch 2: Router dtype fix
373
+ _apply_router_dtype_patch()
374
+
375
+ # Patch 6: Save-time reverse key rename + 3D -> per-expert unfuse
376
+ _apply_save_reverse_rename_patch()
377
+
378
+ # Patches 3, 4 are applied via HYV3PatchCallback during training.
379
+ # Users should add HYV3PatchCallback to their Trainer callbacks.
380
+
381
+ logger.info(
382
+ "HYV3 patches module loaded. Remember to add HYV3PatchCallback to "
383
+ "your Trainer callbacks for full compatibility."
384
+ )
train/llama_factory_support/hy_v3_template.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HYV3 chat template registration for LLaMA Factory.
3
+
4
+ Usage:
5
+ 1. Copy this file's register_template block into LLaMA Factory's
6
+ src/llamafactory/data/template.py (for upstream MR).
7
+ 2. Or import this module before training to register at runtime:
8
+ import hy_v3_template
9
+ """
10
+
11
+ from llamafactory.data.template import ReasoningTemplate, register_template
12
+ from llamafactory.data.formatter import EmptyFormatter, StringFormatter
13
+
14
+
15
+ # ---------------------------------------------------------------------------
16
+ # HYV3 (MoE, pure text) chat template
17
+ #
18
+ # Token format (from chat_template.jinja & tokenizer_config.json):
19
+ # BOS: <|hy_begin▁of▁sentence|>
20
+ # System: {system_content} (directly after BOS, no role tag)
21
+ # User: <|hy_User|>{user_content}
22
+ # Assistant: <|hy_Assistant|>{assistant_content}<|hy_eos|>
23
+ # EOS: <|hy_eos|>
24
+ #
25
+ # Loss mask: only compute loss on assistant content (including <|hy_eos|>).
26
+ #
27
+ # Note: The system message has NO explicit role token -- it is placed right
28
+ # after BOS. The eos_token is <|hy_eos|>.
29
+ #
30
+ # Reasoning: Supports think tags via ReasoningTemplate.
31
+ # - thought_words: ("<think>", "</think>") matching jinja template
32
+ # - enable_thinking: set globally via data_args.enable_thinking (default True)
33
+ # - Training data always includes think tags (empty or with content)
34
+ # ---------------------------------------------------------------------------
35
+
36
+ register_template(
37
+ name="hy_v3",
38
+ template_class=ReasoningTemplate,
39
+ format_user=StringFormatter(slots=["<|hy_User|>{{content}}"]),
40
+ format_assistant=StringFormatter(slots=["<|hy_Assistant|>{{content}}", {"eos_token"}]),
41
+ format_system=StringFormatter(slots=["{{content}}"]),
42
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
43
+ thought_words=("<think>", "</think>"),
44
+ stop_words=["<|hy_eos|>"],
45
+ efficient_eos=True,
46
+ )
train/llama_factory_support/train_hy_v3.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LLaMA Factory training entry-point wrapper for HYV3.
3
+
4
+ This script:
5
+ 1. Registers the hy_v3 chat template
6
+ 2. Applies all HYV3 monkey-patches (checkpoint key rename, dtype fix, etc.)
7
+ 3. Injects HYV3PatchCallback into the training loop
8
+ 4. Calls run_exp() to start LLaMA Factory training
9
+
10
+ How it works:
11
+ - train_lf.sh launches this script via torchrun directly:
12
+ torchrun ... train_hy_v3.py hy_v3_full_sft.yaml
13
+ - Each torchrun worker executes this script, so all patches are applied
14
+ in every worker process before training begins.
15
+ - We call run_exp() directly (not the CLI launcher) to avoid the
16
+ launcher re-spawning workers and losing our patches.
17
+
18
+ Usage:
19
+ # Via launch script (recommended):
20
+ bash train_lf.sh
21
+
22
+ # Direct single-node (8 GPUs):
23
+ torchrun --nproc_per_node 8 train_hy_v3.py hy_v3_full_sft.yaml
24
+ """
25
+
26
+ import sys
27
+ import os
28
+
29
+ # Add current directory to path so patches can be imported
30
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
31
+
32
+ # Step 1: Register HYV3 template (must be before training starts)
33
+ import hy_v3_template # noqa: F401
34
+
35
+ # Step 2: Apply checkpoint key rename patch (must be before model loading)
36
+ import hy_v3_patches # noqa: F401
37
+
38
+ # Step 3: Inject HYV3PatchCallback into LLaMA Factory's training flow
39
+ from llamafactory.train.sft.workflow import run_sft as _orig_run_sft
40
+
41
+
42
+ def _patched_run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks=None):
43
+ """Wrap run_sft to inject HYV3PatchCallback."""
44
+ if callbacks is None:
45
+ callbacks = []
46
+
47
+ # Determine tokenizer directory for the save callback
48
+ tokenizer_dir = getattr(model_args, "model_name_or_path", None)
49
+ callbacks.append(hy_v3_patches.HYV3PatchCallback(tokenizer_dir=tokenizer_dir))
50
+
51
+ return _orig_run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks=callbacks)
52
+
53
+
54
+ # Monkey-patch the SFT workflow
55
+ import llamafactory.train.sft.workflow as _sft_wf
56
+ _sft_wf.run_sft = _patched_run_sft
57
+
58
+
59
+ def main():
60
+ """Entry point: called by torchrun in each worker process.
61
+
62
+ Since train_lf.sh launches us via torchrun directly, all patches
63
+ (template registration, checkpoint key rename, SFT callback injection)
64
+ are already applied in this process. We just call run_exp() to start
65
+ training — no need to go through the CLI launcher.
66
+ """
67
+ from llamafactory.train.tuner import run_exp
68
+ run_exp()
69
+
70
+
71
+ if __name__ == "__main__":
72
+ main()
train/llama_factory_support/train_lf.sh ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # ============================================================================
3
+ # LLaMA Factory training launch script for HYV3
4
+ #
5
+ # This script sets up the environment and launches training via torchrun.
6
+ #
7
+ # We use train_hy_v3.py as the entry point (not llamafactory-cli)
8
+ # because we need to inject HYV3-specific monkey-patches and register
9
+ # the hy_v3 chat template BEFORE LLaMA Factory starts.
10
+ # train_hy_v3.py directly calls run_exp() in each torchrun worker,
11
+ # ensuring all patches are active.
12
+ #
13
+ # Usage:
14
+ # Single node: bash train_lf.sh
15
+ # Multi-node: Run this script on EACH node with the same IP_LIST.
16
+ # IP_LIST="10.0.0.1,10.0.0.2" bash train_lf.sh
17
+ # ============================================================================
18
+
19
+ set -euo pipefail
20
+
21
+ # -------------------- Network Configuration --------------------
22
+ NET_TYPE="high"
23
+ export NCCL_DEBUG=WARN
24
+ export NCCL_P2P_LEVEL=NVL
25
+ export NCCL_IB_TIMEOUT=24
26
+ export NCCL_NVLS_ENABLE=0
27
+ export NCCL_MPI_PROFILE_PRIMS_ENABLE=0
28
+ export CUDA_DEVICE_MAX_CONNECTIONS=1
29
+ export TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC=3600
30
+ if [[ "${NET_TYPE}" = "low" ]]; then
31
+ export NCCL_SOCKET_IFNAME=eth1
32
+ export NCCL_IB_GID_INDEX=3
33
+ export NCCL_IB_HCA=mlx5_2:1
34
+ export NCCL_IB_SL=3
35
+ export NCCL_CHECK_DISABLE=1
36
+ export NCCL_P2P_DISABLE=0
37
+ export NCCL_LL_THRESHOLD=16384
38
+ export NCCL_IB_CUDA_SUPPORT=1
39
+ else
40
+ export NCCL_IB_GID_INDEX=3
41
+ export NCCL_IB_SL=3
42
+ export NCCL_CHECK_DISABLE=1
43
+ export NCCL_P2P_DISABLE=0
44
+ export NCCL_IB_DISABLE=0
45
+ export NCCL_LL_THRESHOLD=16384
46
+ export NCCL_IB_CUDA_SUPPORT=1
47
+ export NCCL_SOCKET_IFNAME=bond1
48
+ export UCX_NET_DEVICES=bond1
49
+ export NCCL_IB_HCA=mlx5_bond_1,mlx5_bond_5,mlx5_bond_3,mlx5_bond_7,mlx5_bond_4,mlx5_bond_8,mlx5_bond_2,mlx5_bond_6
50
+ export NCCL_COLLNET_ENABLE=0
51
+ export SHARP_COLL_ENABLE_SAT=0
52
+ export NCCL_NET_GDR_LEVEL=2
53
+ export NCCL_IB_QPS_PER_CONNECTION=4
54
+ export NCCL_IB_TC=160
55
+ export NCCL_PXN_DISABLE=1
56
+ fi
57
+
58
+ # Skip LLaMA Factory version check (we use a newer transformers branch)
59
+ export DISABLE_VERSION_CHECK=1
60
+
61
+ # -------------------- Node Configuration --------------------
62
+ export HOST_GPU_NUM=8
63
+ # IP list, comma separated. e.g. "10.0.0.1,10.0.0.2" or single node "127.0.0.1"
64
+ export IP_LIST=${IP_LIST:-"127.0.0.1"}
65
+
66
+ MASTER_PORT=${MASTER_PORT:-29500}
67
+
68
+ IFS=',' read -ra IP_ARRAY <<< "$IP_LIST"
69
+ NODES=${#IP_ARRAY[@]}
70
+ MASTER_ADDR=${IP_ARRAY[0]}
71
+
72
+ # -------------------- Paths --------------------
73
+ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
74
+ YAML_FILE="${SCRIPT_DIR}/hy_v3_full_sft.yaml"
75
+ ENTRY_SCRIPT="${SCRIPT_DIR}/train_hy_v3.py"
76
+
77
+ # -------------------- Distributed Environment --------------------
78
+ export MASTER_ADDR="${MASTER_ADDR}"
79
+ export MASTER_PORT="${MASTER_PORT}"
80
+ export NNODES="${NODES}"
81
+
82
+ if [ ${NODES} -gt 1 ]; then
83
+ # Determine local node rank by matching local IP against IP_LIST
84
+ LOCAL_IP=$(hostname -i | awk '{print $1}')
85
+ NODE_RANK=0
86
+ for i in "${!IP_ARRAY[@]}"; do
87
+ if [[ "${IP_ARRAY[$i]}" == "${LOCAL_IP}" ]]; then
88
+ NODE_RANK=$i
89
+ break
90
+ fi
91
+ done
92
+ export RANK="${NODE_RANK}"
93
+ else
94
+ export RANK=0
95
+ fi
96
+
97
+ echo "============================================"
98
+ echo " HYV3 LLaMA Factory Training"
99
+ echo " Nodes: ${NNODES}, Rank: ${RANK}"
100
+ echo " Master: ${MASTER_ADDR}:${MASTER_PORT}"
101
+ echo " GPUs per node: ${HOST_GPU_NUM}"
102
+ echo " Total GPUs: $((NODES * HOST_GPU_NUM))"
103
+ echo "============================================"
104
+
105
+ # -------------------- Launch --------------------
106
+ # We launch torchrun directly (instead of FORCE_TORCHRUN) so that each
107
+ # worker process runs train_hy_v3.py with all HYV3 patches applied.
108
+ torchrun \
109
+ --nnodes "${NNODES}" \
110
+ --node_rank "${RANK}" \
111
+ --nproc_per_node "${HOST_GPU_NUM}" \
112
+ --master_addr "${MASTER_ADDR}" \
113
+ --master_port "${MASTER_PORT}" \
114
+ "${ENTRY_SCRIPT}" "${YAML_FILE}"
train/merge_lora_weight.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 导入所需的库
2
+ from transformers import AutoModelForCausalLM # 用于加载预训练的语言模型
3
+ from peft import LoraConfig, get_peft_model, PeftModel # 用于处理LoRA权重
4
+ import argparse # 用于解析命令行参数
5
+ import shutil # 用于文件操作,如复制
6
+ import os # 用于文件路径操作
7
+ import torch # 用于深度学习操作
8
+
9
+ def main():
10
+ # 创建参数解析器
11
+ parser = argparse.ArgumentParser()
12
+ # 添加命令行参数
13
+ parser.add_argument("--base_model_path", type=str, required=True,
14
+ help="Path to pretrained model or model identifier from huggingface.co/models")
15
+ parser.add_argument("--adapter_model_path", type=str, required=True, help="Path to adapter model")
16
+ parser.add_argument("--output_path", type=str, required=True, help="Path to save the output model")
17
+ parser.add_argument("--save_dtype", type=str, choices=['bf16', 'fp32', 'fp16'],
18
+ default='fp32', help="In which dtype to save, fp32, bf16 or fp16.")
19
+ # 解析命令行参数
20
+ args = parser.parse_args()
21
+
22
+ name2dtype = {'bf16': torch.bfloat16, 'fp32': torch.float32, 'fp16': torch.float16}
23
+ # 加载基座模型
24
+ model = AutoModelForCausalLM.from_pretrained(
25
+ args.base_model_path, device_map='cpu',
26
+ trust_remote_code=True, torch_dtype=name2dtype[args.save_dtype]
27
+ )
28
+ # 在基座模型的基础上加载 adapter 权重
29
+ model = PeftModel.from_pretrained(model, args.adapter_model_path, trust_remote_code=True)
30
+ # 融合模型和 adapter
31
+ model = model.merge_and_unload()
32
+ # 保存融合后的模型权重
33
+ model.save_pretrained(args.output_path, safe_serialization=False)
34
+
35
+ # Copy tokenizer, config and other non-weight files from base model
36
+ # Skip model weight files (.safetensors, .bin, .pt) and index files
37
+ _SKIP_SUFFIXES = ('.safetensors', '.bin', '.pt', '.pth')
38
+ _SKIP_NAMES = {'model.safetensors.index.json', 'pytorch_model.bin.index.json'}
39
+
40
+ for fname in os.listdir(args.base_model_path):
41
+ src = os.path.join(args.base_model_path, fname)
42
+ if not os.path.isfile(src):
43
+ continue
44
+ if fname in _SKIP_NAMES or fname.endswith(_SKIP_SUFFIXES):
45
+ continue
46
+ dst = os.path.join(args.output_path, fname)
47
+ if not os.path.exists(dst):
48
+ shutil.copy(src, dst)
49
+ print(f'Copied {fname}')
50
+
51
+ print(f'Merged model weight is saved to {args.output_path}')
52
+
53
+ if __name__ == "__main__":
54
+ main()
train/merge_lora_weight.sh ADDED
@@ -0,0 +1 @@
 
 
1
+ python3 merge_lora_weight.py --base_model_path /xxx/hy_l_train/checkpoint-200 --adapter_model_path /xxx/runs/hy_l_lora_train/checkpoint-200 --output_path /xxx/ckpts/merged_hy_lora_weight --save_dtype bf16
train/requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ transformers>=5.6.0
2
+ torch>=2.10.0
3
+ torchvision
4
+ torchaudio
5
+ accelerate>=1.11.0
6
+ peft>=0.18.1
7
+ deepspeed>=0.18.7
8
+ sentencepiece
9
+ protobuf
10
+ ninja
11
+ flash-attn
12
+ tensorboard
train/tools/check_converted.py ADDED
@@ -0,0 +1,455 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Quick validation script for converted HYV3 outer-format checkpoint.
4
+
5
+ Checks:
6
+ 1. model.safetensors.index.json structure and completeness
7
+ 2. All expected weight keys exist (dense layer 0, MoE layers 1-79)
8
+ 3. Expert tensor shapes (fused 3D format)
9
+ 4. All referenced shard files exist and are non-empty
10
+ 5. Spot-check: load a few shards and verify tensor shapes/dtypes
11
+ 6. No duplicate or orphan keys
12
+
13
+ Usage:
14
+ python check_converted.py <output_dir> [--spot-check N]
15
+
16
+ Example:
17
+ python check_converted.py pretrain_base/hf_outer
18
+ python check_converted.py pretrain_base/hf_outer --spot-check 5
19
+ """
20
+
21
+ import argparse
22
+ import json
23
+ import os
24
+ import sys
25
+ import time
26
+ from collections import defaultdict
27
+
28
+ # ============================================================================
29
+ # Expected key patterns for HYV3 outer format
30
+ # ============================================================================
31
+
32
+ # Dense layer (layer 0) expected suffixes
33
+ DENSE_SUFFIXES = [
34
+ "input_layernorm.weight",
35
+ "post_attention_layernorm.weight",
36
+ "self_attn.q_proj.weight",
37
+ "self_attn.k_proj.weight",
38
+ "self_attn.v_proj.weight",
39
+ "self_attn.o_proj.weight",
40
+ "self_attn.q_norm.weight",
41
+ "self_attn.k_norm.weight",
42
+ "mlp.gate_proj.weight",
43
+ "mlp.up_proj.weight",
44
+ "mlp.down_proj.weight",
45
+ ]
46
+
47
+ # MoE layer (layers 1-79) expected suffixes
48
+ MOE_SUFFIXES = [
49
+ "input_layernorm.weight",
50
+ "post_attention_layernorm.weight",
51
+ "self_attn.q_proj.weight",
52
+ "self_attn.k_proj.weight",
53
+ "self_attn.v_proj.weight",
54
+ "self_attn.o_proj.weight",
55
+ "self_attn.q_norm.weight",
56
+ "self_attn.k_norm.weight",
57
+ # MoE-specific
58
+ "mlp.gate.weight",
59
+ "mlp.e_score_correction_bias",
60
+ "mlp.experts.gate_up_proj",
61
+ "mlp.experts.down_proj",
62
+ "mlp.shared_experts.gate_proj.weight",
63
+ "mlp.shared_experts.up_proj.weight",
64
+ "mlp.shared_experts.down_proj.weight",
65
+ ]
66
+
67
+ # MTP (Multi-Token Prediction) layer expected suffixes
68
+ # MTP layers share MoE structure but have additional projection/norm keys
69
+ MTP_EXTRA_SUFFIXES = [
70
+ "eh_proj.weight",
71
+ "enorm.weight",
72
+ "final_layernorm.weight",
73
+ "hnorm.weight",
74
+ ]
75
+
76
+ # Global keys (not per-layer)
77
+ GLOBAL_KEYS = [
78
+ "model.embed_tokens.weight",
79
+ "model.norm.weight",
80
+ "lm_head.weight",
81
+ ]
82
+
83
+
84
+ def load_config(output_dir):
85
+ """Load config.json and extract model parameters."""
86
+ config_path = os.path.join(output_dir, "config.json")
87
+ if not os.path.exists(config_path):
88
+ print(f"[ERROR] config.json not found in {output_dir}")
89
+ return None
90
+ with open(config_path) as f:
91
+ return json.load(f)
92
+
93
+
94
+ def check_index_json(output_dir):
95
+ """Check model.safetensors.index.json for structure and completeness."""
96
+ index_path = os.path.join(output_dir, "model.safetensors.index.json")
97
+ if not os.path.exists(index_path):
98
+ print(f"[ERROR] model.safetensors.index.json not found")
99
+ return None, []
100
+
101
+ with open(index_path) as f:
102
+ index = json.load(f)
103
+
104
+ errors = []
105
+
106
+ # Check structure
107
+ if "metadata" not in index:
108
+ errors.append("Missing 'metadata' in index.json")
109
+ elif "total_size" not in index["metadata"]:
110
+ errors.append("Missing 'total_size' in metadata")
111
+
112
+ if "weight_map" not in index:
113
+ errors.append("Missing 'weight_map' in index.json")
114
+ return index, errors
115
+
116
+ weight_map = index["weight_map"]
117
+ total_size = index.get("metadata", {}).get("total_size", 0)
118
+
119
+ print(f" Index keys : {len(weight_map)}")
120
+ print(f" Total size : {total_size / 1e9:.2f} GB")
121
+
122
+ # Check for empty weight_map
123
+ if len(weight_map) == 0:
124
+ errors.append("weight_map is empty")
125
+
126
+ return index, errors
127
+
128
+
129
+ def check_expected_keys(weight_map, config):
130
+ """Check that all expected keys exist in the weight_map."""
131
+ errors = []
132
+ warnings = []
133
+
134
+ num_layers = config.get("num_hidden_layers", 80)
135
+ first_k_dense = config.get("first_k_dense_replace", 1)
136
+ num_experts = config.get("num_experts", 192)
137
+ num_mtp_layers = config.get("num_nextn_predict_layers", 0)
138
+
139
+ # Check global keys
140
+ for key in GLOBAL_KEYS:
141
+ if key not in weight_map:
142
+ errors.append(f"Missing global key: {key}")
143
+
144
+ # Check per-layer keys (regular layers)
145
+ missing_by_type = defaultdict(list)
146
+ for layer_idx in range(num_layers):
147
+ prefix = f"model.layers.{layer_idx}."
148
+ if layer_idx < first_k_dense:
149
+ # Dense layer
150
+ suffixes = DENSE_SUFFIXES
151
+ else:
152
+ # MoE layer
153
+ suffixes = MOE_SUFFIXES
154
+
155
+ for suffix in suffixes:
156
+ full_key = prefix + suffix
157
+ if full_key not in weight_map:
158
+ missing_by_type[suffix].append(layer_idx)
159
+
160
+ # Check MTP layers (layer num_layers .. num_layers + num_mtp_layers - 1)
161
+ mtp_missing_by_type = defaultdict(list)
162
+ for mtp_idx in range(num_mtp_layers):
163
+ layer_idx = num_layers + mtp_idx
164
+ prefix = f"model.layers.{layer_idx}."
165
+ # MTP layers use MoE structure + extra projection/norm keys
166
+ mtp_suffixes = MOE_SUFFIXES + MTP_EXTRA_SUFFIXES
167
+ for suffix in mtp_suffixes:
168
+ full_key = prefix + suffix
169
+ if full_key not in weight_map:
170
+ mtp_missing_by_type[suffix].append(layer_idx)
171
+
172
+ for suffix, layers in sorted(mtp_missing_by_type.items()):
173
+ layer_str = str(layers)
174
+ errors.append(f"Missing MTP key '{suffix}' in layers: {layer_str}")
175
+
176
+ for suffix, layers in sorted(missing_by_type.items()):
177
+ if len(layers) <= 5:
178
+ layer_str = str(layers)
179
+ else:
180
+ layer_str = f"{layers[:3]}...({len(layers)} total)"
181
+ errors.append(f"Missing '{suffix}' in layers: {layer_str}")
182
+
183
+ # Check for unexpected keys (not matching any known pattern)
184
+ known_prefixes = set()
185
+ # Regular layers + MTP layers
186
+ for layer_idx in range(num_layers + num_mtp_layers):
187
+ known_prefixes.add(f"model.layers.{layer_idx}.")
188
+ known_prefixes.add("model.embed_tokens.")
189
+ known_prefixes.add("model.norm.")
190
+ known_prefixes.add("lm_head.")
191
+ # Alternative MTP prefix (some models use this)
192
+ known_prefixes.add("model.mtp_layers.")
193
+
194
+ unexpected = []
195
+ for key in weight_map:
196
+ if not any(key.startswith(p) for p in known_prefixes):
197
+ unexpected.append(key)
198
+
199
+ if unexpected:
200
+ if len(unexpected) <= 5:
201
+ for k in unexpected:
202
+ warnings.append(f"Unexpected key: {k}")
203
+ else:
204
+ warnings.append(f"{len(unexpected)} unexpected keys found (first 3: {unexpected[:3]})")
205
+
206
+ return errors, warnings
207
+
208
+
209
+ def check_shard_files(output_dir, weight_map):
210
+ """Check that all referenced shard files exist and are non-empty."""
211
+ errors = []
212
+ warnings = []
213
+
214
+ # Get unique shard files
215
+ shard_files = sorted(set(weight_map.values()))
216
+ print(f" Shard files : {len(shard_files)}")
217
+
218
+ missing = []
219
+ empty = []
220
+ total_disk_size = 0
221
+
222
+ for sf in shard_files:
223
+ path = os.path.join(output_dir, sf)
224
+ if not os.path.exists(path):
225
+ missing.append(sf)
226
+ else:
227
+ size = os.path.getsize(path)
228
+ if size == 0:
229
+ empty.append(sf)
230
+ total_disk_size += size
231
+
232
+ print(f" Disk size : {total_disk_size / 1e9:.2f} GB")
233
+
234
+ if missing:
235
+ errors.append(f"Missing shard files ({len(missing)}): {missing[:5]}")
236
+ if empty:
237
+ errors.append(f"Empty shard files ({len(empty)}): {empty[:5]}")
238
+
239
+ # Check for orphan shard files (exist on disk but not in index)
240
+ all_safetensors = set(
241
+ f for f in os.listdir(output_dir)
242
+ if f.endswith(".safetensors")
243
+ )
244
+ referenced = set(shard_files)
245
+ orphans = all_safetensors - referenced
246
+ if orphans:
247
+ # Distinguish between empty residue files (cross-shard merge artifacts)
248
+ # and real orphan files with actual data
249
+ EMPTY_SHARD_THRESHOLD = 128 # bytes; empty safetensors header is ~16 bytes
250
+ residue_orphans = []
251
+ real_orphans = []
252
+ for o in sorted(orphans):
253
+ sz = os.path.getsize(os.path.join(output_dir, o))
254
+ if sz <= EMPTY_SHARD_THRESHOLD:
255
+ residue_orphans.append(o)
256
+ else:
257
+ real_orphans.append(o)
258
+
259
+ if residue_orphans:
260
+ warnings.append(
261
+ f"{len(residue_orphans)} empty residue shard(s) from cross-shard merge "
262
+ f"(<=128 bytes each, safe to delete)"
263
+ )
264
+ if real_orphans:
265
+ errors.append(
266
+ f"Orphan shard files with data (not in index): {real_orphans[:5]}"
267
+ )
268
+
269
+ return errors, warnings
270
+
271
+
272
+ def check_key_distribution(weight_map):
273
+ """Check the distribution of keys across shards."""
274
+ shard_key_count = defaultdict(int)
275
+ for key, shard in weight_map.items():
276
+ shard_key_count[shard] += 1
277
+
278
+ counts = sorted(shard_key_count.values())
279
+ print(f" Keys/shard : min={counts[0]}, max={counts[-1]}, "
280
+ f"median={counts[len(counts)//2]}")
281
+
282
+ # Check for shards with 0 keys (should not happen if they are in weight_map)
283
+ zero_shards = [s for s, c in shard_key_count.items() if c == 0]
284
+ if zero_shards:
285
+ return [f"Shards with 0 keys: {zero_shards}"]
286
+ return []
287
+
288
+
289
+ def spot_check_shards(output_dir, weight_map, config, num_checks=3):
290
+ """Spot-check a few shards by loading and verifying tensor shapes."""
291
+ errors = []
292
+
293
+ try:
294
+ from safetensors import safe_open
295
+ except ImportError:
296
+ print(" [SKIP] safetensors not installed, skipping spot-check")
297
+ return errors
298
+
299
+ num_experts = config.get("num_experts", 192)
300
+ expert_hidden = config.get("expert_hidden_dim", config.get("moe_intermediate_size", 1536))
301
+ hidden_size = config.get("hidden_size", 4096)
302
+
303
+ # Find shards that contain expert tensors (most interesting to check)
304
+ expert_shards = set()
305
+ for key, shard in weight_map.items():
306
+ if "experts.gate_up_proj" in key or "experts.down_proj" in key:
307
+ expert_shards.add(shard)
308
+
309
+ # Pick a few shards to check
310
+ check_shards = sorted(expert_shards)[:num_checks]
311
+ if not check_shards:
312
+ check_shards = sorted(set(weight_map.values()))[:num_checks]
313
+
314
+ print(f"\n Spot-checking {len(check_shards)} shard(s)...")
315
+
316
+ for shard_file in check_shards:
317
+ shard_path = os.path.join(output_dir, shard_file)
318
+ t0 = time.time()
319
+
320
+ try:
321
+ with safe_open(shard_path, framework="pt", device="cpu") as f:
322
+ keys_in_shard = list(f.keys())
323
+ for key in keys_in_shard:
324
+ tensor = f.get_tensor(key)
325
+
326
+ # Check expert shapes
327
+ if key.endswith("experts.gate_up_proj"):
328
+ expected_shape = (num_experts, expert_hidden * 2, hidden_size)
329
+ if tuple(tensor.shape) != expected_shape:
330
+ errors.append(
331
+ f"{shard_file}/{key}: shape {tuple(tensor.shape)} "
332
+ f"!= expected {expected_shape}"
333
+ )
334
+
335
+ elif key.endswith("experts.down_proj"):
336
+ expected_shape = (num_experts, hidden_size, expert_hidden)
337
+ if tuple(tensor.shape) != expected_shape:
338
+ errors.append(
339
+ f"{shard_file}/{key}: shape {tuple(tensor.shape)} "
340
+ f"!= expected {expected_shape}"
341
+ )
342
+
343
+ # Check for NaN/Inf
344
+ if tensor.is_floating_point():
345
+ if tensor.isnan().any():
346
+ errors.append(f"{shard_file}/{key}: contains NaN values")
347
+ if tensor.isinf().any():
348
+ errors.append(f"{shard_file}/{key}: contains Inf values")
349
+
350
+ elapsed = time.time() - t0
351
+ print(f" {shard_file}: {len(keys_in_shard)} keys, OK ({elapsed:.1f}s)")
352
+
353
+ except Exception as e:
354
+ errors.append(f"Failed to load {shard_file}: {e}")
355
+
356
+ return errors
357
+
358
+
359
+ def main():
360
+ parser = argparse.ArgumentParser(
361
+ description="Validate converted HYV3 outer-format checkpoint."
362
+ )
363
+ parser.add_argument(
364
+ "output_dir", type=str,
365
+ help="Path to the converted outer-format checkpoint directory.",
366
+ )
367
+ parser.add_argument(
368
+ "--spot-check", type=int, default=3, dest="spot_check",
369
+ help="Number of shards to spot-check by loading tensors (default: 3).",
370
+ )
371
+ args = parser.parse_args()
372
+
373
+ output_dir = os.path.abspath(args.output_dir)
374
+ print(f"Validating: {output_dir}\n")
375
+
376
+ if not os.path.isdir(output_dir):
377
+ print(f"[ERROR] Directory not found: {output_dir}")
378
+ sys.exit(1)
379
+
380
+ all_errors = []
381
+ all_warnings = []
382
+
383
+ # 1. Load config
384
+ print("[1/5] Loading config.json...")
385
+ config = load_config(output_dir)
386
+ if config is None:
387
+ print("[ERROR] Cannot proceed without config.json")
388
+ sys.exit(1)
389
+
390
+ num_layers = config.get("num_hidden_layers", 0)
391
+ num_experts = config.get("num_experts", 0)
392
+ first_k_dense = config.get("first_k_dense_replace", 0)
393
+ num_mtp = config.get("num_nextn_predict_layers", 0)
394
+ print(f" Layers : {num_layers} ({first_k_dense} dense, {num_layers - first_k_dense} MoE)")
395
+ print(f" MTP layers : {num_mtp}")
396
+ print(f" Experts/layer : {num_experts}")
397
+ print(f" Hidden size : {config.get('hidden_size', '?')}")
398
+ print(f" Expert hidden : {config.get('expert_hidden_dim', config.get('moe_intermediate_size', '?'))}")
399
+
400
+ # 2. Check index.json
401
+ print("\n[2/5] Checking model.safetensors.index.json...")
402
+ index, idx_errors = check_index_json(output_dir)
403
+ all_errors.extend(idx_errors)
404
+
405
+ if index is None or "weight_map" not in index:
406
+ print("[ERROR] Cannot proceed without valid index.json")
407
+ sys.exit(1)
408
+
409
+ weight_map = index["weight_map"]
410
+
411
+ # 3. Check expected keys
412
+ print("\n[3/5] Checking expected keys...")
413
+ key_errors, key_warnings = check_expected_keys(weight_map, config)
414
+ all_errors.extend(key_errors)
415
+ all_warnings.extend(key_warnings)
416
+
417
+ # Also check key distribution
418
+ dist_errors = check_key_distribution(weight_map)
419
+ all_errors.extend(dist_errors)
420
+
421
+ # 4. Check shard files
422
+ print("\n[4/5] Checking shard files on disk...")
423
+ shard_errors, shard_warnings = check_shard_files(output_dir, weight_map)
424
+ all_errors.extend(shard_errors)
425
+ all_warnings.extend(shard_warnings)
426
+
427
+ # 5. Spot-check
428
+ if args.spot_check > 0:
429
+ print(f"\n[5/5] Spot-checking tensors (loading {args.spot_check} shard(s))...")
430
+ spot_errors = spot_check_shards(output_dir, weight_map, config, args.spot_check)
431
+ all_errors.extend(spot_errors)
432
+ else:
433
+ print("\n[5/5] Spot-check skipped (--spot-check 0)")
434
+
435
+ # Summary
436
+ print(f"\n{'=' * 60}")
437
+ if all_warnings:
438
+ print(f"WARNINGS ({len(all_warnings)}):")
439
+ for w in all_warnings:
440
+ print(f" [WARN] {w}")
441
+
442
+ if all_errors:
443
+ print(f"ERRORS ({len(all_errors)}):")
444
+ for e in all_errors:
445
+ print(f" [ERROR] {e}")
446
+ print(f"\nResult: FAILED ({len(all_errors)} error(s), {len(all_warnings)} warning(s))")
447
+ sys.exit(1)
448
+ else:
449
+ print(f"Result: PASSED (0 errors, {len(all_warnings)} warning(s))")
450
+ print(f"{'=' * 60}")
451
+ sys.exit(0)
452
+
453
+
454
+ if __name__ == "__main__":
455
+ main()
train/tools/convert_ckpt_to_outer.py ADDED
@@ -0,0 +1,641 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Memory-friendly checkpoint converter: inner -> outer format (v2).
4
+
5
+ Converts the HYV3 checkpoint from inner format (per-expert keys, old naming)
6
+ to outer format (fused 3D experts, new naming) shard by shard.
7
+
8
+ Handles the case where a single layer's experts may be split across
9
+ multiple shards (cross-shard experts) by deferring their fusion to a
10
+ post-processing step.
11
+
12
+ v2 improvements over v1:
13
+ - Post-processing is shard-centric (each shard read/written only once)
14
+ instead of prefix-centric (same shard read/written multiple times).
15
+ This fixes Bus error (core dump) when there are many cross-shard groups.
16
+ - Explicit memory management with gc.collect() to prevent memory bloat.
17
+ - Better progress reporting during post-processing.
18
+
19
+ Supports multi-process parallelism for faster conversion.
20
+
21
+ Usage:
22
+ # Default 8 workers
23
+ python convert_ckpt_to_outer.py \\
24
+ --input_dir pretrain_base/hf \\
25
+ --output_dir pretrain_base/hf_outer
26
+
27
+ # Custom worker count
28
+ python convert_ckpt_to_outer.py \\
29
+ --input_dir pretrain_base/hf \\
30
+ --output_dir pretrain_base/hf_outer \\
31
+ --workers 16
32
+
33
+ The script will:
34
+ 1. Pre-scan index.json to detect cross-shard expert groups
35
+ 2. Convert weights shard-by-shard in parallel (key rename + expert fuse)
36
+ 3. Post-process cross-shard expert groups (merge from multiple shards)
37
+ - v2: shard-centric approach, each shard read/written only once
38
+ 4. Copy config.json as-is (already in outer format)
39
+ 5. Copy all other files (tokenizer, etc.)
40
+ 6. Rebuild model.safetensors.index.json
41
+ """
42
+
43
+ import argparse
44
+ import gc
45
+ import json
46
+ import os
47
+ import re
48
+ import signal
49
+ import shutil
50
+ import sys
51
+ import time
52
+ import traceback
53
+ from collections import OrderedDict, defaultdict
54
+ from multiprocessing import Pool
55
+
56
+ import torch
57
+
58
+ try:
59
+ from safetensors import safe_open
60
+ from safetensors.torch import save_file
61
+ except ImportError:
62
+ raise ImportError("Please install safetensors: pip install safetensors")
63
+
64
+ # ============================================================================
65
+ # Signal handling for Bus error (SIGBUS) and other fatal signals
66
+ # ============================================================================
67
+
68
+ def _fatal_signal_handler(signum, frame):
69
+ """Handle fatal signals (SIGBUS, SIGSEGV) by logging before exit.
70
+
71
+ These signals cannot be caught by try/except. This handler ensures
72
+ the error message is written to stderr (captured by nohup redirection)
73
+ before the process terminates.
74
+ """
75
+ sig_name = signal.Signals(signum).name if hasattr(signal, 'Signals') else str(signum)
76
+ pid = os.getpid()
77
+ msg = (
78
+ f"\n[FATAL] Process {pid} received {sig_name} (signal {signum}).\n"
79
+ f"This typically indicates an out-of-memory condition during mmap I/O.\n"
80
+ f"Stack trace at time of signal:\n"
81
+ )
82
+ sys.stderr.write(msg)
83
+ traceback.print_stack(frame, file=sys.stderr)
84
+ sys.stderr.flush()
85
+ # Re-raise with default handler to get proper exit code
86
+ signal.signal(signum, signal.SIG_DFL)
87
+ os.kill(pid, signum)
88
+
89
+
90
+ def _install_signal_handlers():
91
+ """Install handlers for SIGBUS and SIGSEGV in the current process."""
92
+ for sig in (signal.SIGBUS, signal.SIGSEGV):
93
+ try:
94
+ signal.signal(sig, _fatal_signal_handler)
95
+ except (OSError, ValueError):
96
+ # Some signals may not be available on all platforms
97
+ pass
98
+
99
+
100
+ def _pool_worker_init():
101
+ """Initializer for multiprocessing pool workers.
102
+
103
+ Installs signal handlers so that Bus errors in worker processes
104
+ are also logged before the process dies.
105
+ """
106
+ _install_signal_handlers()
107
+
108
+
109
+ # ============================================================================
110
+ # Key rename mapping (inner -> outer)
111
+ # ============================================================================
112
+
113
+ _KEY_RENAMES = [
114
+ ("mlp.router.gate.", "mlp.gate."),
115
+ ("mlp.expert_bias", "mlp.e_score_correction_bias"),
116
+ ("mlp.shared_mlp.", "mlp.shared_experts."),
117
+ ]
118
+
119
+ # Regex to match per-expert keys
120
+ _EXPERT_KEY_RE = re.compile(
121
+ r"^(.*\.mlp\.experts\.)(\d+)\.(gate_proj|up_proj|down_proj)\.weight$"
122
+ )
123
+
124
+ def rename_key(key: str) -> str:
125
+ """Rename a single key from inner to outer format."""
126
+ for old_sub, new_sub in _KEY_RENAMES:
127
+ if old_sub in key:
128
+ key = key.replace(old_sub, new_sub)
129
+ break
130
+ return key
131
+
132
+ def scan_cross_shard_experts(index_path: str):
133
+ """Pre-scan index.json to find expert groups that span multiple shards.
134
+
135
+ Returns:
136
+ cross_shard_prefixes: set of expert prefixes that span multiple shards
137
+ e.g. {"model.layers.80.mlp.experts."}
138
+ """
139
+ with open(index_path) as f:
140
+ index = json.load(f)
141
+ wm = index["weight_map"]
142
+
143
+ # prefix -> set of shards
144
+ prefix_shards = defaultdict(set)
145
+ for key in wm:
146
+ m = _EXPERT_KEY_RE.match(key)
147
+ if m:
148
+ prefix = m.group(1)
149
+ prefix_shards[prefix].add(wm[key])
150
+
151
+ cross_shard_prefixes = set()
152
+ for prefix, shards in prefix_shards.items():
153
+ if len(shards) > 1:
154
+ cross_shard_prefixes.add(prefix)
155
+
156
+ return cross_shard_prefixes
157
+
158
+ def convert_shard(shard_path: str, cross_shard_prefixes: set = None):
159
+ """Load a single shard, rename keys, and fuse experts.
160
+
161
+ For expert groups in cross_shard_prefixes, the per-expert keys are
162
+ kept as-is (just renamed) and returned separately as deferred items,
163
+ to be merged later in a post-processing step.
164
+
165
+ Returns:
166
+ result: OrderedDict of converted tensors (ready to save)
167
+ deferred_expert_keys: list of original expert keys that were deferred
168
+ (these are kept in result with their original per-expert naming
169
+ but with the outer rename applied, to be post-processed later)
170
+ """
171
+ if cross_shard_prefixes is None:
172
+ cross_shard_prefixes = set()
173
+
174
+ tensors = OrderedDict()
175
+ with safe_open(shard_path, framework="pt", device="cpu") as f:
176
+ for key in f.keys():
177
+ tensors[key] = f.get_tensor(key)
178
+
179
+ # Separate expert keys from non-expert keys
180
+ expert_groups = {} # prefix -> {expert_idx -> {proj_name -> tensor}}
181
+ deferred_expert_keys = [] # keys that belong to cross-shard experts
182
+ result = OrderedDict()
183
+
184
+ for key, tensor in tensors.items():
185
+ m = _EXPERT_KEY_RE.match(key)
186
+ if m:
187
+ prefix = m.group(1)
188
+ expert_idx = int(m.group(2))
189
+ proj_name = m.group(3)
190
+
191
+ if prefix in cross_shard_prefixes:
192
+ # Defer: keep the key as-is (with rename) for post-processing
193
+ new_key = rename_key(key)
194
+ result[new_key] = tensor
195
+ deferred_expert_keys.append(new_key)
196
+ else:
197
+ # Normal: collect for fusion within this shard
198
+ if prefix not in expert_groups:
199
+ expert_groups[prefix] = {}
200
+ if expert_idx not in expert_groups[prefix]:
201
+ expert_groups[prefix][expert_idx] = {}
202
+ expert_groups[prefix][expert_idx][proj_name] = tensor
203
+ else:
204
+ # Non-expert key: just rename
205
+ new_key = rename_key(key)
206
+ result[new_key] = tensor
207
+
208
+ # Fuse expert weights for each non-cross-shard layer prefix
209
+ for prefix in sorted(expert_groups.keys()):
210
+ experts = expert_groups[prefix]
211
+ num_experts = max(experts.keys()) + 1
212
+
213
+ gate_up_list = []
214
+ down_list = []
215
+ for i in range(num_experts):
216
+ if i not in experts:
217
+ raise ValueError(
218
+ f"Missing expert {i} in {prefix}. "
219
+ f"Found: {sorted(experts.keys())}"
220
+ )
221
+ exp = experts[i]
222
+ gate_up = torch.cat([exp["gate_proj"], exp["up_proj"]], dim=0)
223
+ gate_up_list.append(gate_up)
224
+ down_list.append(exp["down_proj"])
225
+
226
+ fused_gate_up = torch.stack(gate_up_list, dim=0)
227
+ fused_down = torch.stack(down_list, dim=0)
228
+
229
+ for exp in experts.values():
230
+ exp.clear()
231
+ gate_up_list.clear()
232
+ down_list.clear()
233
+
234
+ result[f"{prefix}gate_up_proj"] = fused_gate_up
235
+ result[f"{prefix}down_proj"] = fused_down
236
+
237
+ return result, deferred_expert_keys
238
+
239
+ def _process_one_shard(args_tuple):
240
+ """Worker function: convert a single shard and save to output dir.
241
+
242
+ Args:
243
+ args_tuple: (idx, num_shards, shard_file, input_dir, output_dir, cross_shard_prefixes)
244
+
245
+ Returns:
246
+ (shard_file, key_list, shard_size, elapsed, deferred_keys)
247
+ """
248
+ idx, num_shards, shard_file, input_dir, output_dir, cross_shard_prefixes = args_tuple
249
+ shard_path = os.path.join(input_dir, shard_file)
250
+ t0 = time.time()
251
+
252
+ converted, deferred_keys = convert_shard(shard_path, cross_shard_prefixes)
253
+
254
+ shard_size = sum(t.numel() * t.element_size() for t in converted.values())
255
+
256
+ out_shard_path = os.path.join(output_dir, shard_file)
257
+ save_file(converted, out_shard_path)
258
+
259
+ elapsed = time.time() - t0
260
+ num_keys = len(converted)
261
+ key_list = list(converted.keys())
262
+
263
+ del converted
264
+
265
+ deferred_info = ""
266
+ if deferred_keys:
267
+ deferred_info = f", Deferred={len(deferred_keys)}"
268
+
269
+ print(
270
+ f" [{idx + 1}/{num_shards}] {shard_file}: "
271
+ f"Keys={num_keys}, Size={shard_size / 1e9:.2f} GB, "
272
+ f"Time={elapsed:.1f}s{deferred_info}",
273
+ flush=True,
274
+ )
275
+
276
+ return shard_file, key_list, shard_size, elapsed, deferred_keys
277
+
278
+
279
+ def post_process_cross_shard_experts(output_dir, cross_shard_prefixes, all_deferred):
280
+ """Merge cross-shard expert groups (v2: shard-centric approach).
281
+
282
+ Instead of iterating per-prefix (which causes the same shard to be
283
+ loaded/saved multiple times), this v2 approach:
284
+ 1. Builds a mapping of which prefixes each shard is involved in
285
+ 2. Collects all expert tensors from all involved shards in ONE pass
286
+ 3. Fuses all prefixes
287
+ 4. Writes each shard only ONCE with all its updates applied
288
+
289
+ This avoids the Bus error (core dump) caused by repeated mmap of
290
+ large files and memory bloat.
291
+
292
+ Args:
293
+ output_dir: path to output directory
294
+ cross_shard_prefixes: set of expert prefixes that span multiple shards
295
+ all_deferred: dict of {shard_file: [deferred_key, ...]}
296
+
297
+ Returns:
298
+ updated_shards: dict of {shard_file: (key_list, shard_size)} for updated shards
299
+ """
300
+ if not cross_shard_prefixes:
301
+ return {}
302
+
303
+ print(f"\n Post-processing {len(cross_shard_prefixes)} cross-shard expert group(s)...",
304
+ flush=True)
305
+
306
+ # ----------------------------------------------------------------
307
+ # Step 1: Build mappings
308
+ # ----------------------------------------------------------------
309
+ # prefix -> ordered list of shards that contain its experts
310
+ prefix_to_shards = defaultdict(set)
311
+ # shard -> set of prefixes it is involved in
312
+ shard_to_prefixes = defaultdict(set)
313
+
314
+ for shard_file, deferred_keys in all_deferred.items():
315
+ for key in deferred_keys:
316
+ m = _EXPERT_KEY_RE.match(key)
317
+ if m:
318
+ prefix = m.group(1)
319
+ if prefix in cross_shard_prefixes:
320
+ prefix_to_shards[prefix].add(shard_file)
321
+ shard_to_prefixes[shard_file].add(prefix)
322
+
323
+ # For each prefix, decide which shard will hold the fused result
324
+ # (use the first shard alphabetically)
325
+ prefix_to_target_shard = {}
326
+ for prefix in sorted(prefix_to_shards.keys()):
327
+ target = sorted(prefix_to_shards[prefix])[0]
328
+ prefix_to_target_shard[prefix] = target
329
+
330
+ # All shards that need to be updated
331
+ all_involved_shards = set()
332
+ for shards in prefix_to_shards.values():
333
+ all_involved_shards.update(shards)
334
+
335
+ print(f" Involved shards: {len(all_involved_shards)}", flush=True)
336
+ print(f" Expert groups: {len(prefix_to_shards)}", flush=True)
337
+
338
+ # ----------------------------------------------------------------
339
+ # Step 2: Collect all expert tensors from all involved shards
340
+ # (one pass per shard)
341
+ # ----------------------------------------------------------------
342
+ # prefix -> {expert_idx -> {proj_name -> tensor}}
343
+ all_expert_data = defaultdict(dict)
344
+ # shard -> OrderedDict of non-expert keys (to be re-saved)
345
+ shard_non_expert = {}
346
+
347
+ sorted_involved = sorted(all_involved_shards)
348
+ for si, shard_file in enumerate(sorted_involved):
349
+ shard_path = os.path.join(output_dir, shard_file)
350
+ prefixes_in_shard = shard_to_prefixes[shard_file]
351
+
352
+ print(f" [{si+1}/{len(sorted_involved)}] Reading {shard_file} "
353
+ f"({len(prefixes_in_shard)} prefix(es))...", flush=True)
354
+
355
+ non_expert = OrderedDict()
356
+ with safe_open(shard_path, framework="pt", device="cpu") as f:
357
+ for key in f.keys():
358
+ m = _EXPERT_KEY_RE.match(key)
359
+ if m and m.group(1) in prefixes_in_shard:
360
+ # This is a deferred expert key
361
+ prefix = m.group(1)
362
+ expert_idx = int(m.group(2))
363
+ proj_name = m.group(3)
364
+ if expert_idx not in all_expert_data[prefix]:
365
+ all_expert_data[prefix][expert_idx] = {}
366
+ all_expert_data[prefix][expert_idx][proj_name] = f.get_tensor(key)
367
+ else:
368
+ # Non-expert key: keep as-is
369
+ non_expert[key] = f.get_tensor(key)
370
+
371
+ shard_non_expert[shard_file] = non_expert
372
+ gc.collect()
373
+
374
+ # ----------------------------------------------------------------
375
+ # Step 3: Fuse all expert groups
376
+ # ----------------------------------------------------------------
377
+ # prefix -> {"gate_up_proj": tensor, "down_proj": tensor}
378
+ fused_results = {}
379
+
380
+ for pi, prefix in enumerate(sorted(all_expert_data.keys())):
381
+ expert_data = all_expert_data[prefix]
382
+ num_experts = max(expert_data.keys()) + 1
383
+
384
+ print(f" Fusing {prefix} ({num_experts} experts)...", flush=True)
385
+
386
+ gate_up_list = []
387
+ down_list = []
388
+ for i in range(num_experts):
389
+ if i not in expert_data:
390
+ raise ValueError(
391
+ f"Missing expert {i} in {prefix} after cross-shard merge. "
392
+ f"Found: {sorted(expert_data.keys())}"
393
+ )
394
+ exp = expert_data[i]
395
+ if "gate_proj" not in exp or "up_proj" not in exp:
396
+ raise ValueError(
397
+ f"Expert {i} in {prefix} missing gate_proj/up_proj. "
398
+ f"Has: {sorted(exp.keys())}"
399
+ )
400
+ if "down_proj" not in exp:
401
+ raise ValueError(
402
+ f"Expert {i} in {prefix} missing down_proj. "
403
+ f"Has: {sorted(exp.keys())}"
404
+ )
405
+ gate_up = torch.cat([exp["gate_proj"], exp["up_proj"]], dim=0)
406
+ gate_up_list.append(gate_up)
407
+ down_list.append(exp["down_proj"])
408
+
409
+ fused_gate_up = torch.stack(gate_up_list, dim=0)
410
+ fused_down = torch.stack(down_list, dim=0)
411
+
412
+ fused_results[prefix] = {
413
+ "gate_up_proj": fused_gate_up,
414
+ "down_proj": fused_down,
415
+ }
416
+
417
+ # Free per-expert data for this prefix
418
+ del gate_up_list, down_list
419
+ for exp in expert_data.values():
420
+ exp.clear()
421
+ del all_expert_data[prefix]
422
+ gc.collect()
423
+
424
+ del all_expert_data
425
+ gc.collect()
426
+
427
+ # ----------------------------------------------------------------
428
+ # Step 4: Write each involved shard ONCE with all updates applied
429
+ # ----------------------------------------------------------------
430
+ updated_shards = {}
431
+
432
+ for si, shard_file in enumerate(sorted_involved):
433
+ shard_path = os.path.join(output_dir, shard_file)
434
+ non_expert = shard_non_expert[shard_file]
435
+
436
+ # Add fused tensors for prefixes that target this shard
437
+ fused_added = []
438
+ for prefix, target_shard in prefix_to_target_shard.items():
439
+ if target_shard == shard_file and prefix in fused_results:
440
+ non_expert[f"{prefix}gate_up_proj"] = fused_results[prefix]["gate_up_proj"]
441
+ non_expert[f"{prefix}down_proj"] = fused_results[prefix]["down_proj"]
442
+ fused_added.append(prefix)
443
+
444
+ save_file(non_expert, shard_path)
445
+ shard_size = sum(t.numel() * t.element_size() for t in non_expert.values())
446
+ updated_shards[shard_file] = (list(non_expert.keys()), shard_size)
447
+
448
+ fused_info = ""
449
+ if fused_added:
450
+ fused_info = f", Fused {len(fused_added)} group(s)"
451
+
452
+ print(f" [{si+1}/{len(sorted_involved)}] Wrote {shard_file}: "
453
+ f"{len(non_expert)} keys, {shard_size / 1e9:.2f} GB{fused_info}",
454
+ flush=True)
455
+
456
+ # Free memory for this shard
457
+ del shard_non_expert[shard_file]
458
+ for prefix in fused_added:
459
+ del fused_results[prefix]
460
+ del non_expert
461
+ gc.collect()
462
+
463
+ return updated_shards
464
+
465
+
466
+ def main():
467
+ parser = argparse.ArgumentParser(
468
+ description="Convert HYV3 checkpoint from inner to outer format (v2, shard-centric post-processing)."
469
+ )
470
+ parser.add_argument(
471
+ "--input_dir", type=str, required=True,
472
+ help="Path to the inner-format checkpoint directory.",
473
+ )
474
+ parser.add_argument(
475
+ "--output_dir", type=str, required=True,
476
+ help="Path to the output outer-format checkpoint directory.",
477
+ )
478
+ parser.add_argument(
479
+ "--workers", type=int, default=8,
480
+ help="Number of parallel worker processes (default: 8).",
481
+ )
482
+ args = parser.parse_args()
483
+
484
+ input_dir = os.path.abspath(args.input_dir)
485
+ output_dir = os.path.abspath(args.output_dir)
486
+ num_workers = args.workers
487
+
488
+ if not os.path.isdir(input_dir):
489
+ raise FileNotFoundError(f"Input directory not found: {input_dir}")
490
+
491
+ os.makedirs(output_dir, exist_ok=True)
492
+
493
+ # Pre-scan for cross-shard expert groups
494
+ index_path = os.path.join(input_dir, "model.safetensors.index.json")
495
+ cross_shard_prefixes = set()
496
+ if os.path.exists(index_path):
497
+ cross_shard_prefixes = scan_cross_shard_experts(index_path)
498
+ if cross_shard_prefixes:
499
+ print(f"Detected {len(cross_shard_prefixes)} cross-shard expert group(s):")
500
+ for p in sorted(cross_shard_prefixes):
501
+ print(f" - {p}")
502
+ print()
503
+
504
+ # Get all safetensors files
505
+ shard_files = sorted(
506
+ f for f in os.listdir(input_dir) if f.endswith(".safetensors")
507
+ )
508
+ if not shard_files:
509
+ raise FileNotFoundError(f"No .safetensors files found in {input_dir}")
510
+
511
+ # Skip already-converted shards (for resumability)
512
+ # NOTE: if there are cross-shard experts, we cannot skip shards that
513
+ # contain deferred keys (they need post-processing). For simplicity,
514
+ # when cross-shard experts exist, we re-process all shards.
515
+ remaining = []
516
+ skipped = []
517
+ if cross_shard_prefixes:
518
+ # Re-process all shards when cross-shard experts exist
519
+ remaining = list(shard_files)
520
+ else:
521
+ for sf in shard_files:
522
+ out_path = os.path.join(output_dir, sf)
523
+ if os.path.exists(out_path) and os.path.getsize(out_path) > 0:
524
+ skipped.append(sf)
525
+ else:
526
+ remaining.append(sf)
527
+
528
+ num_shards = len(shard_files)
529
+ num_workers = min(num_workers, len(remaining)) if remaining else 1
530
+
531
+ print(f"=" * 60)
532
+ print(f"HYV3 Checkpoint Converter (inner -> outer, v2)")
533
+ print(f" Input : {input_dir}")
534
+ print(f" Output : {output_dir}")
535
+ print(f" Shards : {num_shards} total, {len(skipped)} already done, {len(remaining)} to process")
536
+ print(f" Workers: {num_workers}")
537
+ if cross_shard_prefixes:
538
+ print(f" Cross-shard experts: {len(cross_shard_prefixes)} group(s) (will post-process)")
539
+ print(f"=" * 60)
540
+
541
+ t_start = time.time()
542
+
543
+ # Build task list for remaining shards
544
+ tasks = [
545
+ (i, len(remaining), sf, input_dir, output_dir, cross_shard_prefixes)
546
+ for i, sf in enumerate(remaining)
547
+ ]
548
+
549
+ # Process in parallel
550
+ results = []
551
+ if tasks:
552
+ with Pool(processes=num_workers, initializer=_pool_worker_init) as pool:
553
+ results = pool.map(_process_one_shard, tasks)
554
+
555
+ # Collect deferred keys info
556
+ all_deferred = {} # shard_file -> [deferred_keys]
557
+ for shard_file, key_list, shard_size, elapsed, deferred_keys in results:
558
+ if deferred_keys:
559
+ all_deferred[shard_file] = deferred_keys
560
+
561
+ # Post-process cross-shard expert groups (v2: shard-centric)
562
+ updated_shards = {}
563
+ if cross_shard_prefixes and all_deferred:
564
+ updated_shards = post_process_cross_shard_experts(
565
+ output_dir, cross_shard_prefixes, all_deferred
566
+ )
567
+
568
+ # Build weight_map and total_size
569
+ weight_map = OrderedDict()
570
+ total_size = 0
571
+
572
+ # For skipped shards, read their keys from the output files
573
+ for sf in skipped:
574
+ out_path = os.path.join(output_dir, sf)
575
+ with safe_open(out_path, framework="pt", device="cpu") as f:
576
+ keys = list(f.keys())
577
+ for key in keys:
578
+ weight_map[key] = sf
579
+ t = f.get_tensor(key)
580
+ total_size += t.numel() * t.element_size()
581
+
582
+ # Collect results from newly converted shards
583
+ for shard_file, key_list, shard_size, elapsed, deferred_keys in results:
584
+ if shard_file in updated_shards:
585
+ # This shard was updated by post-processing
586
+ updated_key_list, updated_size = updated_shards[shard_file]
587
+ for key in updated_key_list:
588
+ weight_map[key] = shard_file
589
+ total_size += updated_size
590
+ else:
591
+ for key in key_list:
592
+ weight_map[key] = shard_file
593
+ total_size += shard_size
594
+
595
+ # Build and save index
596
+ sorted_weight_map = OrderedDict(sorted(weight_map.items()))
597
+ index = {
598
+ "metadata": {"total_size": total_size},
599
+ "weight_map": sorted_weight_map,
600
+ }
601
+ index_path_out = os.path.join(output_dir, "model.safetensors.index.json")
602
+ with open(index_path_out, "w") as f:
603
+ json.dump(index, f, indent=2)
604
+ f.write("\n")
605
+ print(f"\nSaved {index_path_out}")
606
+
607
+ # Copy non-safetensors files (config, tokenizer, etc.)
608
+ skip_suffixes = {".safetensors"}
609
+ skip_names = {"model.safetensors.index.json"}
610
+ copied = []
611
+ for fname in os.listdir(input_dir):
612
+ if fname in skip_names:
613
+ continue
614
+ if any(fname.endswith(s) for s in skip_suffixes):
615
+ continue
616
+ src = os.path.join(input_dir, fname)
617
+ dst = os.path.join(output_dir, fname)
618
+ if os.path.isfile(src):
619
+ shutil.copy2(src, dst)
620
+ copied.append(fname)
621
+ elif os.path.isdir(src):
622
+ if os.path.exists(dst):
623
+ shutil.rmtree(dst)
624
+ shutil.copytree(src, dst)
625
+ copied.append(fname + "/")
626
+
627
+ if copied:
628
+ print(f"\nCopied files: {', '.join(copied)}")
629
+
630
+ t_total = time.time() - t_start
631
+ print(f"\n{'=' * 60}")
632
+ print(f"Conversion complete!")
633
+ print(f" Total keys : {len(weight_map)}")
634
+ print(f" Total size : {total_size / 1e9:.2f} GB")
635
+ print(f" Total time : {t_total:.1f}s ({t_total / 60:.1f} min)")
636
+ print(f" Output dir : {output_dir}")
637
+ print(f"{'=' * 60}")
638
+
639
+ if __name__ == "__main__":
640
+ _install_signal_handlers()
641
+ main()
train/train.py ADDED
@@ -0,0 +1,564 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Tencent Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
16
+ #
17
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
18
+ # and OPT implementations in this library. It has been modified from its
19
+ # original forms to accommodate minor architectural differences compared
20
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
21
+ #
22
+ # Licensed under the Apache License, Version 2.0 (the "License");
23
+ # you may not use this file except in compliance with the License.
24
+ # You may obtain a copy of the License at
25
+ #
26
+ # http://www.apache.org/licenses/LICENSE-2.0
27
+ #
28
+ # Unless required by applicable law or agreed to in writing, software
29
+ # distributed under the License is distributed on an "AS IS" BASIS,
30
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31
+ # See the License for the specific language governing permissions and
32
+ # limitations under the License.
33
+
34
+
35
+ import os
36
+ import re
37
+ import sys
38
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
39
+ import json
40
+ import torch
41
+ import shutil
42
+ import logging
43
+ from dataclasses import dataclass, field
44
+ import deepspeed
45
+ from typing import Optional, Dict
46
+
47
+ import transformers
48
+ from torch.utils.data import Dataset
49
+ from transformers import Trainer, TrainerCallback
50
+ from peft import LoraConfig, get_peft_model, PeftModel
51
+ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
52
+ from transformers.modeling_utils import unwrap_model
53
+
54
+
55
+ def print_args(args, name='arguments'):
56
+ """Print arguments."""
57
+ if torch.distributed.get_rank() == 0:
58
+ print(f'------------------------ {name} ------------------------', flush=True)
59
+ str_list = []
60
+ for arg in vars(args):
61
+ dots = '.' * (48 - len(arg))
62
+ str_list.append(' {} {} {}'.format(arg, dots, getattr(args, arg)))
63
+ for arg in sorted(str_list, key=lambda x: x.lower()):
64
+ print(arg, flush=True)
65
+ print(f'-------------------- end of {name} ---------------------', flush=True)
66
+
67
+
68
+ @dataclass
69
+ class ModelArguments:
70
+ use_flash_attn: bool = field(
71
+ default=False,
72
+ metadata={"help": "Enable FlashAttention-2 for faster training."}
73
+ )
74
+ use_lora: bool = field(default=False, metadata={"help": "Enable Lora for faster training."})
75
+ hidden_size: int = field(default=2048, metadata={"help": "The hidden size of the model."})
76
+ num_layers: int = field(default=24, metadata={"help": "The number of layers of the model."})
77
+ num_attention_heads: int = field(default=16, metadata={"help": "The number of attention heads of the model."})
78
+ intermediate_size: int = field(default=8192, metadata={"help": "The intermediate size of the model."})
79
+ max_position_embeddings: int = field(
80
+ default=2048,
81
+ metadata={"help": "The maximum sequence length that this model might ever be used with."}
82
+ )
83
+ vocab_size: int = field(default=50257, metadata={"help": "The vocabulary size of the model."})
84
+ type_vocab_size: int = field(default=1, metadata={"help": "The vocabulary size of the model."})
85
+ layer_norm_eps: float = field(
86
+ default=1e-5,
87
+ metadata={"help": "The epsilon used by the layer normalization layers of the model."}
88
+ )
89
+ moe_topk: int = field(default=4, metadata={"help": "The topk for MOE."})
90
+ num_experts: int = field(default=8, metadata={"help": "The number of experts for MOE."})
91
+ num_key_value_heads: int = field(default=16, metadata={"help": "The number of key-value heads in GQA."})
92
+ moe_intermediate_size: int = field(default=1536, metadata={"help": "The intermediate size of each MoE expert."})
93
+ use_mixed_mlp_moe: bool = field(
94
+ default=False,
95
+ metadata={"help": "Whether to use mixed MoE with shared expert."}
96
+ )
97
+ num_shared_expert: int = field(default=1, metadata={"help": "Number of shared experts."})
98
+ use_qk_norm: bool = field(default=False, metadata={"help": "Whether to use qk norm."})
99
+ moe_layer_num_skipped: int = field(default=1, metadata={"help": "Number of initial dense layers before MoE layers."})
100
+ tie_word_embeddings: bool = field(
101
+ default=True,
102
+ metadata={"help": "Whether to tie the word embeddings of the encoder and the decoder."}
103
+ )
104
+ lora_rank: int = field(default=64, metadata={"help": "The rank of lora."})
105
+ lora_alpha: int = field(default=8, metadata={"help": "Lora alpha"})
106
+ lora_dropout: float = field(default=0.0, metadata={"help": "Lora dropout"})
107
+ train_attention_params_only: bool = field(default=False, metadata={
108
+ "help": "Whether to train attention parameters only."}
109
+ )
110
+
111
+
112
+ @dataclass
113
+ class DataArguments:
114
+ train_data_file: str = field(default=None, metadata={"help": "Path to the training data."})
115
+ max_seq_length: int = field(
116
+ default=2048,
117
+ metadata={"help": "The max sequence length of the model inputs after tokenization."}
118
+ )
119
+ complex_data: Optional[str] = field(default=None)
120
+ use_dummy_data: bool = field(default=False, metadata={"help": "Use dummy data."})
121
+
122
+
123
+ @dataclass
124
+ class TrainingArguments(transformers.TrainingArguments):
125
+ cache_dir: Optional[str] = field(default=None)
126
+ optim: str = field(default="adamw_torch")
127
+ model_max_length: int = field(
128
+ default=2048,
129
+ metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
130
+ )
131
+ tokenizer_name_or_path: Optional[str] = field(default=None)
132
+ model_name_or_path: Optional[str] = field(default=None)
133
+ min_lr: float = field(
134
+ default=0.01,
135
+ metadata={"help": "The final learning rate at the end of the decay will be learning_rate * min_lr"}
136
+ )
137
+
138
+
139
+ IGNORE_INDEX = -100
140
+
141
+
142
+ class DummyDataset(Dataset):
143
+ def __init__(self, tokenizer, max_seq_length=512, length=1000):
144
+ self.tokenizer = tokenizer
145
+ self.max_seq_length = max_seq_length
146
+ self.length = length
147
+
148
+ def __len__(self):
149
+ return self.length
150
+
151
+ def __getitem__(self, index):
152
+ tokens = torch.randint(0, self.tokenizer.vocab_size, (self.max_seq_length, ))
153
+ return {'input_ids': tokens, 'labels': tokens}
154
+
155
+
156
+ class SFTDataset(Dataset):
157
+ def __init__(self, data_file, tokenizer, max_seq_length = 2048, prompt_format = 'mplus'):
158
+ self.tokenizer = tokenizer
159
+ self.prompt_format = prompt_format
160
+ self.max_seq_length = max_seq_length
161
+
162
+ self.data_list = self.load_data(data_file)
163
+
164
+ def __len__(self):
165
+ return len(self.data_list)
166
+
167
+ def load_data(self, data_file):
168
+ logging.info('Loading data: {}'.format(data_file))
169
+ with open(data_file, 'r', encoding='utf8') as f:
170
+ data_list = f.readlines()
171
+ logging.info("there are {} data in dataset".format(len(data_list)))
172
+ return data_list
173
+
174
+ def encode_data(self, data_dict):
175
+ model_inputs = {}
176
+ reasoning_effort = data_dict.get('reasoning_effort', None)
177
+ if reasoning_effort is None:
178
+ reasoning_effort = 'no_think'
179
+ template_output = self.tokenizer.apply_chat_template(data_dict['messages'], tokenize=True, return_dict=False, is_training=True, reasoning_effort=reasoning_effort)
180
+ if isinstance(template_output, list) and len(template_output) > 0 and isinstance(template_output[0], list):
181
+ template_output = template_output[0]
182
+ message_tokens = torch.tensor(template_output, dtype=torch.long)
183
+
184
+ # Use new HunYuan tokenizer special tokens
185
+ assistant_token_id = self.tokenizer.convert_tokens_to_ids('<|hy_Assistant|>')
186
+ eos_token_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.eos_token)
187
+ pad_token_id = self.tokenizer.pad_token_id
188
+
189
+ # Find assistant reply boundaries: starts at <|hy_Assistant|>, ends at eos_token
190
+ loss_token_begins = (message_tokens == assistant_token_id).nonzero(as_tuple=True)[0].tolist()
191
+ loss_token_ends = (message_tokens == eos_token_id).nonzero(as_tuple=True)[0].tolist()
192
+ message_labels = torch.tensor([IGNORE_INDEX] * message_tokens.shape[0])
193
+ for begin_idx, end_idx in zip(loss_token_begins, loss_token_ends):
194
+ # Compute loss from the token after <|hy_Assistant|> to eos_token (inclusive)
195
+ message_labels[begin_idx + 1:end_idx + 1] = message_tokens[begin_idx + 1:end_idx + 1]
196
+ input_ids = message_tokens.to(torch.long)
197
+ labels = message_labels.to(torch.long)
198
+
199
+ input_ids = input_ids[:self.max_seq_length]
200
+ labels = labels[:self.max_seq_length]
201
+ attention_mask = [1 if val != pad_token_id else 0 for val in input_ids]
202
+ model_inputs["input_ids"] = input_ids
203
+ model_inputs["attention_mask"] = torch.tensor(attention_mask, dtype=torch.bool)
204
+ model_inputs["labels"] = labels
205
+
206
+ return model_inputs
207
+
208
+ def __getitem__(self, index):
209
+ data = self.data_list[index]
210
+ data = json.loads(data)
211
+ model_inputs = self.encode_data(data)
212
+
213
+ return model_inputs
214
+
215
+
216
+ @dataclass
217
+ class DataCollatorForSupervisedDataset(object):
218
+ """Collate examples for supervised fine-tuning."""
219
+
220
+ tokenizer: transformers.PreTrainedTokenizer
221
+
222
+ def __call__(self, instances):
223
+ input_ids = [instance['input_ids'] for instance in instances]
224
+ labels = [instance['labels'] for instance in instances]
225
+ pad_token_id = self.tokenizer.pad_token_id
226
+ input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=pad_token_id)
227
+ labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
228
+ return dict(
229
+ input_ids=input_ids,
230
+ labels=labels,
231
+ attention_mask=input_ids.ne(pad_token_id),
232
+ )
233
+
234
+
235
+ def make_supervised_data_module(tokenizer, data_args) -> Dict:
236
+ """Make dataset and collator for supervised fine-tuning."""
237
+ if data_args.use_dummy_data:
238
+ train_dataset = DummyDataset(tokenizer, data_args.max_seq_length)
239
+ else:
240
+ train_dataset = SFTDataset(
241
+ tokenizer=tokenizer,
242
+ data_file=data_args.train_data_file,
243
+ max_seq_length=data_args.max_seq_length
244
+ )
245
+ data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
246
+ return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator)
247
+
248
+
249
+ # for full model training, change the config.json, copy the model and configuration to support Auto load
250
+ class CustomSaveCallback(TrainerCallback):
251
+ def on_save(self, args, state, control, **kwargs):
252
+ if torch.distributed.get_rank() == 0:
253
+ output_dir = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
254
+
255
+ # Copy tokenizer files to checkpoint directory
256
+ tokenizer_files = [
257
+ 'generation_config.json',
258
+ 'hy.tiktoken',
259
+ 'tokenizer_config.json',
260
+ 'tokenization_hy.py',
261
+ 'tokenizer.json',
262
+ 'special_tokens_map.json',
263
+ 'chat_template.jinja',
264
+ ]
265
+ for fname in tokenizer_files:
266
+ src = os.path.join(args.tokenizer_name_or_path, fname)
267
+ if os.path.isfile(src):
268
+ shutil.copy(src, os.path.join(output_dir, fname))
269
+
270
+ return control
271
+
272
+
273
+ def train():
274
+ parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
275
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
276
+ print_args(model_args, 'model arguments')
277
+ print_args(data_args, 'data arguments')
278
+ print_args(training_args, 'training arguments')
279
+
280
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
281
+ training_args.tokenizer_name_or_path,
282
+ trust_remote_code = True
283
+ )
284
+
285
+ init_kwargs = {}
286
+ if model_args.use_flash_attn:
287
+ init_kwargs["attn_implementation"] = "flash_attention_2"
288
+ # Workaround: transformers >= 5.x uses importlib.metadata.packages_distributions()
289
+ # to verify flash-attn package name, which fails when the package is installed under
290
+ # a custom distribution name (e.g. ptm-flash-attn). Patch the check to skip it.
291
+ try:
292
+ from transformers.modeling_flash_attention_utils import FLASH_ATTENTION_COMPATIBILITY_MATRIX
293
+ _orig_pkg_check = FLASH_ATTENTION_COMPATIBILITY_MATRIX[2]["pkg_availability_check"]
294
+ FLASH_ATTENTION_COMPATIBILITY_MATRIX[2]["pkg_availability_check"] = lambda *a, **kw: True
295
+ print("[Patch] Bypassed flash_attn package distribution name check for FA2.")
296
+ except Exception as e:
297
+ print(f"[Patch] Could not patch FA2 pkg check (non-fatal): {e}")
298
+ if training_args.bf16:
299
+ init_kwargs["dtype"] = torch.bfloat16
300
+ elif training_args.fp16:
301
+ init_kwargs["dtype"] = torch.float16
302
+
303
+ # Check if model weights exist (not just the directory)
304
+ _has_weights = (
305
+ training_args.model_name_or_path is not None
306
+ and os.path.isdir(training_args.model_name_or_path)
307
+ and any(
308
+ os.path.isfile(os.path.join(training_args.model_name_or_path, f))
309
+ for f in ("model.safetensors", "pytorch_model.bin", "model.safetensors.index.json", "pytorch_model.bin.index.json")
310
+ )
311
+ )
312
+
313
+ # -----------------------------------------------------------------------
314
+ # Fix: Rename checkpoint keys so that old-style weight names (e.g.
315
+ # self_attn.q_norm) are mapped to the current model attribute names
316
+ # (e.g. self_attn.query_layernorm). The model's
317
+ # _fix_state_dict_key_on_load hook is NOT invoked on the DeepSpeed
318
+ # ZeRO-3 loading path, so we monkey-patch the ZeRO-3 loader instead.
319
+ # -----------------------------------------------------------------------
320
+ _CKPT_KEY_RENAMES = [
321
+ ("mlp.gate.wg.", "mlp.router.gate."),
322
+ ]
323
+
324
+ from transformers.integrations.deepspeed import (
325
+ _load_state_dict_into_zero3_model as _orig_load_zero3,
326
+ )
327
+ import transformers.integrations.deepspeed as _ds_mod
328
+ import transformers.modeling_utils as _mu_mod
329
+
330
+ def _patched_load_zero3(model_to_load, state_dict, load_config=None):
331
+ new_sd = {}
332
+ for k, v in state_dict.items():
333
+ new_k = k
334
+ for old_sub, new_sub in _CKPT_KEY_RENAMES:
335
+ if old_sub in new_k:
336
+ new_k = new_k.replace(old_sub, new_sub)
337
+ break
338
+ new_sd[new_k] = v
339
+
340
+ # Call original ZeRO-3 loader for parameters
341
+ result = _orig_load_zero3(model_to_load, new_sd, load_config)
342
+
343
+ # -------------------------------------------------------------------
344
+ # Patch: Manually load buffers (e.g. e_score_correction_bias).
345
+ # ZeRO-3's loader only handles named_parameters, not named_buffers.
346
+ # -------------------------------------------------------------------
347
+ buffers_loaded = 0
348
+ for name, buf in model_to_load.named_buffers():
349
+ if name in new_sd:
350
+ src_tensor = new_sd[name]
351
+ if isinstance(src_tensor, torch.Tensor):
352
+ buf.data.copy_(src_tensor.to(buf.dtype))
353
+ buffers_loaded += 1
354
+ # Remove from unexpected keys if tracked
355
+ if isinstance(result, tuple) and len(result) >= 2:
356
+ if isinstance(result[1], set):
357
+ result[1].discard(name)
358
+ if buffers_loaded > 0:
359
+ print(f"[HYV3 Patch] Manually loaded {buffers_loaded} buffers "
360
+ f"(e.g. e_score_correction_bias) into model.")
361
+
362
+ return result
363
+
364
+ _ds_mod._load_state_dict_into_zero3_model = _patched_load_zero3
365
+ _mu_mod._load_state_dict_into_zero3_model = _patched_load_zero3
366
+ # -----------------------------------------------------------------------
367
+
368
+ # -------------------------------------------------------------------
369
+ # Patch: Save-time reverse key rename + 3D -> per-expert unfuse.
370
+ #
371
+ # When saving checkpoints, the model state_dict uses 3D fused experts
372
+ # and new naming. We reverse both for old checkpoint compatibility:
373
+ # - mlp.gate. -> mlp.router.gate.
374
+ # - mlp.e_score_correction_bias -> mlp.expert_bias
375
+ # - mlp.shared_experts. -> mlp.shared_mlp.
376
+ # - experts.gate_up_proj -> experts.{N}.gate_proj.weight + up_proj
377
+ # - experts.down_proj -> experts.{N}.down_proj.weight
378
+ # -------------------------------------------------------------------
379
+ _SAVE_KEY_RENAMES = [
380
+ ("mlp.gate.", "mlp.router.gate."),
381
+ ("mlp.e_score_correction_bias", "mlp.expert_bias"),
382
+ ("mlp.shared_experts.", "mlp.shared_mlp."),
383
+ ]
384
+ _FUSED_EXPERT_KEY_RE = re.compile(
385
+ r"^(.*\.mlp\.experts\.)(gate_up_proj|down_proj)$"
386
+ )
387
+
388
+ def _apply_save_reverse_rename_patch():
389
+ try:
390
+ from transformers.models.hy_v3.modeling_hy_v3 import HYV3ForCausalLM
391
+ except ImportError:
392
+ try:
393
+ from transformers.hy_v3.modeling_hy_v3 import HYV3ForCausalLM
394
+ except ImportError:
395
+ print("[HYV3 Patch] Could not import HYV3ForCausalLM; "
396
+ "save reverse rename patch NOT applied.")
397
+ return
398
+
399
+ _orig_save_pretrained = HYV3ForCausalLM.save_pretrained
400
+
401
+ def _patched_save_pretrained(self, *args, **kwargs):
402
+ state_dict = kwargs.get("state_dict", None)
403
+ if state_dict is not None:
404
+ reversed_sd = {}
405
+ for k, v in state_dict.items():
406
+ new_k = k
407
+ # Apply simple key renames
408
+ for new_sub, old_sub in _SAVE_KEY_RENAMES:
409
+ if new_sub in new_k:
410
+ new_k = new_k.replace(new_sub, old_sub)
411
+ break
412
+
413
+ # Check if this is a fused 3D expert key
414
+ m = _FUSED_EXPERT_KEY_RE.match(new_k)
415
+ if m:
416
+ prefix = m.group(1) # e.g. "model.layers.1.mlp.experts."
417
+ proj_type = m.group(2) # "gate_up_proj" or "down_proj"
418
+
419
+ if proj_type == "gate_up_proj":
420
+ # v shape: [num_experts, 2*intermediate, hidden]
421
+ num_experts = v.shape[0]
422
+ intermediate = v.shape[1] // 2
423
+ for i in range(num_experts):
424
+ gate = v[i, :intermediate, :]
425
+ up = v[i, intermediate:, :]
426
+ reversed_sd[f"{prefix}{i}.gate_proj.weight"] = gate
427
+ reversed_sd[f"{prefix}{i}.up_proj.weight"] = up
428
+ elif proj_type == "down_proj":
429
+ # v shape: [num_experts, hidden, intermediate]
430
+ num_experts = v.shape[0]
431
+ for i in range(num_experts):
432
+ reversed_sd[f"{prefix}{i}.down_proj.weight"] = v[i]
433
+ else:
434
+ reversed_sd[new_k] = v
435
+
436
+ kwargs["state_dict"] = reversed_sd
437
+ print(f"[HYV3 Patch] Reverse-renamed and unfused "
438
+ f"{len(state_dict)} -> {len(reversed_sd)} "
439
+ f"state_dict keys for old checkpoint compatibility.")
440
+ return _orig_save_pretrained(self, *args, **kwargs)
441
+
442
+ HYV3ForCausalLM.save_pretrained = _patched_save_pretrained
443
+ print("[HYV3 Patch] Applied: save-time reverse key rename + "
444
+ "3D -> per-expert unfuse for old ckpt compatibility.")
445
+
446
+ _apply_save_reverse_rename_patch()
447
+ # -------------------------------------------------------------------
448
+
449
+ if _has_weights:
450
+ print(f"Initializing model from local file: {training_args.model_name_or_path}")
451
+ model = transformers.AutoModelForCausalLM.from_pretrained(
452
+ training_args.model_name_or_path,
453
+ trust_remote_code=True,
454
+ **init_kwargs
455
+ )
456
+ else:
457
+ from transformers import HYV3Config
458
+ from transformers import HYV3ForCausalLM
459
+ print(f"Model weights not found at: {training_args.model_name_or_path}, "
460
+ f"using random initialized HYV3 model instead.")
461
+ # Use len(tokenizer) to include added special tokens; tokenizer.vocab_size
462
+ # may only return the base vocabulary size and miss special tokens whose
463
+ # IDs exceed that range, causing index-out-of-bounds in the embedding layer.
464
+ config = HYV3Config(
465
+ vocab_size=len(tokenizer),
466
+ hidden_size=model_args.hidden_size,
467
+ intermediate_size=model_args.intermediate_size,
468
+ max_position_embeddings=training_args.model_max_length,
469
+ moe_topk=model_args.moe_topk,
470
+ num_experts=model_args.num_experts,
471
+ num_attention_heads=model_args.num_attention_heads,
472
+ num_key_value_heads=model_args.num_key_value_heads,
473
+ num_hidden_layers=model_args.num_layers,
474
+ moe_intermediate_size=model_args.moe_intermediate_size,
475
+ use_mixed_mlp_moe=model_args.use_mixed_mlp_moe,
476
+ num_shared_expert=model_args.num_shared_expert,
477
+ use_qk_norm=model_args.use_qk_norm,
478
+ moe_layer_num_skipped=model_args.moe_layer_num_skipped,
479
+ tie_word_embeddings=model_args.tie_word_embeddings,
480
+ )
481
+ with deepspeed.zero.Init(dtype=init_kwargs.get("torch_dtype", torch.bfloat16), config_dict_or_path=training_args.deepspeed):
482
+ model = HYV3ForCausalLM(config)
483
+
484
+ if model_args.train_attention_params_only:
485
+ for name, param in model.named_parameters():
486
+ if 'self_attn' not in name:
487
+ param.requires_grad = False
488
+
489
+ if model_args.use_lora:
490
+ # define Lora configuration
491
+ lora_config = LoraConfig(
492
+ r=model_args.lora_rank,
493
+ lora_alpha=model_args.lora_alpha,
494
+ lora_dropout=model_args.lora_dropout,
495
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
496
+ bias="none",
497
+ task_type="CAUSAL_LM",
498
+ )
499
+ model = get_peft_model(model, lora_config)
500
+
501
+ data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
502
+ # Tell Trainer not to attempt DataParallel
503
+ model.is_parallelizable = True
504
+ model.model_parallel = True
505
+
506
+ training_args.lr_scheduler_kwargs = {
507
+ 'min_lr_rate': training_args.min_lr / training_args.learning_rate,
508
+ }
509
+
510
+ # -----------------------------------------------------------------------
511
+ # Fix: DeepSpeed ZeRO-3 + gradient checkpointing compatibility.
512
+ #
513
+ # PyTorch's torch.utils.checkpoint with use_reentrant=False (the default
514
+ # in transformers) performs strict metadata checks on recomputed tensors
515
+ # during backward. Under ZeRO-3, parameters are all-gathered during the
516
+ # first forward pass (shape=[full_size]) but may be partitioned back
517
+ # (shape=[0]) when the checkpoint recomputes, causing a CheckpointError.
518
+ #
519
+ # Setting use_reentrant=True avoids this strict metadata check.
520
+ # -----------------------------------------------------------------------
521
+ if training_args.gradient_checkpointing and training_args.deepspeed:
522
+ training_args.gradient_checkpointing_kwargs = {"use_reentrant": True}
523
+
524
+ trainer = Trainer(
525
+ model=model,
526
+ processing_class=tokenizer,
527
+ args=training_args,
528
+ callbacks=[CustomSaveCallback],
529
+ **data_module
530
+ )
531
+ model.config.use_cache = False
532
+
533
+ # -----------------------------------------------------------------------
534
+ # Monkey-patch: fix dtype mismatch in DeepSpeed ZeRO-3 linear wrapper.
535
+ #
536
+ # By this point the DeepSpeed engine has been initialised by the Trainer
537
+ # and torch.nn.functional.linear has been replaced with
538
+ # zero3_linear_wrap. That wrapper does NOT auto-align input/weight
539
+ # dtypes before the matmul, causing "expected mat1 and mat2 to have the
540
+ # same dtype" errors in mixed-precision paths (MoE router gate in fp32
541
+ # with bf16 weights, expert FFN receiving fp32 routing-weighted input
542
+ # with bf16 weights, etc.).
543
+ #
544
+ # We wrap F.linear HERE (after DeepSpeed init) so that:
545
+ # 1. We are sure to capture the already-replaced function.
546
+ # 2. The dtype cast happens *outside* the autograd.Function, so
547
+ # gradient-checkpointing recompute sees identical tensor metadata.
548
+ # -----------------------------------------------------------------------
549
+ import torch.nn.functional as _F
550
+ _orig_F_linear = _F.linear
551
+
552
+ def _dtype_safe_linear(input, weight, bias=None):
553
+ if input.dtype != weight.dtype:
554
+ input = input.to(weight.dtype)
555
+ return _orig_F_linear(input, weight, bias)
556
+
557
+ _F.linear = _dtype_safe_linear
558
+ # -----------------------------------------------------------------------
559
+
560
+ trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
561
+
562
+
563
+ if __name__ == "__main__":
564
+ train()
train/train.sh ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ NET_TYPE="high"
4
+ export NCCL_DEBUG=WARN
5
+ export NCCL_P2P_LEVEL=NVL
6
+ export NCCL_IB_TIMEOUT=24
7
+ export NCCL_NVLS_ENABLE=0
8
+ export NCCL_MPI_PROFILE_PRIMS_ENABLE=0
9
+ export CUDA_DEVICE_MAX_CONNECTIONS=1
10
+ export TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC=3600
11
+ if [[ "${NET_TYPE}" = "low" ]]; then
12
+ export NCCL_SOCKET_IFNAME=eth1
13
+ export NCCL_IB_GID_INDEX=3
14
+ export NCCL_IB_HCA=mlx5_2:1
15
+ export NCCL_IB_SL=3
16
+ export NCCL_CHECK_DISABLE=1
17
+ export NCCL_P2P_DISABLE=0
18
+ export NCCL_LL_THRESHOLD=16384
19
+ export NCCL_IB_CUDA_SUPPORT=1
20
+ else
21
+ export NCCL_IB_GID_INDEX=3
22
+ export NCCL_IB_SL=3
23
+ export NCCL_CHECK_DISABLE=1
24
+ export NCCL_P2P_DISABLE=0
25
+ export NCCL_IB_DISABLE=0
26
+ export NCCL_LL_THRESHOLD=16384
27
+ export NCCL_IB_CUDA_SUPPORT=1
28
+ export NCCL_SOCKET_IFNAME=bond1
29
+ export UCX_NET_DEVICES=bond1
30
+ export NCCL_IB_HCA=mlx5_bond_1,mlx5_bond_5,mlx5_bond_3,mlx5_bond_7,mlx5_bond_4,mlx5_bond_8,mlx5_bond_2,mlx5_bond_6
31
+ export NCCL_COLLNET_ENABLE=0
32
+ export SHARP_COLL_ENABLE_SAT=0
33
+ export NCCL_NET_GDR_LEVEL=2
34
+ export NCCL_IB_QPS_PER_CONNECTION=4
35
+ export NCCL_IB_TC=160
36
+ export NCCL_PXN_DISABLE=1
37
+ fi
38
+
39
+ export HOST_GPU_NUM=8
40
+ # IP list, comma separated. e.g. "192.168.1.1,192.168.1.2" or single node "192.168.1.1"
41
+ IP_LIST=${IP_LIST:-"127.0.0.1"}
42
+
43
+ IFS=',' read -ra IP_ARRAY <<< "$IP_LIST"
44
+ export NODES=${#IP_ARRAY[@]}
45
+ export LOCAL_IP=${IP_ARRAY[0]}
46
+ NODE_IP_LIST=""
47
+ for ip in "${IP_ARRAY[@]}"; do
48
+ if [ -n "$NODE_IP_LIST" ]; then
49
+ NODE_IP_LIST="${NODE_IP_LIST},"
50
+ fi
51
+ NODE_IP_LIST="${NODE_IP_LIST}${ip}:${HOST_GPU_NUM}"
52
+ done
53
+ export NODE_IP_LIST
54
+ export NODE_NUM=$((${NODES} * ${HOST_GPU_NUM}))
55
+
56
+ model_path=path_to_model_weight
57
+ tokenizer_path=../models
58
+ train_data_file=example_data.jsonl
59
+
60
+ # ds_config_file=ds_zero2_no_offload.json
61
+ # ds_config_file=ds_zero3_no_offload.json
62
+ ds_config_file=ds_zero3_offload_no_auto.json
63
+
64
+ output_path=/root/hf_train_output_full
65
+
66
+ mkdir -p ${output_path}
67
+
68
+ current_time=$(date "+%Y.%m.%d-%H.%M.%S")
69
+ log_file=${output_path}/"log_${current_time}.txt"
70
+
71
+ echo $NODE_IP_LIST > env.txt 2>&1
72
+ sed "s/:/ slots=/g" env.txt | sed "s/,/\n/g" > "hostfile"
73
+ sed "s/:.//g" env.txt | sed "s/,/\n/g" > "pssh.hosts"
74
+ export CHIEF_IP=$LOCAL_IP
75
+
76
+ if [ ${NODES} -gt 1 ]; then
77
+ HOST_PATH=hostfile
78
+ DS_ARGS="--hostfile=${HOST_PATH} --master_addr ${CHIEF_IP}"
79
+ else
80
+ DS_ARGS=""
81
+ fi
82
+
83
+ echo "NODES: ${NODES}, LOCAL_IP: ${LOCAL_IP}, NODE_IP_LIST: ${NODE_IP_LIST}"
84
+
85
+ deepspeed ${DS_ARGS} \
86
+ train.py \
87
+ --do_train \
88
+ --model_name_or_path ${model_path} \
89
+ --tokenizer_name_or_path ${tokenizer_path} \
90
+ --train_data_file ${train_data_file} \
91
+ --deepspeed ${ds_config_file} \
92
+ --output_dir ${output_path} \
93
+ --per_device_train_batch_size 1 \
94
+ --gradient_accumulation_steps 1 \
95
+ --gradient_checkpointing \
96
+ --lr_scheduler_type cosine_with_min_lr \
97
+ --logging_steps 1 \
98
+ --max_steps 200 \
99
+ --save_steps 100 \
100
+ --learning_rate 1e-5 \
101
+ --min_lr 1e-6 \
102
+ --warmup_ratio 0.01 \
103
+ --save_strategy steps \
104
+ --bf16 \
105
+ --hidden_size 4096 \
106
+ --intermediate_size 13312 \
107
+ --model_max_length 4096 \
108
+ --max_seq_length 4096 \
109
+ --moe_topk 8 \
110
+ --num_experts 192 \
111
+ --moe_intermediate_size 1536 \
112
+ --moe_layer_num_skipped 1 \
113
+ --num_attention_heads 64 \
114
+ --num_key_value_heads 8 \
115
+ --num_layers 80 \
116
+ --use_mixed_mlp_moe \
117
+ --num_shared_expert 1 \
118
+ --use_qk_norm | tee ${log_file}
train/train_lora.sh ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ NET_TYPE="high"
4
+ export NCCL_DEBUG=WARN
5
+ export NCCL_P2P_LEVEL=NVL
6
+ export NCCL_IB_TIMEOUT=24
7
+ export NCCL_NVLS_ENABLE=0
8
+ export NCCL_MPI_PROFILE_PRIMS_ENABLE=0
9
+ export CUDA_DEVICE_MAX_CONNECTIONS=1
10
+ export TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC=3600
11
+ if [[ "${NET_TYPE}" = "low" ]]; then
12
+ export NCCL_SOCKET_IFNAME=eth1
13
+ export NCCL_IB_GID_INDEX=3
14
+ export NCCL_IB_HCA=mlx5_2:1
15
+ export NCCL_IB_SL=3
16
+ export NCCL_CHECK_DISABLE=1
17
+ export NCCL_P2P_DISABLE=0
18
+ export NCCL_LL_THRESHOLD=16384
19
+ export NCCL_IB_CUDA_SUPPORT=1
20
+ else
21
+ export NCCL_IB_GID_INDEX=3
22
+ export NCCL_IB_SL=3
23
+ export NCCL_CHECK_DISABLE=1
24
+ export NCCL_P2P_DISABLE=0
25
+ export NCCL_IB_DISABLE=0
26
+ export NCCL_LL_THRESHOLD=16384
27
+ export NCCL_IB_CUDA_SUPPORT=1
28
+ export NCCL_SOCKET_IFNAME=bond1
29
+ export UCX_NET_DEVICES=bond1
30
+ export NCCL_IB_HCA=mlx5_bond_1,mlx5_bond_5,mlx5_bond_3,mlx5_bond_7,mlx5_bond_4,mlx5_bond_8,mlx5_bond_2,mlx5_bond_6
31
+ export NCCL_COLLNET_ENABLE=0
32
+ export SHARP_COLL_ENABLE_SAT=0
33
+ export NCCL_NET_GDR_LEVEL=2
34
+ export NCCL_IB_QPS_PER_CONNECTION=4
35
+ export NCCL_IB_TC=160
36
+ export NCCL_PXN_DISABLE=1
37
+ fi
38
+
39
+ export HOST_GPU_NUM=8
40
+ # IP list, comma separated. e.g. "192.168.1.1,192.168.1.2" or single node "192.168.1.1"
41
+ IP_LIST=${IP_LIST:-"127.0.0.1"}
42
+
43
+ IFS=',' read -ra IP_ARRAY <<< "$IP_LIST"
44
+ export NODES=${#IP_ARRAY[@]}
45
+ export LOCAL_IP=${IP_ARRAY[0]}
46
+ NODE_IP_LIST=""
47
+ for ip in "${IP_ARRAY[@]}"; do
48
+ if [ -n "$NODE_IP_LIST" ]; then
49
+ NODE_IP_LIST="${NODE_IP_LIST},"
50
+ fi
51
+ NODE_IP_LIST="${NODE_IP_LIST}${ip}:${HOST_GPU_NUM}"
52
+ done
53
+ export NODE_IP_LIST
54
+ export NODE_NUM=$((${NODES} * ${HOST_GPU_NUM}))
55
+
56
+
57
+ model_path=path_to_model_weight
58
+ tokenizer_path=../models
59
+ train_data_file=example_data.jsonl
60
+
61
+ # ds_config_file=ds_zero2_no_offload.json
62
+ # ds_config_file=ds_zero3_no_offload.json
63
+ ds_config_file=ds_zero3_offload.json
64
+
65
+ output_path=/root/hf_train_output
66
+
67
+ mkdir -p ${output_path}
68
+
69
+ current_time=$(date "+%Y.%m.%d-%H.%M.%S")
70
+ log_file=${output_path}/"log_${current_time}.txt"
71
+
72
+ echo $NODE_IP_LIST > env.txt 2>&1
73
+ sed "s/:/ slots=/g" env.txt | sed "s/,/\n/g" > "hostfile"
74
+ sed "s/:.//g" env.txt | sed "s/,/\n/g" > "pssh.hosts"
75
+ export CHIEF_IP=$LOCAL_IP
76
+
77
+ if [ ${NODES} -gt 1 ]; then
78
+ HOST_PATH=hostfile
79
+ DS_ARGS="--hostfile=${HOST_PATH} --master_addr ${CHIEF_IP}"
80
+ else
81
+ DS_ARGS=""
82
+ fi
83
+
84
+ echo "NODES: ${NODES}, LOCAL_IP: ${LOCAL_IP}, NODE_IP_LIST: ${NODE_IP_LIST}"
85
+
86
+ deepspeed ${DS_ARGS} \
87
+ train.py \
88
+ --do_train \
89
+ --model_name_or_path ${model_path} \
90
+ --tokenizer_name_or_path ${tokenizer_path} \
91
+ --train_data_file ${train_data_file} \
92
+ --deepspeed ${ds_config_file} \
93
+ --output_dir ${output_path} \
94
+ --per_device_train_batch_size 1 \
95
+ --gradient_accumulation_steps 1 \
96
+ --gradient_checkpointing \
97
+ --lr_scheduler_type cosine_with_min_lr \
98
+ --logging_steps 1 \
99
+ --max_steps 200 \
100
+ --save_steps 100 \
101
+ --learning_rate 1e-5 \
102
+ --min_lr 1e-6 \
103
+ --warmup_ratio 0.01 \
104
+ --save_strategy steps \
105
+ --bf16 \
106
+ --use_lora \
107
+ --lora_rank 64 \
108
+ --lora_alpha 128 \
109
+ --lora_dropout 0.1 \
110
+ --hidden_size 4096 \
111
+ --intermediate_size 13312 \
112
+ --model_max_length 4096 \
113
+ --max_seq_length 4096 \
114
+ --moe_topk 8 \
115
+ --num_experts 192 \
116
+ --moe_intermediate_size 1536 \
117
+ --moe_layer_num_skipped 1 \
118
+ --num_attention_heads 64 \
119
+ --num_key_value_heads 8 \
120
+ --num_layers 80 \
121
+ --use_mixed_mlp_moe \
122
+ --num_shared_expert 1 \
123
+ --use_qk_norm | tee ${log_file}