Colin Zeng commited on
Commit
bd26087
·
1 Parent(s): 86098d2

Model Upload

Browse files
README.md ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ base_model:
4
+ - stepfun-ai/step-3.5-flash
5
+ library_name: transformers
6
+ ---
7
+
8
+ # Model Overview
9
+
10
+ - **Model Architecture:** Step3p5ForCausalLM
11
+ - **Input:** Text
12
+ - **Output:** Text
13
+ - **Supported Hardware Microarchitecture:** AMD MI350/MI355
14
+ - **ROCm**: 7.1.0
15
+ - **PyTorch**: 2.10.0
16
+ - **Transformers**: 4.57.6
17
+ - **Operating System(s):** Linux
18
+ - **Inference Engine:** [vLLM](https://docs.vllm.ai/en/latest/)
19
+ - **Model Optimizer:** [AMD-Quark](https://quark.docs.amd.com/latest/index.html)
20
+ - **Weight quantization:** MoE-only, OCP MXFP4, Static
21
+ - **Activation quantization:** MoE-only, OCP MXFP4, Dynamic
22
+ - **Docker Image:** rocm/vllm-dev@sha256:63f1fe04d87376bb173a1e837fba8610ab2dd77039fe7c9b97195f2a89d4d463
23
+
24
+
25
+ # Model Quantization
26
+
27
+ The model was quantized from [stepfun-ai/Step-3.5-Flash](https://huggingface.co/stepfun-ai/Step-3.5-Flash) using [AMD-Quark](https://quark.docs.amd.com/latest/index.html). The weights and activations are both quantized to MXFP4.
28
+
29
+
30
+ **Quantization scripts:**
31
+ ```
32
+ cd Quark/examples/torch/language_modeling/llm_ptq/
33
+ python3 step3p5_quantize_quark.py --model_dir $MODEL_DIR \
34
+ --num_calib_data 128 \
35
+ --multi_gpu \
36
+ --trust_remote_code \
37
+ --preset mxfp4_moe_only_no_kvcache
38
+ --output_dir $output_dir
39
+ ```
40
+ For further details or issues, please refer to the AMD-Quark documentation or contact the respective developers.
41
+
42
+ # Deployment
43
+ ### Use with vLLM
44
+
45
+ This model can be deployed efficiently using the [vLLM](https://docs.vllm.ai/en/latest/) backend.
46
+
47
+ ## Evaluation
48
+ The model was evaluated on gsm8k benchmarks using the [vLLM](https://docs.vllm.ai/en/latest/) framework.
49
+
50
+ ### Accuracy
51
+
52
+ <table>
53
+ <tr>
54
+ <td><strong>Benchmark</strong>
55
+ </td>
56
+ <td><strong>stepfun-ai/Step-3.5-Flash (bf16)</strong>
57
+ </td>
58
+ <td><strong>amd/Step-3.5-Flash-MXFP4 (this model)</strong>
59
+ </td>
60
+ <td><strong>Recovery</strong>
61
+ </td>
62
+ </tr>
63
+ <tr>
64
+ <td>gsm8k (flexible-extract)
65
+ </td>
66
+ <td>0.8939
67
+ </td>
68
+ <td>0.8726
69
+ </td>
70
+ <td>97.6%
71
+ </td>
72
+ </tr>
73
+ </table>
74
+
75
+
76
+ ### Reproduction
77
+
78
+ The GSM8K results were obtained using the vLLM framework, based on the Docker image `rocm/vllm:nightly`.
79
+
80
+ #### Note: Due to model support issues in vLLM for Step-3.5-Flash, a few patches need to be applied (specified below) in order to run inference and evaluation using vLLM.
81
+
82
+ #### Preparation in container
83
+ ```
84
+ # Reinstall vLLM
85
+ pip uninstall vllm -y
86
+ git clone https://github.com/vllm-project/vllm.git
87
+ cd vllm
88
+ git checkout de7dd634b969adc6e5f50cff0cc09c1be1711d01
89
+ pip install -r requirements/rocm.txt
90
+ python setup.py develop
91
+ cd ..
92
+ export QUARK_MXFP4_IMPL="triton"
93
+ ```
94
+ Modify `vllm/model_executor/models/step3p5.py` by adding the below packed_modules_mapping update in the model's `__init__` function:
95
+ ```
96
+ ...
97
+ @support_torch_compile
98
+ class Step3p5Model(nn.Module):
99
+ def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
100
+ super().__init__()
101
+
102
+ self.vllm_config = vllm_config
103
+ config = vllm_config.model_config.hf_config
104
+ + # Update packed_modules_mapping for quantization
105
+ + if hasattr(vllm_config, "quant_config") and vllm_config.quant_config:
106
+ + vllm_config.quant_config.packed_modules_mapping.update({
107
+ + "qkv_proj": ["q_proj", "k_proj", "v_proj"],
108
+ + "gate_up_proj": ["gate_proj", "up_proj"],
109
+ + })
110
+ self.vocab_size = config.vocab_size
111
+ self.config = config
112
+ ...
113
+ ```
114
+ Additionally, modify the same file (`step3p5.py`) by adding the below MoE expert name mapping to the model's `load_weights` function:
115
+ ```
116
+ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
117
+ config = self.config
118
+ assert config.num_attention_groups > 1, "Only support GQA"
119
+
120
+ ...
121
+
122
+ for name, loaded_weight in weights:
123
+ if name.startswith("model."):
124
+ local_name = name[len("model.") :]
125
+ full_name = name
126
+ else:
127
+ local_name = name
128
+ full_name = f"model.{name}" if name else "model"
129
+
130
+ + # Normalize legacy MoE expert naming like ".moe.<E>.gate_proj" to
131
+ + # the ".moe.experts.<E>.gate_proj" format
132
+ + if ".moe.experts." not in local_name and ".moe." in local_name:
133
+ + parts = local_name.split(".moe.", 1)
134
+ + if len(parts) == 2 and "." in parts[1]:
135
+ + expert_and_rest = parts[1]
136
+ + expert_id, remainder = expert_and_rest.split(".", 1)
137
+ + if expert_id.isdigit():
138
+ + local_name = f"{parts[0]}.moe.experts.{expert_id}.{remainder}"
139
+
140
+ spec_layer = get_spec_layer_idx_from_weight_name(config, full_name)
141
+ if spec_layer is not None:
142
+ continue # skip spec decode layers for main model
143
+ ...
144
+ ```
145
+ Finally, modify `vllm/model_executor/layers/quantization/quark/quark_moe.py` by forcing `self.emulate` to "True":
146
+ ```
147
+ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
148
+ def __init__(...):
149
+ super().__init__(moe)
150
+ ...
151
+
152
+ self.model_type = getattr(
153
+ get_current_vllm_config().model_config.hf_config, "model_type", None
154
+ )
155
+
156
+ - self.emulate = (
157
+ - not current_platform.supports_mx()
158
+ - or not self.ocp_mx_scheme.startswith("w_mxfp4")
159
+ - ) and (self.mxfp4_backend is None or not self.use_rocm_aiter_moe)
160
+ + self.emulate = True
161
+
162
+ logger.warning_once(
163
+ ...
164
+ ```
165
+
166
+
167
+ #### Evaluating model using lm_eval
168
+ ```
169
+ lm_eval --model vllm --model_args 'pretrained=$MODEL_DIR,attention_backend=ROCM_AITER_UNIFIED_ATTN,quantization='quark',trust_remote_code=True' --tasks gsm8k --batch_size auto
170
+ ```
171
+
172
+
173
+ # License
174
+ Modifications Copyright(c) 2026 Advanced Micro Devices, Inc. All rights reserved.
chat_template.jinja ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {% macro render_content(content) %}{% if content is none %}{{- '' }}{% elif content is string %}{{- content }}{% elif content is mapping %}{{- content['value'] if 'value' in content else content['text'] }}{% elif content is iterable %}{% for item in content %}{% if item.type == 'text' %}{{- item['value'] if 'value' in item else item['text'] }}{% elif item.type == 'image' %}<im_patch>{% endif %}{% endfor %}{% endif %}{% endmacro %}
2
+ {{bos_token}}{%- if tools %}
3
+ {{- '<|im_start|>system\n' }}
4
+ {%- if messages[0].role == 'system' %}
5
+ {{- render_content(messages[0].content) + '\n\n' }}
6
+ {%- endif %}
7
+ {{- "# Tools\n\nYou have access to the following functions in JSONSchema format:\n\n<tools>" }}
8
+ {%- for tool in tools %}
9
+ {{- "\n" }}
10
+ {{- tool | tojson(ensure_ascii=False) }}
11
+ {%- endfor %}
12
+ {{- "\n</tools>\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n- Function calls MUST follow the specified format: an inner <function=...>\n...\n</function> block must be nested within <tool_call>\n...\n</tool_call> XML tags\n- Required parameters MUST be specified\n</IMPORTANT><|im_end|>\n" }}
13
+ {%- else %}
14
+ {%- if messages[0].role == 'system' %}
15
+ {{- '<|im_start|>system\n' + render_content(messages[0].content) + '<|im_end|>\n' }}
16
+ {%- endif %}
17
+ {%- endif %}
18
+ {%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
19
+ {%- for message in messages[::-1] %}
20
+ {%- set index = (messages|length - 1) - loop.index0 %}
21
+ {%- if ns.multi_step_tool and message.role == "user" and render_content(message.content) is string and not(render_content(message.content).startswith('<tool_response>') and render_content(message.content).endswith('</tool_response>')) %}
22
+ {%- set ns.multi_step_tool = false %}
23
+ {%- set ns.last_query_index = index %}
24
+ {%- endif %}
25
+ {%- endfor %}
26
+ {%- for message in messages %}
27
+ {%- set content = render_content(message.content) %}
28
+ {%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
29
+ {%- set role_name = 'observation' if (message.role == "system" and not loop.first and message.name == 'observation') else message.role %}
30
+ {{- '<|im_start|>' + role_name + '\n' + content + '<|im_end|>' + '\n' }}
31
+ {%- elif message.role == "assistant" %}
32
+ {%- if message.reasoning_content is string %}
33
+ {%- set reasoning_content = render_content(message.reasoning_content) %}
34
+ {%- else %}
35
+ {%- if '</think>' in content %}
36
+ {%- set reasoning_content = content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
37
+ {%- set content = content.split('</think>')[-1].lstrip('\n') %}
38
+ {%- else %}
39
+ {%- set reasoning_content = '' %}
40
+ {%- endif %}
41
+ {%- endif %}
42
+ {%- if loop.index0 > ns.last_query_index %}
43
+ {{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content + '\n</think>\n' + content }}
44
+ {%- else %}
45
+ {{- '<|im_start|>' + message.role + '\n' + content }}
46
+ {%- endif %}
47
+ {%- if message.tool_calls %}
48
+ {%- for tool_call in message.tool_calls %}
49
+ {%- if tool_call.function is defined %}
50
+ {%- set tool_call = tool_call.function %}
51
+ {%- endif %}
52
+ {{- '<tool_call>\n<function=' + tool_call.name + '>\n' }}
53
+ {%- if tool_call.arguments is defined %}
54
+ {%- set arguments = tool_call.arguments %}
55
+ {%- for args_name, args_value in arguments|items %}
56
+ {{- '<parameter=' + args_name + '>\n' }}
57
+ {%- set args_value = args_value | tojson(ensure_ascii=False) | safe if args_value is mapping or (args_value is sequence and args_value is not string) else args_value | string %}
58
+ {{- args_value }}
59
+ {{- '\n</parameter>\n' }}
60
+ {%- endfor %}
61
+ {%- endif %}
62
+ {{- '</function>\n</tool_call>' }}
63
+ {%- endfor %}
64
+ {%- endif %}
65
+ {{- '<|im_end|>\n' }}
66
+ {%- elif message.role == "tool" %}
67
+ {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
68
+ {{- '<|im_start|>tool_response\n' }}
69
+ {%- endif %}
70
+ {{- '<tool_response>' }}
71
+ {{- content }}
72
+ {{- '</tool_response>' }}
73
+ {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
74
+ {{- '<|im_end|>\n' }}
75
+ {%- endif %}
76
+ {%- endif %}
77
+ {%- endfor %}
78
+ {%- if add_generation_prompt %}
79
+ {{- '<|im_start|>assistant\n<think>\n' }}
80
+ {%- endif %}
config.json ADDED
@@ -0,0 +1,784 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Step3p5ForCausalLM"
4
+ ],
5
+ "att_impl_type": "GQA",
6
+ "attention_other_setting": {
7
+ "attention_type": "sliding_attention",
8
+ "head_dim": 128,
9
+ "num_attention_groups": 8,
10
+ "num_attention_heads": 96,
11
+ "true_head_dim": 128
12
+ },
13
+ "auto_map": {
14
+ "AutoConfig": "configuration_step3p5.Step3p5Config",
15
+ "AutoModelForCausalLM": "modeling_step3p5.Step3p5ForCausalLM"
16
+ },
17
+ "bos_token_id": 0,
18
+ "dtype": "bfloat16",
19
+ "eos_token_id": [
20
+ 1,
21
+ 2,
22
+ 128007
23
+ ],
24
+ "head_dim": 128,
25
+ "hidden_size": 4096,
26
+ "intermediate_size": 11264,
27
+ "layer_types": [
28
+ "full_attention",
29
+ "sliding_attention",
30
+ "sliding_attention",
31
+ "sliding_attention",
32
+ "full_attention",
33
+ "sliding_attention",
34
+ "sliding_attention",
35
+ "sliding_attention",
36
+ "full_attention",
37
+ "sliding_attention",
38
+ "sliding_attention",
39
+ "sliding_attention",
40
+ "full_attention",
41
+ "sliding_attention",
42
+ "sliding_attention",
43
+ "sliding_attention",
44
+ "full_attention",
45
+ "sliding_attention",
46
+ "sliding_attention",
47
+ "sliding_attention",
48
+ "full_attention",
49
+ "sliding_attention",
50
+ "sliding_attention",
51
+ "sliding_attention",
52
+ "full_attention",
53
+ "sliding_attention",
54
+ "sliding_attention",
55
+ "sliding_attention",
56
+ "full_attention",
57
+ "sliding_attention",
58
+ "sliding_attention",
59
+ "sliding_attention",
60
+ "full_attention",
61
+ "sliding_attention",
62
+ "sliding_attention",
63
+ "sliding_attention",
64
+ "full_attention",
65
+ "sliding_attention",
66
+ "sliding_attention",
67
+ "sliding_attention",
68
+ "full_attention",
69
+ "sliding_attention",
70
+ "sliding_attention",
71
+ "sliding_attention",
72
+ "full_attention",
73
+ "sliding_attention",
74
+ "sliding_attention",
75
+ "sliding_attention"
76
+ ],
77
+ "max_position_embeddings": 262144,
78
+ "max_seq_len": 262144,
79
+ "model_type": "step3p5",
80
+ "moe_every_n_layer": 1,
81
+ "moe_intermediate_size": 1280,
82
+ "moe_layer_offset": 0,
83
+ "moe_layers_enum": "3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44",
84
+ "moe_num_experts": 288,
85
+ "moe_router_activation": "sigmoid",
86
+ "moe_router_scaling_factor": 3.0,
87
+ "moe_top_k": 8,
88
+ "need_fp32_gate": true,
89
+ "norm_expert_weight": true,
90
+ "num_attention_groups": 8,
91
+ "num_attention_heads": 64,
92
+ "num_hidden_layers": 45,
93
+ "num_nextn_predict_layers": 3,
94
+ "partial_rotary_factor": 0.5,
95
+ "partial_rotary_factors": [
96
+ 0.5,
97
+ 1.0,
98
+ 1.0,
99
+ 1.0,
100
+ 0.5,
101
+ 1.0,
102
+ 1.0,
103
+ 1.0,
104
+ 0.5,
105
+ 1.0,
106
+ 1.0,
107
+ 1.0,
108
+ 0.5,
109
+ 1.0,
110
+ 1.0,
111
+ 1.0,
112
+ 0.5,
113
+ 1.0,
114
+ 1.0,
115
+ 1.0,
116
+ 0.5,
117
+ 1.0,
118
+ 1.0,
119
+ 1.0,
120
+ 0.5,
121
+ 1.0,
122
+ 1.0,
123
+ 1.0,
124
+ 0.5,
125
+ 1.0,
126
+ 1.0,
127
+ 1.0,
128
+ 0.5,
129
+ 1.0,
130
+ 1.0,
131
+ 1.0,
132
+ 0.5,
133
+ 1.0,
134
+ 1.0,
135
+ 1.0,
136
+ 0.5,
137
+ 1.0,
138
+ 1.0,
139
+ 1.0,
140
+ 0.5,
141
+ 1.0,
142
+ 1.0,
143
+ 1.0
144
+ ],
145
+ "quantization_config": {
146
+ "algo_config": null,
147
+ "exclude": [
148
+ "model.embed_tokens",
149
+ "model.layers.0.self_attn.q_proj",
150
+ "model.layers.0.self_attn.k_proj",
151
+ "model.layers.0.self_attn.v_proj",
152
+ "model.layers.0.self_attn.o_proj",
153
+ "model.layers.0.self_attn.g_proj",
154
+ "model.layers.0.mlp.gate_proj",
155
+ "model.layers.0.mlp.up_proj",
156
+ "model.layers.0.mlp.down_proj",
157
+ "model.layers.1.self_attn.q_proj",
158
+ "model.layers.1.self_attn.k_proj",
159
+ "model.layers.1.self_attn.v_proj",
160
+ "model.layers.1.self_attn.o_proj",
161
+ "model.layers.1.self_attn.g_proj",
162
+ "model.layers.1.mlp.gate_proj",
163
+ "model.layers.1.mlp.up_proj",
164
+ "model.layers.1.mlp.down_proj",
165
+ "model.layers.2.self_attn.q_proj",
166
+ "model.layers.2.self_attn.k_proj",
167
+ "model.layers.2.self_attn.v_proj",
168
+ "model.layers.2.self_attn.o_proj",
169
+ "model.layers.2.self_attn.g_proj",
170
+ "model.layers.2.mlp.gate_proj",
171
+ "model.layers.2.mlp.up_proj",
172
+ "model.layers.2.mlp.down_proj",
173
+ "model.layers.3.self_attn.q_proj",
174
+ "model.layers.3.self_attn.k_proj",
175
+ "model.layers.3.self_attn.v_proj",
176
+ "model.layers.3.self_attn.o_proj",
177
+ "model.layers.3.self_attn.g_proj",
178
+ "model.layers.3.moe.gate",
179
+ "model.layers.3.share_expert.gate_proj",
180
+ "model.layers.3.share_expert.up_proj",
181
+ "model.layers.3.share_expert.down_proj",
182
+ "model.layers.4.self_attn.q_proj",
183
+ "model.layers.4.self_attn.k_proj",
184
+ "model.layers.4.self_attn.v_proj",
185
+ "model.layers.4.self_attn.o_proj",
186
+ "model.layers.4.self_attn.g_proj",
187
+ "model.layers.4.moe.gate",
188
+ "model.layers.4.share_expert.gate_proj",
189
+ "model.layers.4.share_expert.up_proj",
190
+ "model.layers.4.share_expert.down_proj",
191
+ "model.layers.5.self_attn.q_proj",
192
+ "model.layers.5.self_attn.k_proj",
193
+ "model.layers.5.self_attn.v_proj",
194
+ "model.layers.5.self_attn.o_proj",
195
+ "model.layers.5.self_attn.g_proj",
196
+ "model.layers.5.moe.gate",
197
+ "model.layers.5.share_expert.gate_proj",
198
+ "model.layers.5.share_expert.up_proj",
199
+ "model.layers.5.share_expert.down_proj",
200
+ "model.layers.6.self_attn.q_proj",
201
+ "model.layers.6.self_attn.k_proj",
202
+ "model.layers.6.self_attn.v_proj",
203
+ "model.layers.6.self_attn.o_proj",
204
+ "model.layers.6.self_attn.g_proj",
205
+ "model.layers.6.moe.gate",
206
+ "model.layers.6.share_expert.gate_proj",
207
+ "model.layers.6.share_expert.up_proj",
208
+ "model.layers.6.share_expert.down_proj",
209
+ "model.layers.7.self_attn.q_proj",
210
+ "model.layers.7.self_attn.k_proj",
211
+ "model.layers.7.self_attn.v_proj",
212
+ "model.layers.7.self_attn.o_proj",
213
+ "model.layers.7.self_attn.g_proj",
214
+ "model.layers.7.moe.gate",
215
+ "model.layers.7.share_expert.gate_proj",
216
+ "model.layers.7.share_expert.up_proj",
217
+ "model.layers.7.share_expert.down_proj",
218
+ "model.layers.8.self_attn.q_proj",
219
+ "model.layers.8.self_attn.k_proj",
220
+ "model.layers.8.self_attn.v_proj",
221
+ "model.layers.8.self_attn.o_proj",
222
+ "model.layers.8.self_attn.g_proj",
223
+ "model.layers.8.moe.gate",
224
+ "model.layers.8.share_expert.gate_proj",
225
+ "model.layers.8.share_expert.up_proj",
226
+ "model.layers.8.share_expert.down_proj",
227
+ "model.layers.9.self_attn.q_proj",
228
+ "model.layers.9.self_attn.k_proj",
229
+ "model.layers.9.self_attn.v_proj",
230
+ "model.layers.9.self_attn.o_proj",
231
+ "model.layers.9.self_attn.g_proj",
232
+ "model.layers.9.moe.gate",
233
+ "model.layers.9.share_expert.gate_proj",
234
+ "model.layers.9.share_expert.up_proj",
235
+ "model.layers.9.share_expert.down_proj",
236
+ "model.layers.10.self_attn.q_proj",
237
+ "model.layers.10.self_attn.k_proj",
238
+ "model.layers.10.self_attn.v_proj",
239
+ "model.layers.10.self_attn.o_proj",
240
+ "model.layers.10.self_attn.g_proj",
241
+ "model.layers.10.moe.gate",
242
+ "model.layers.10.share_expert.gate_proj",
243
+ "model.layers.10.share_expert.up_proj",
244
+ "model.layers.10.share_expert.down_proj",
245
+ "model.layers.11.self_attn.q_proj",
246
+ "model.layers.11.self_attn.k_proj",
247
+ "model.layers.11.self_attn.v_proj",
248
+ "model.layers.11.self_attn.o_proj",
249
+ "model.layers.11.self_attn.g_proj",
250
+ "model.layers.11.moe.gate",
251
+ "model.layers.11.share_expert.gate_proj",
252
+ "model.layers.11.share_expert.up_proj",
253
+ "model.layers.11.share_expert.down_proj",
254
+ "model.layers.12.self_attn.q_proj",
255
+ "model.layers.12.self_attn.k_proj",
256
+ "model.layers.12.self_attn.v_proj",
257
+ "model.layers.12.self_attn.o_proj",
258
+ "model.layers.12.self_attn.g_proj",
259
+ "model.layers.12.moe.gate",
260
+ "model.layers.12.share_expert.gate_proj",
261
+ "model.layers.12.share_expert.up_proj",
262
+ "model.layers.12.share_expert.down_proj",
263
+ "model.layers.13.self_attn.q_proj",
264
+ "model.layers.13.self_attn.k_proj",
265
+ "model.layers.13.self_attn.v_proj",
266
+ "model.layers.13.self_attn.o_proj",
267
+ "model.layers.13.self_attn.g_proj",
268
+ "model.layers.13.moe.gate",
269
+ "model.layers.13.share_expert.gate_proj",
270
+ "model.layers.13.share_expert.up_proj",
271
+ "model.layers.13.share_expert.down_proj",
272
+ "model.layers.14.self_attn.q_proj",
273
+ "model.layers.14.self_attn.k_proj",
274
+ "model.layers.14.self_attn.v_proj",
275
+ "model.layers.14.self_attn.o_proj",
276
+ "model.layers.14.self_attn.g_proj",
277
+ "model.layers.14.moe.gate",
278
+ "model.layers.14.share_expert.gate_proj",
279
+ "model.layers.14.share_expert.up_proj",
280
+ "model.layers.14.share_expert.down_proj",
281
+ "model.layers.15.self_attn.q_proj",
282
+ "model.layers.15.self_attn.k_proj",
283
+ "model.layers.15.self_attn.v_proj",
284
+ "model.layers.15.self_attn.o_proj",
285
+ "model.layers.15.self_attn.g_proj",
286
+ "model.layers.15.moe.gate",
287
+ "model.layers.15.share_expert.gate_proj",
288
+ "model.layers.15.share_expert.up_proj",
289
+ "model.layers.15.share_expert.down_proj",
290
+ "model.layers.16.self_attn.q_proj",
291
+ "model.layers.16.self_attn.k_proj",
292
+ "model.layers.16.self_attn.v_proj",
293
+ "model.layers.16.self_attn.o_proj",
294
+ "model.layers.16.self_attn.g_proj",
295
+ "model.layers.16.moe.gate",
296
+ "model.layers.16.share_expert.gate_proj",
297
+ "model.layers.16.share_expert.up_proj",
298
+ "model.layers.16.share_expert.down_proj",
299
+ "model.layers.17.self_attn.q_proj",
300
+ "model.layers.17.self_attn.k_proj",
301
+ "model.layers.17.self_attn.v_proj",
302
+ "model.layers.17.self_attn.o_proj",
303
+ "model.layers.17.self_attn.g_proj",
304
+ "model.layers.17.moe.gate",
305
+ "model.layers.17.share_expert.gate_proj",
306
+ "model.layers.17.share_expert.up_proj",
307
+ "model.layers.17.share_expert.down_proj",
308
+ "model.layers.18.self_attn.q_proj",
309
+ "model.layers.18.self_attn.k_proj",
310
+ "model.layers.18.self_attn.v_proj",
311
+ "model.layers.18.self_attn.o_proj",
312
+ "model.layers.18.self_attn.g_proj",
313
+ "model.layers.18.moe.gate",
314
+ "model.layers.18.share_expert.gate_proj",
315
+ "model.layers.18.share_expert.up_proj",
316
+ "model.layers.18.share_expert.down_proj",
317
+ "model.layers.19.self_attn.q_proj",
318
+ "model.layers.19.self_attn.k_proj",
319
+ "model.layers.19.self_attn.v_proj",
320
+ "model.layers.19.self_attn.o_proj",
321
+ "model.layers.19.self_attn.g_proj",
322
+ "model.layers.19.moe.gate",
323
+ "model.layers.19.share_expert.gate_proj",
324
+ "model.layers.19.share_expert.up_proj",
325
+ "model.layers.19.share_expert.down_proj",
326
+ "model.layers.20.self_attn.q_proj",
327
+ "model.layers.20.self_attn.k_proj",
328
+ "model.layers.20.self_attn.v_proj",
329
+ "model.layers.20.self_attn.o_proj",
330
+ "model.layers.20.self_attn.g_proj",
331
+ "model.layers.20.moe.gate",
332
+ "model.layers.20.share_expert.gate_proj",
333
+ "model.layers.20.share_expert.up_proj",
334
+ "model.layers.20.share_expert.down_proj",
335
+ "model.layers.21.self_attn.q_proj",
336
+ "model.layers.21.self_attn.k_proj",
337
+ "model.layers.21.self_attn.v_proj",
338
+ "model.layers.21.self_attn.o_proj",
339
+ "model.layers.21.self_attn.g_proj",
340
+ "model.layers.21.moe.gate",
341
+ "model.layers.21.share_expert.gate_proj",
342
+ "model.layers.21.share_expert.up_proj",
343
+ "model.layers.21.share_expert.down_proj",
344
+ "model.layers.22.self_attn.q_proj",
345
+ "model.layers.22.self_attn.k_proj",
346
+ "model.layers.22.self_attn.v_proj",
347
+ "model.layers.22.self_attn.o_proj",
348
+ "model.layers.22.self_attn.g_proj",
349
+ "model.layers.22.moe.gate",
350
+ "model.layers.22.share_expert.gate_proj",
351
+ "model.layers.22.share_expert.up_proj",
352
+ "model.layers.22.share_expert.down_proj",
353
+ "model.layers.23.self_attn.q_proj",
354
+ "model.layers.23.self_attn.k_proj",
355
+ "model.layers.23.self_attn.v_proj",
356
+ "model.layers.23.self_attn.o_proj",
357
+ "model.layers.23.self_attn.g_proj",
358
+ "model.layers.23.moe.gate",
359
+ "model.layers.23.share_expert.gate_proj",
360
+ "model.layers.23.share_expert.up_proj",
361
+ "model.layers.23.share_expert.down_proj",
362
+ "model.layers.24.self_attn.q_proj",
363
+ "model.layers.24.self_attn.k_proj",
364
+ "model.layers.24.self_attn.v_proj",
365
+ "model.layers.24.self_attn.o_proj",
366
+ "model.layers.24.self_attn.g_proj",
367
+ "model.layers.24.moe.gate",
368
+ "model.layers.24.share_expert.gate_proj",
369
+ "model.layers.24.share_expert.up_proj",
370
+ "model.layers.24.share_expert.down_proj",
371
+ "model.layers.25.self_attn.q_proj",
372
+ "model.layers.25.self_attn.k_proj",
373
+ "model.layers.25.self_attn.v_proj",
374
+ "model.layers.25.self_attn.o_proj",
375
+ "model.layers.25.self_attn.g_proj",
376
+ "model.layers.25.moe.gate",
377
+ "model.layers.25.share_expert.gate_proj",
378
+ "model.layers.25.share_expert.up_proj",
379
+ "model.layers.25.share_expert.down_proj",
380
+ "model.layers.26.self_attn.q_proj",
381
+ "model.layers.26.self_attn.k_proj",
382
+ "model.layers.26.self_attn.v_proj",
383
+ "model.layers.26.self_attn.o_proj",
384
+ "model.layers.26.self_attn.g_proj",
385
+ "model.layers.26.moe.gate",
386
+ "model.layers.26.share_expert.gate_proj",
387
+ "model.layers.26.share_expert.up_proj",
388
+ "model.layers.26.share_expert.down_proj",
389
+ "model.layers.27.self_attn.q_proj",
390
+ "model.layers.27.self_attn.k_proj",
391
+ "model.layers.27.self_attn.v_proj",
392
+ "model.layers.27.self_attn.o_proj",
393
+ "model.layers.27.self_attn.g_proj",
394
+ "model.layers.27.moe.gate",
395
+ "model.layers.27.share_expert.gate_proj",
396
+ "model.layers.27.share_expert.up_proj",
397
+ "model.layers.27.share_expert.down_proj",
398
+ "model.layers.28.self_attn.q_proj",
399
+ "model.layers.28.self_attn.k_proj",
400
+ "model.layers.28.self_attn.v_proj",
401
+ "model.layers.28.self_attn.o_proj",
402
+ "model.layers.28.self_attn.g_proj",
403
+ "model.layers.28.moe.gate",
404
+ "model.layers.28.share_expert.gate_proj",
405
+ "model.layers.28.share_expert.up_proj",
406
+ "model.layers.28.share_expert.down_proj",
407
+ "model.layers.29.self_attn.q_proj",
408
+ "model.layers.29.self_attn.k_proj",
409
+ "model.layers.29.self_attn.v_proj",
410
+ "model.layers.29.self_attn.o_proj",
411
+ "model.layers.29.self_attn.g_proj",
412
+ "model.layers.29.moe.gate",
413
+ "model.layers.29.share_expert.gate_proj",
414
+ "model.layers.29.share_expert.up_proj",
415
+ "model.layers.29.share_expert.down_proj",
416
+ "model.layers.30.self_attn.q_proj",
417
+ "model.layers.30.self_attn.k_proj",
418
+ "model.layers.30.self_attn.v_proj",
419
+ "model.layers.30.self_attn.o_proj",
420
+ "model.layers.30.self_attn.g_proj",
421
+ "model.layers.30.moe.gate",
422
+ "model.layers.30.share_expert.gate_proj",
423
+ "model.layers.30.share_expert.up_proj",
424
+ "model.layers.30.share_expert.down_proj",
425
+ "model.layers.31.self_attn.q_proj",
426
+ "model.layers.31.self_attn.k_proj",
427
+ "model.layers.31.self_attn.v_proj",
428
+ "model.layers.31.self_attn.o_proj",
429
+ "model.layers.31.self_attn.g_proj",
430
+ "model.layers.31.moe.gate",
431
+ "model.layers.31.share_expert.gate_proj",
432
+ "model.layers.31.share_expert.up_proj",
433
+ "model.layers.31.share_expert.down_proj",
434
+ "model.layers.32.self_attn.q_proj",
435
+ "model.layers.32.self_attn.k_proj",
436
+ "model.layers.32.self_attn.v_proj",
437
+ "model.layers.32.self_attn.o_proj",
438
+ "model.layers.32.self_attn.g_proj",
439
+ "model.layers.32.moe.gate",
440
+ "model.layers.32.share_expert.gate_proj",
441
+ "model.layers.32.share_expert.up_proj",
442
+ "model.layers.32.share_expert.down_proj",
443
+ "model.layers.33.self_attn.q_proj",
444
+ "model.layers.33.self_attn.k_proj",
445
+ "model.layers.33.self_attn.v_proj",
446
+ "model.layers.33.self_attn.o_proj",
447
+ "model.layers.33.self_attn.g_proj",
448
+ "model.layers.33.moe.gate",
449
+ "model.layers.33.share_expert.gate_proj",
450
+ "model.layers.33.share_expert.up_proj",
451
+ "model.layers.33.share_expert.down_proj",
452
+ "model.layers.34.self_attn.q_proj",
453
+ "model.layers.34.self_attn.k_proj",
454
+ "model.layers.34.self_attn.v_proj",
455
+ "model.layers.34.self_attn.o_proj",
456
+ "model.layers.34.self_attn.g_proj",
457
+ "model.layers.34.moe.gate",
458
+ "model.layers.34.share_expert.gate_proj",
459
+ "model.layers.34.share_expert.up_proj",
460
+ "model.layers.34.share_expert.down_proj",
461
+ "model.layers.35.self_attn.q_proj",
462
+ "model.layers.35.self_attn.k_proj",
463
+ "model.layers.35.self_attn.v_proj",
464
+ "model.layers.35.self_attn.o_proj",
465
+ "model.layers.35.self_attn.g_proj",
466
+ "model.layers.35.moe.gate",
467
+ "model.layers.35.share_expert.gate_proj",
468
+ "model.layers.35.share_expert.up_proj",
469
+ "model.layers.35.share_expert.down_proj",
470
+ "model.layers.36.self_attn.q_proj",
471
+ "model.layers.36.self_attn.k_proj",
472
+ "model.layers.36.self_attn.v_proj",
473
+ "model.layers.36.self_attn.o_proj",
474
+ "model.layers.36.self_attn.g_proj",
475
+ "model.layers.36.moe.gate",
476
+ "model.layers.36.share_expert.gate_proj",
477
+ "model.layers.36.share_expert.up_proj",
478
+ "model.layers.36.share_expert.down_proj",
479
+ "model.layers.37.self_attn.q_proj",
480
+ "model.layers.37.self_attn.k_proj",
481
+ "model.layers.37.self_attn.v_proj",
482
+ "model.layers.37.self_attn.o_proj",
483
+ "model.layers.37.self_attn.g_proj",
484
+ "model.layers.37.moe.gate",
485
+ "model.layers.37.share_expert.gate_proj",
486
+ "model.layers.37.share_expert.up_proj",
487
+ "model.layers.37.share_expert.down_proj",
488
+ "model.layers.38.self_attn.q_proj",
489
+ "model.layers.38.self_attn.k_proj",
490
+ "model.layers.38.self_attn.v_proj",
491
+ "model.layers.38.self_attn.o_proj",
492
+ "model.layers.38.self_attn.g_proj",
493
+ "model.layers.38.moe.gate",
494
+ "model.layers.38.share_expert.gate_proj",
495
+ "model.layers.38.share_expert.up_proj",
496
+ "model.layers.38.share_expert.down_proj",
497
+ "model.layers.39.self_attn.q_proj",
498
+ "model.layers.39.self_attn.k_proj",
499
+ "model.layers.39.self_attn.v_proj",
500
+ "model.layers.39.self_attn.o_proj",
501
+ "model.layers.39.self_attn.g_proj",
502
+ "model.layers.39.moe.gate",
503
+ "model.layers.39.share_expert.gate_proj",
504
+ "model.layers.39.share_expert.up_proj",
505
+ "model.layers.39.share_expert.down_proj",
506
+ "model.layers.40.self_attn.q_proj",
507
+ "model.layers.40.self_attn.k_proj",
508
+ "model.layers.40.self_attn.v_proj",
509
+ "model.layers.40.self_attn.o_proj",
510
+ "model.layers.40.self_attn.g_proj",
511
+ "model.layers.40.moe.gate",
512
+ "model.layers.40.share_expert.gate_proj",
513
+ "model.layers.40.share_expert.up_proj",
514
+ "model.layers.40.share_expert.down_proj",
515
+ "model.layers.41.self_attn.q_proj",
516
+ "model.layers.41.self_attn.k_proj",
517
+ "model.layers.41.self_attn.v_proj",
518
+ "model.layers.41.self_attn.o_proj",
519
+ "model.layers.41.self_attn.g_proj",
520
+ "model.layers.41.moe.gate",
521
+ "model.layers.41.share_expert.gate_proj",
522
+ "model.layers.41.share_expert.up_proj",
523
+ "model.layers.41.share_expert.down_proj",
524
+ "model.layers.42.self_attn.q_proj",
525
+ "model.layers.42.self_attn.k_proj",
526
+ "model.layers.42.self_attn.v_proj",
527
+ "model.layers.42.self_attn.o_proj",
528
+ "model.layers.42.self_attn.g_proj",
529
+ "model.layers.42.moe.gate",
530
+ "model.layers.42.share_expert.gate_proj",
531
+ "model.layers.42.share_expert.up_proj",
532
+ "model.layers.42.share_expert.down_proj",
533
+ "model.layers.43.self_attn.q_proj",
534
+ "model.layers.43.self_attn.k_proj",
535
+ "model.layers.43.self_attn.v_proj",
536
+ "model.layers.43.self_attn.o_proj",
537
+ "model.layers.43.self_attn.g_proj",
538
+ "model.layers.43.moe.gate",
539
+ "model.layers.43.share_expert.gate_proj",
540
+ "model.layers.43.share_expert.up_proj",
541
+ "model.layers.43.share_expert.down_proj",
542
+ "model.layers.44.self_attn.q_proj",
543
+ "model.layers.44.self_attn.k_proj",
544
+ "model.layers.44.self_attn.v_proj",
545
+ "model.layers.44.self_attn.o_proj",
546
+ "model.layers.44.self_attn.g_proj",
547
+ "model.layers.44.moe.gate",
548
+ "model.layers.44.share_expert.gate_proj",
549
+ "model.layers.44.share_expert.up_proj",
550
+ "model.layers.44.share_expert.down_proj",
551
+ "lm_head"
552
+ ],
553
+ "export": {
554
+ "kv_cache_group": [],
555
+ "min_kv_scale": 0.0,
556
+ "pack_method": "reorder",
557
+ "weight_format": "real_quantized",
558
+ "weight_merge_groups": null
559
+ },
560
+ "global_quant_config": {
561
+ "bias": null,
562
+ "input_tensors": {
563
+ "ch_axis": -1,
564
+ "dtype": "fp4",
565
+ "group_size": 32,
566
+ "is_dynamic": true,
567
+ "is_scale_quant": false,
568
+ "mx_element_dtype": null,
569
+ "observer_cls": "PerBlockMXObserver",
570
+ "qscheme": "per_group",
571
+ "round_method": "half_even",
572
+ "scale_calculation_mode": "even",
573
+ "scale_format": "e8m0",
574
+ "scale_type": "float",
575
+ "symmetric": null
576
+ },
577
+ "output_tensors": null,
578
+ "target_device": null,
579
+ "weight": {
580
+ "ch_axis": -1,
581
+ "dtype": "fp4",
582
+ "group_size": 32,
583
+ "is_dynamic": false,
584
+ "is_scale_quant": false,
585
+ "mx_element_dtype": null,
586
+ "observer_cls": "PerBlockMXObserver",
587
+ "qscheme": "per_group",
588
+ "round_method": "half_even",
589
+ "scale_calculation_mode": "even",
590
+ "scale_format": "e8m0",
591
+ "scale_type": "float",
592
+ "symmetric": null
593
+ }
594
+ },
595
+ "kv_cache_post_rope": false,
596
+ "kv_cache_quant_config": {},
597
+ "layer_quant_config": {},
598
+ "layer_type_quant_config": {},
599
+ "quant_method": "quark",
600
+ "quant_mode": "eager_mode",
601
+ "softmax_quant_spec": null,
602
+ "version": "0.12+422c9a6d36"
603
+ },
604
+ "rms_norm_eps": 1e-05,
605
+ "rope_parameters": {
606
+ "factor": 2.0,
607
+ "high_freq_factor": 32.0,
608
+ "low_freq_factor": 1.0,
609
+ "original_max_position_embeddings": 131072,
610
+ "rope_type": "llama3"
611
+ },
612
+ "rope_scaling": {
613
+ "factor": 2.0,
614
+ "high_freq_factor": 32.0,
615
+ "low_freq_factor": 1.0,
616
+ "original_max_position_embeddings": 131072,
617
+ "rope_type": "llama3"
618
+ },
619
+ "rope_theta": [
620
+ 5000000.0,
621
+ 10000.0,
622
+ 10000.0,
623
+ 10000.0,
624
+ 5000000.0,
625
+ 10000.0,
626
+ 10000.0,
627
+ 10000.0,
628
+ 5000000.0,
629
+ 10000.0,
630
+ 10000.0,
631
+ 10000.0,
632
+ 5000000.0,
633
+ 10000.0,
634
+ 10000.0,
635
+ 10000.0,
636
+ 5000000.0,
637
+ 10000.0,
638
+ 10000.0,
639
+ 10000.0,
640
+ 5000000.0,
641
+ 10000.0,
642
+ 10000.0,
643
+ 10000.0,
644
+ 5000000.0,
645
+ 10000.0,
646
+ 10000.0,
647
+ 10000.0,
648
+ 5000000.0,
649
+ 10000.0,
650
+ 10000.0,
651
+ 10000.0,
652
+ 5000000.0,
653
+ 10000.0,
654
+ 10000.0,
655
+ 10000.0,
656
+ 5000000.0,
657
+ 10000.0,
658
+ 10000.0,
659
+ 10000.0,
660
+ 5000000.0,
661
+ 10000.0,
662
+ 10000.0,
663
+ 10000.0,
664
+ 5000000.0,
665
+ 10000.0,
666
+ 10000.0,
667
+ 10000.0
668
+ ],
669
+ "share_expert_dim": 1280,
670
+ "sink": false,
671
+ "sliding_window": 512,
672
+ "swiglu_limits": [
673
+ 0.0,
674
+ 0.0,
675
+ 0.0,
676
+ 0.0,
677
+ 0.0,
678
+ 0.0,
679
+ 0.0,
680
+ 0.0,
681
+ 0.0,
682
+ 0.0,
683
+ 0.0,
684
+ 0.0,
685
+ 0.0,
686
+ 0.0,
687
+ 0.0,
688
+ 0.0,
689
+ 0.0,
690
+ 0.0,
691
+ 0.0,
692
+ 0.0,
693
+ 0.0,
694
+ 0.0,
695
+ 0.0,
696
+ 0.0,
697
+ 0.0,
698
+ 0.0,
699
+ 0.0,
700
+ 0.0,
701
+ 0.0,
702
+ 0.0,
703
+ 0.0,
704
+ 0.0,
705
+ 0.0,
706
+ 0.0,
707
+ 0.0,
708
+ 0.0,
709
+ 0.0,
710
+ 0.0,
711
+ 0.0,
712
+ 0.0,
713
+ 0.0,
714
+ 0.0,
715
+ 0.0,
716
+ 7,
717
+ 7,
718
+ 0.0,
719
+ 0.0,
720
+ 0.0
721
+ ],
722
+ "swiglu_limits_shared": [
723
+ 0.0,
724
+ 0.0,
725
+ 0.0,
726
+ 0.0,
727
+ 0.0,
728
+ 0.0,
729
+ 0.0,
730
+ 0.0,
731
+ 0.0,
732
+ 0.0,
733
+ 0.0,
734
+ 0.0,
735
+ 0.0,
736
+ 0.0,
737
+ 0.0,
738
+ 0.0,
739
+ 0.0,
740
+ 0.0,
741
+ 0.0,
742
+ 0.0,
743
+ 0.0,
744
+ 0.0,
745
+ 0.0,
746
+ 0.0,
747
+ 0.0,
748
+ 0.0,
749
+ 0.0,
750
+ 0.0,
751
+ 0.0,
752
+ 0.0,
753
+ 0.0,
754
+ 0.0,
755
+ 0.0,
756
+ 0.0,
757
+ 0.0,
758
+ 0.0,
759
+ 0.0,
760
+ 0.0,
761
+ 0.0,
762
+ 0.0,
763
+ 0.0,
764
+ 0.0,
765
+ 0.0,
766
+ 0.0,
767
+ 16,
768
+ 0.0,
769
+ 0.0,
770
+ 0.0
771
+ ],
772
+ "transformers_version": "4.57.6",
773
+ "use_cache": false,
774
+ "use_head_wise_attn_gate": true,
775
+ "use_moe": true,
776
+ "use_moe_router_bias": true,
777
+ "use_qk_norm": true,
778
+ "use_rope_layers": [],
779
+ "vocab_size": 128896,
780
+ "yarn_only_types": [
781
+ "full_attention"
782
+ ],
783
+ "zero_centered": true
784
+ }
configuration_step3p5.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Optional, Union
2
+
3
+ from transformers.configuration_utils import PretrainedConfig
4
+
5
+
6
+
7
+ class Step3p5Config(PretrainedConfig):
8
+ model_type = "step3p5"
9
+ architectures = ["Step3p5ForCausalLM"]
10
+
11
+ def __init__(
12
+ self,
13
+ hidden_size: int = 4096,
14
+ intermediate_size: int = 11264,
15
+ num_attention_heads: int = 64,
16
+ num_attention_groups: int = 8,
17
+ num_hidden_layers: int = 45,
18
+ max_seq_len: int = 128000,
19
+ vocab_size: int = 128815,
20
+ rms_norm_eps: float = 1e-5,
21
+ moe_intermediate_size: int = 1280,
22
+ moe_num_experts: int = 288,
23
+ moe_top_k: int = 8,
24
+ rope_theta: float = 10000,
25
+ rope_scaling: Optional[dict[str, Any]] = None,
26
+ max_position_embeddings: int = 128000,
27
+ share_expert_dims: int = 1280,
28
+ head_dim: int = 128,
29
+ norm_expert_weight: bool = True,
30
+ layer_types: list[str] = None,
31
+ sliding_window: Optional[int] = None,
32
+ moe_layers_enum: tuple[int] = (3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14,
33
+ 15, 16, 17, 18, 19, 20, 21, 22, 23, 24,
34
+ 25, 26, 27, 28, 29, 30, 31, 32, 33, 34,
35
+ 35, 36, 37, 38, 39, 40, 41, 42, 43, 44),
36
+ **kwargs,
37
+ ) -> None:
38
+ self.hidden_size = hidden_size
39
+ self.intermediate_size = intermediate_size
40
+ self.num_attention_heads = num_attention_heads
41
+ self.num_attention_groups = num_attention_groups
42
+ self.num_hidden_layers = num_hidden_layers
43
+ self.max_seq_len = max_seq_len
44
+ self.vocab_size = vocab_size
45
+ self.rms_norm_eps = rms_norm_eps
46
+ self.moe_intermediate_size = moe_intermediate_size
47
+ self.moe_num_experts = moe_num_experts
48
+ self.moe_top_k = moe_top_k
49
+ self.rope_theta = rope_theta
50
+ self.rope_scaling = rope_scaling
51
+ self.max_position_embeddings = max_position_embeddings
52
+ self.share_expert_dim = share_expert_dims
53
+ self.head_dim = head_dim
54
+ self.norm_expert_weight = norm_expert_weight
55
+ self.moe_layers_enum = moe_layers_enum
56
+ self.layer_types = layer_types
57
+ self.sliding_window = sliding_window
58
+ super().__init__(**kwargs)
59
+
generation_config.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 0,
4
+ "eos_token_id": [
5
+ 1,
6
+ 2,
7
+ 128007
8
+ ],
9
+ "transformers_version": "4.57.6",
10
+ "use_cache": false
11
+ }
model-00001-of-00023.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:384546f5d09641de80e4878241ab44da14cbd90c21fa9995c334a9888bea5272
3
+ size 4997829840
model-00002-of-00023.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ed5cea6a5f458a8cc3c59c08403ca350927edc2f206c05e482e32f678eec0de5
3
+ size 4997967800
model-00003-of-00023.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bef1ec40bbb87937f8777f6e57bd965408395f923bd9512ebd528ffbd0d99c44
3
+ size 4998486552
model-00004-of-00023.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fe03c3c2c53f90afa5de4c5cb905115a4556e2ac9fba2a39e3d9cb1e1e1245fb
3
+ size 4997967808
model-00005-of-00023.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:54775d7e4e6eb6ea1f6ac958ac9be3ff05e07aa727ba855af9eb235629c15af8
3
+ size 4998489136
model-00006-of-00023.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a30e99ef04b3e91797e05986317852fd132658eab8650c34799f2d5ec1f3d6d1
3
+ size 4997971104
model-00007-of-00023.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1622bc5c0bc51eeda8b4fe4a2316dcffb1c57febaf43f88a4ab43c14b7d091f4
3
+ size 4998489992
model-00008-of-00023.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:33e700b1470e7d3393aabb30d0a3786ae6b13920d9efcdee62e7a15c93141691
3
+ size 4997971288
model-00009-of-00023.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7206bd64c85d4eaadea5bb67c988494042bf6a86a86702a913d8e36d93e62ca4
3
+ size 4998490088
model-00010-of-00023.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:104a1aef751338400e8684e1283d223837b6993f847735d4b7f2656fa9052636
3
+ size 4993155312
model-00011-of-00023.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:29ca0e3dbd4b8b5100967845407362b203e26950a685790b7e5598613b77b19a
3
+ size 4998489800
model-00012-of-00023.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6e447996b9b05e4cefffe76ca8c0c30cc8b297fe14e5a8525981184f75f35f32
3
+ size 4997971104
model-00013-of-00023.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b0410ad3ca30cab2f3a98d5fbb0a82fd9fe55224ff7acc4706019faa2eea4da3
3
+ size 4998489800
model-00014-of-00023.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3ff3a1a570903c6681e62a6110b31cbd6c590e3b436ca50dd7bc367e0fc65986
3
+ size 4997971104
model-00015-of-00023.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:402aa7a977a0ff3c79d718fc5077b29cd9a0d56176fc96f8ac401e89677a028f
3
+ size 4998489808
model-00016-of-00023.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:06ca8045ca30b180ffb580ad74a33f3eb9b8fde6c18936761bab6f669bfad400
3
+ size 4997971208
model-00017-of-00023.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bd45f815bf5b5a30ed5ce73850b18f21c00e39c266befb63bc4b1e9ebb43d424
3
+ size 4998490032
model-00018-of-00023.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3894225b6a72e4c30b14571d88934cbd9f91f70652347311324f9c557d1d8ed7
3
+ size 4997971288
model-00019-of-00023.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4609212a52dc17a26b1ffaf76e8d6b4d2b13ad07e57cf6a5ab8e39e18c052d4b
3
+ size 4986151200
model-00020-of-00023.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2be2764df435b6a37e7a4207f4a881d82604e6e032587f958f8a84f5fec5a11f
3
+ size 5000277752
model-00021-of-00023.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7c83bd8043b4210f5b0ed7a6ec638958a253799c04e39eaea5afcc6f30803b9a
3
+ size 4998135064
model-00022-of-00023.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d136448b1e20d79445f1d531a4c4ecc229cbdfc4f3f2d368783054dfcd43a20f
3
+ size 4998489800
model-00023-of-00023.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ec8c5e201893b3061feb026228557356348afef0aa09765c679a81961d571e8f
3
+ size 4540164960
model.safetensors.index.json ADDED
The diff for this file is too large to render. See raw diff
 
modeling_step3p5.py ADDED
@@ -0,0 +1,900 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The LLAMA4 and HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from dataclasses import dataclass
16
+ from typing import Callable, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ from transformers.activations import ACT2FN
22
+ from transformers.cache_utils import Cache, DynamicCache
23
+ from transformers.generation import GenerationMixin
24
+ from transformers.masking_utils import (create_causal_mask,
25
+ create_sliding_window_causal_mask)
26
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
27
+ from transformers.modeling_layers import GradientCheckpointingLayer
28
+ from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput
29
+ from transformers.modeling_rope_utils import (ROPE_INIT_FUNCTIONS,
30
+ dynamic_rope_update)
31
+ from transformers.modeling_utils import (ALL_ATTENTION_FUNCTIONS,
32
+ PreTrainedModel)
33
+ from transformers.processing_utils import Unpack
34
+ from transformers.utils import TransformersKwargs, can_return_tuple, logging
35
+
36
+ from .configuration_step3p5 import Step3p5Config
37
+
38
+ logger = logging.get_logger(__name__)
39
+
40
+ __all__ = ["Step3p5Model", "Step3p5ForCausalLM"]
41
+
42
+ class Step3p5RotaryEmbedding(nn.Module):
43
+
44
+ def __init__(self, config: Step3p5Config, device=None, layer_idx=None):
45
+ super().__init__()
46
+ # BC: "rope_type" was originally "type"
47
+ self.layer_idx = layer_idx
48
+ if config.rope_parameters is not None:
49
+ self.rope_type = config.rope_parameters.get(
50
+ "rope_type", config.rope_parameters.get("type"))
51
+ else:
52
+ self.rope_type = "default"
53
+ self.max_seq_len_cached = config.max_position_embeddings
54
+ self.original_max_seq_len = config.max_position_embeddings
55
+
56
+ partial_rotary_factors = getattr(config, "partial_rotary_factors",
57
+ None)
58
+ if partial_rotary_factors is not None:
59
+ config.partial_rotary_factor = partial_rotary_factors[
60
+ self.layer_idx]
61
+ else:
62
+ config.partial_rotary_factor = 1.0
63
+
64
+ self.rope_theta = config.rope_theta
65
+ if isinstance(config.rope_theta, list):
66
+ self.rope_theta = config.rope_theta.copy()
67
+ config.rope_theta = self.rope_theta[self.layer_idx]
68
+
69
+ self.config = config
70
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
71
+ inv_freq, self.attention_scaling = self.rope_init_fn(
72
+ self.config, device)
73
+
74
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
75
+ self.original_inv_freq = self.inv_freq
76
+ config.rope_theta = self.rope_theta
77
+
78
+ @torch.no_grad()
79
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
80
+ def forward(self, x, position_ids):
81
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(
82
+ position_ids.shape[0], -1, 1).to(x.device)
83
+ position_ids_expanded = position_ids[:, None, :].float().to(x.device)
84
+
85
+ device_type = x.device.type if isinstance(
86
+ x.device.type, str) and x.device.type != "mps" else "cpu"
87
+ with torch.autocast(device_type=device_type,
88
+ enabled=False): # Force float32
89
+ freqs = (inv_freq_expanded.float()
90
+ @ position_ids_expanded.float()).transpose(1, 2)
91
+ emb = torch.cat((freqs, freqs), dim=-1)
92
+ cos = emb.cos() * self.attention_scaling
93
+ sin = emb.sin() * self.attention_scaling
94
+
95
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
96
+
97
+
98
+ def rotate_half(x):
99
+ """Rotates half the hidden dims of the input."""
100
+ x1 = x[..., :x.shape[-1] // 2]
101
+ x2 = x[..., x.shape[-1] // 2:]
102
+ return torch.cat((-x2, x1), dim=-1)
103
+
104
+
105
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
106
+ """Applies Rotary Position Embedding to the query and key tensors.
107
+
108
+ Args:
109
+ q (`torch.Tensor`): The query tensor.
110
+ k (`torch.Tensor`): The key tensor.
111
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
112
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
113
+ position_ids (`torch.Tensor`, *optional*):
114
+ Deprecated and unused.
115
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
116
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
117
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
118
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
119
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
120
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
121
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
122
+ Returns:
123
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
124
+ """
125
+ rotary_dim = cos.shape[-1]
126
+ q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
127
+ k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
128
+
129
+ # Apply rotary embeddings on the first half or full tensor
130
+ q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
131
+ k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)
132
+
133
+ # Concatenate back to full shape
134
+ q_embed = torch.cat([q_embed, q_pass], dim=-1)
135
+ k_embed = torch.cat([k_embed, k_pass], dim=-1)
136
+ return q_embed, k_embed
137
+
138
+
139
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
140
+ """
141
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
142
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
143
+ """
144
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
145
+ if n_rep == 1:
146
+ return hidden_states
147
+ hidden_states = hidden_states[:, :,
148
+ None, :, :].expand(batch,
149
+ num_key_value_heads,
150
+ n_rep, slen, head_dim)
151
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen,
152
+ head_dim)
153
+
154
+
155
+ # Adapted from transformers.models.llama.modeling_llama.eager_attention_forward -> llama4 doesn't cast attn weights to fp32
156
+ def eager_attention_forward(
157
+ module: nn.Module,
158
+ query: torch.Tensor,
159
+ key: torch.Tensor,
160
+ value: torch.Tensor,
161
+ attention_mask: Optional[torch.Tensor],
162
+ scaling: float,
163
+ dropout: float = 0.0,
164
+ **kwargs,
165
+ ):
166
+ key_states = repeat_kv(key, module.num_key_value_groups)
167
+ value_states = repeat_kv(value, module.num_key_value_groups)
168
+ # breakpoint()
169
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
170
+ if attention_mask is not None:
171
+ causal_mask = attention_mask[:, :, :, :key_states.shape[-2]]
172
+ attn_weights = attn_weights + causal_mask
173
+
174
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
175
+ attn_weights = nn.functional.dropout(attn_weights,
176
+ p=dropout,
177
+ training=module.training)
178
+ attn_output = torch.matmul(attn_weights, value_states)
179
+ attn_output = attn_output.transpose(1, 2).contiguous()
180
+
181
+ return attn_output, attn_weights
182
+
183
+ @dataclass
184
+ class Step3p5CausalLMOutputWithPast(ModelOutput):
185
+ r"""
186
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
187
+ Language modeling loss (for next-token prediction).
188
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
189
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
190
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
191
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
192
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
193
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
194
+ `past_key_values` input) to speed up sequential decoding.
195
+ """
196
+
197
+ loss: Optional[torch.FloatTensor] = None
198
+ last_hidden_state: Optional[torch.FloatTensor] = None
199
+ logits: torch.FloatTensor = None
200
+ past_key_values: Optional[list[torch.FloatTensor]] = None
201
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
202
+ attentions: Optional[tuple[torch.FloatTensor]] = None
203
+
204
+
205
+ class Step3p5MLP(nn.Module):
206
+
207
+ def __init__(self, config, intermediate_size=None, swiglu_limit=None):
208
+ super().__init__()
209
+ self.config = config
210
+ self.hidden_size = config.hidden_size
211
+ self.intermediate_size = intermediate_size if intermediate_size is not None else config.intermediate_size
212
+ self.gate_proj = nn.Linear(self.hidden_size,
213
+ self.intermediate_size,
214
+ bias=False)
215
+ self.up_proj = nn.Linear(self.hidden_size,
216
+ self.intermediate_size,
217
+ bias=False)
218
+ self.down_proj = nn.Linear(self.intermediate_size,
219
+ self.hidden_size,
220
+ bias=False)
221
+ self.act_fn = ACT2FN["silu"]
222
+ self.limit = swiglu_limit
223
+
224
+ def forward(self, x):
225
+ up = self.up_proj(x)
226
+ gate = self.act_fn(self.gate_proj(x))
227
+ if self.limit is not None:
228
+ gate = gate.clamp(min=None, max=self.limit)
229
+ up = up.clamp(min=-self.limit, max=self.limit)
230
+
231
+ return self.down_proj(gate * up)
232
+
233
+
234
+ def sigmoid_routing_function(gating_output: torch.Tensor, topk: int,
235
+ renormalize: bool):
236
+ gating_output = gating_output.float()
237
+ gate_prob = torch.sigmoid(gating_output)
238
+ gate_prob = gate_prob / gate_prob.sum(dim=-1, keepdim=True)
239
+ topk_prob, indices = torch.topk(gate_prob, k=topk, dim=1)
240
+ expert_topk_weight = topk_prob
241
+ if renormalize:
242
+ expert_topk_weight = expert_topk_weight / torch.sum(
243
+ expert_topk_weight, dim=-1, keepdim=True)
244
+ return expert_topk_weight, indices
245
+
246
+
247
+ def softmax_routing_function(gating_output: torch.Tensor, top_k: int,
248
+ renormalize: bool):
249
+ gating_output = gating_output.float()
250
+ gate_prob = torch.softmax(gating_output, dim=-1)
251
+ gate_prob = gate_prob / gate_prob.sum(dim=-1, keepdim=True)
252
+ topk_prob, indices = torch.topk(gate_prob, k=top_k, dim=1)
253
+ expert_topk_weight = topk_prob
254
+ if renormalize:
255
+ expert_topk_weight = expert_topk_weight / torch.sum(
256
+ expert_topk_weight, dim=-1, keepdim=True)
257
+ return expert_topk_weight, indices.to(torch.int32)
258
+
259
+
260
+ class MoELinear(nn.Module):
261
+
262
+ def __init__(self, num_experts, in_features, out_features):
263
+ super().__init__()
264
+ self.num_experts = num_experts
265
+ self.in_features = in_features
266
+ self.out_features = out_features
267
+ self.weight = nn.Parameter(
268
+ torch.empty(num_experts, out_features, in_features))
269
+
270
+ def forward(self, x, expert_id):
271
+ x = F.linear(x.float(), self.weight[expert_id].float())
272
+ return x
273
+
274
+
275
+ class Step3p5MoEMLP(nn.Module):
276
+
277
+ def __init__(self, config, swiglu_limit=None):
278
+ super().__init__()
279
+ self.num_experts = config.moe_num_experts
280
+ self.top_k = config.moe_top_k
281
+ self.hidden_size = config.hidden_size
282
+ self.moe_intermediate_size = config.moe_intermediate_size
283
+
284
+ self.use_moe_router_bias = config.use_moe_router_bias
285
+ if self.use_moe_router_bias:
286
+ self.router_bias = nn.Parameter(torch.zeros(config.moe_num_experts,
287
+ dtype=torch.float32),
288
+ requires_grad=False)
289
+ self.custom_routing_function = self.router_bias_func
290
+ elif config.moe_router_activation == "sigmoid":
291
+ self.custom_routing_function = sigmoid_routing_function
292
+ else:
293
+ self.custom_routing_function = None
294
+ self.need_fp32_gate = config.need_fp32_gate
295
+ self.routed_scaling_factor = getattr(config,
296
+ "moe_router_scaling_factor", 1.0)
297
+
298
+ # gating
299
+ self.gate = nn.Linear(self.hidden_size, self.num_experts, bias=False)
300
+
301
+ self.act_fn = ACT2FN["silu"]
302
+ self.limit = swiglu_limit
303
+
304
+ self.up_proj = MoELinear(self.num_experts, self.hidden_size,
305
+ self.moe_intermediate_size)
306
+ self.gate_proj = MoELinear(self.num_experts, self.hidden_size,
307
+ self.moe_intermediate_size)
308
+ self.down_proj = MoELinear(self.num_experts,
309
+ self.moe_intermediate_size,
310
+ self.hidden_size)
311
+
312
+ def router_bias_func(self, gating_output: torch.Tensor, topk: int,
313
+ renormalize: bool):
314
+ gate_prob = torch.sigmoid(gating_output.float())
315
+ gate_prob_with_bias = gate_prob + self.router_bias.unsqueeze(0)
316
+ _, indices = torch.topk(gate_prob_with_bias, k=topk, dim=1)
317
+ topk_prob = torch.gather(gate_prob, 1, indices)
318
+ expert_topk_weight = topk_prob
319
+ if renormalize:
320
+ expert_topk_weight = expert_topk_weight / (
321
+ torch.sum(expert_topk_weight, dim=-1, keepdim=True) + 1e-20)
322
+ return expert_topk_weight, indices
323
+
324
+ def get_expert_output(self, inputs: torch.Tensor, expert_id):
325
+ #if self.limit is None:
326
+ up = self.up_proj(inputs, expert_id)
327
+ gate = self.act_fn(self.gate_proj(inputs, expert_id))
328
+ if self.limit is not None:
329
+ gate = gate.clamp(min=None, max=self.limit)
330
+ up = up.clamp(min=-self.limit, max=self.limit)
331
+
332
+ return self.down_proj(gate * up, expert_id)
333
+
334
+ def forward(self, hidden_states):
335
+ """ """
336
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
337
+ hidden_states = hidden_states.view(-1, hidden_dim)
338
+ if self.need_fp32_gate:
339
+ router_logits = torch.matmul(hidden_states.to(torch.float32), self.gate.weight.t().to(torch.float32))
340
+ else:
341
+ # router_logits: (batch * sequence_length, n_experts)
342
+ router_logits = self.gate(hidden_states)
343
+
344
+ if self.custom_routing_function:
345
+ routing_weights, selected_experts = self.custom_routing_function(
346
+ router_logits, self.top_k, renormalize=True)
347
+ else:
348
+ routing_weights = F.softmax(router_logits,
349
+ dim=1,
350
+ dtype=torch.float)
351
+ routing_weights, selected_experts = torch.topk(routing_weights,
352
+ self.top_k,
353
+ dim=-1)
354
+
355
+ routing_weights = routing_weights * self.routed_scaling_factor
356
+
357
+ final_hidden_states = torch.zeros(
358
+ (batch_size * sequence_length, hidden_dim),
359
+ dtype=hidden_states.dtype,
360
+ device=hidden_states.device)
361
+
362
+ # One hot encode the selected experts to create an expert mask
363
+ # this will be used to easily index which expert is going to be sollicitated
364
+ expert_mask = torch.nn.functional.one_hot(
365
+ selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
366
+
367
+ # Loop over all available experts in the model and perform the computation on each expert
368
+ for expert_idx in range(self.num_experts):
369
+ idx, top_x = torch.where(expert_mask[expert_idx])
370
+
371
+ # Index the correct hidden states and compute the expert hidden state for
372
+ # the current expert. We need to make sure to multiply the output hidden
373
+ # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
374
+ current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
375
+ current_hidden_states = (
376
+ self.get_expert_output(current_state, expert_idx) *
377
+ routing_weights[top_x, idx, None])
378
+
379
+ # However `index_add_` only support torch tensors for indexing so we'll use
380
+ # the `top_x` tensor here.
381
+ final_hidden_states.index_add_(
382
+ 0, top_x, current_hidden_states.to(hidden_states.dtype))
383
+ final_hidden_states = final_hidden_states.reshape(
384
+ batch_size, sequence_length, hidden_dim)
385
+ return final_hidden_states
386
+
387
+
388
+ class Step3p5RMSNorm(nn.Module):
389
+
390
+ def __init__(
391
+ self,
392
+ hidden_size: int,
393
+ eps: float = 1e-5,
394
+ ) -> None:
395
+ super().__init__()
396
+ self.weight = nn.Parameter(torch.ones(hidden_size))
397
+ self.variance_epsilon = eps
398
+
399
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
400
+ dtype = x.dtype
401
+ x = x.float()
402
+ variance = x.pow(2).mean(dim=-1, keepdim=True)
403
+ normed = x * torch.rsqrt(variance + self.variance_epsilon)
404
+ normed = normed * (self.weight.float() + 1)
405
+ return normed.to(dtype)
406
+ class Step3p5Attention(nn.Module):
407
+
408
+ def __init__(self, config: Step3p5Config, layer_idx):
409
+ super().__init__()
410
+ self.config = config
411
+ self.layer_idx = layer_idx
412
+ self.num_attention_heads = config.num_attention_heads
413
+ self.num_key_value_heads = config.num_attention_groups
414
+
415
+ layer_types = getattr(config, "layer_types", [])
416
+ if layer_types:
417
+ enable_sliding_window = layer_types[
418
+ self.layer_idx] == "sliding_attention"
419
+ else:
420
+ enable_sliding_window = self.layer_idx % 2 == 0
421
+
422
+ if hasattr(config, "yarn_only_types") and layer_types[
423
+ self.layer_idx] not in config.yarn_only_types:
424
+ config.rope_parameters = None
425
+ else:
426
+ config.rope_parameters = getattr(config, "rope_scaling", None)
427
+
428
+ self.sliding_window = config.sliding_window
429
+ if enable_sliding_window:
430
+ self.num_attention_heads = config.attention_other_setting[
431
+ "num_attention_heads"]
432
+ self.num_key_value_heads = config.attention_other_setting[
433
+ "num_attention_groups"]
434
+
435
+ if self.sliding_window is not None and enable_sliding_window:
436
+ self.sliding_window = (self.sliding_window)
437
+ else:
438
+ self.sliding_window = None
439
+ self.head_dim = getattr(config, "head_dim",
440
+ config.hidden_size // self.num_attention_heads)
441
+ self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads
442
+
443
+ self.rotary_emb = Step3p5RotaryEmbedding(config, layer_idx=layer_idx)
444
+
445
+ self.q_size = self.num_attention_heads * self.head_dim
446
+ self.kv_size = self.num_key_value_heads * self.head_dim
447
+ self.scaling = self.head_dim**-0.5
448
+
449
+ self.q_proj = nn.Linear(config.hidden_size, self.q_size, bias=False)
450
+ self.k_proj = nn.Linear(config.hidden_size, self.kv_size, bias=False)
451
+ self.v_proj = nn.Linear(config.hidden_size, self.kv_size, bias=False)
452
+ self.o_proj = nn.Linear(self.q_size, config.hidden_size, bias=False)
453
+ self.q_norm = Step3p5RMSNorm(self.head_dim,
454
+ eps=config.rms_norm_eps)
455
+ self.k_norm = Step3p5RMSNorm(self.head_dim,
456
+ eps=config.rms_norm_eps)
457
+
458
+ self.use_head_wise_attn_gate = config.use_head_wise_attn_gate
459
+ if self.use_head_wise_attn_gate:
460
+ self.g_proj = nn.Linear(config.hidden_size,
461
+ self.num_attention_heads,
462
+ bias=False)
463
+
464
+ self.use_rope = True
465
+ use_rope_layers = getattr(config, "use_rope_layers", None)
466
+ if use_rope_layers:
467
+ self.use_rope = use_rope_layers[self.layer_idx]
468
+
469
+ def forward(
470
+ self,
471
+ hidden_states: torch.Tensor,
472
+ attention_mask: Optional[torch.Tensor],
473
+ past_key_value: Optional[Cache] = None,
474
+ cache_position: Optional[torch.LongTensor] = None,
475
+ position_ids: Optional[torch.LongTensor] = None,
476
+ **kwargs: Unpack[FlashAttentionKwargs],
477
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor],
478
+ Optional[Tuple[torch.Tensor]]]:
479
+ input_shape = hidden_states.shape[:-1]
480
+ hidden_shape = (*input_shape, -1, self.head_dim)
481
+
482
+ query_states = self.q_norm(
483
+ self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
484
+ key_states = self.k_norm(
485
+ self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
486
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(
487
+ 1, 2)
488
+ if self.use_head_wise_attn_gate:
489
+ gate_states = self.g_proj(hidden_states)
490
+ cos, sin = self.rotary_emb(hidden_states, position_ids)
491
+
492
+ # cos, sin = position_embeddings
493
+ query_states, key_states = apply_rotary_pos_emb(
494
+ query_states, key_states, cos, sin)
495
+
496
+ # query_states, key_states = apply_rotary_pos_emb(query_norm_states, key_norm_states, cos, sin)
497
+ if past_key_value is not None:
498
+ # sin and cos are specific to RoPE models; position_ids needed for the static cache
499
+ cache_kwargs = {
500
+ "sin": sin,
501
+ "cos": cos,
502
+ "cache_position": cache_position
503
+ }
504
+ key_states, value_states = past_key_value.update(
505
+ key_states, value_states, self.layer_idx, cache_kwargs)
506
+
507
+ attention_interface: Callable = eager_attention_forward
508
+ # TODO: considering FP8;
509
+ # RuntimeError: Expected attn_mask dtype to be bool or float or to match query dtype,
510
+ # but got attn_mask.dtype: long int and query.dtype: c10::BFloat16 instead.
511
+ if self.config._attn_implementation != "eager":
512
+ attention_interface = ALL_ATTENTION_FUNCTIONS[
513
+ self.config._attn_implementation]
514
+
515
+ attn_output, attn_weights = attention_interface(
516
+ self,
517
+ query_states,
518
+ key_states,
519
+ value_states,
520
+ attention_mask,
521
+ dropout=0.0 if not self.training else self.attention_dropout,
522
+ scaling=self.scaling,
523
+ sliding_window=self.sliding_window, # main diff with Llama
524
+ **kwargs,
525
+ )
526
+ attn_output = attn_output.reshape(*input_shape, -1)
527
+ if self.use_head_wise_attn_gate:
528
+ output = attn_output.view(
529
+ *attn_output.shape[:-1], self.num_attention_heads,
530
+ self.head_dim) * gate_states.unsqueeze(-1).sigmoid()
531
+ attn_output = output.view(*attn_output.shape)
532
+ attn_output = self.o_proj(attn_output)
533
+
534
+ return attn_output, attn_weights
535
+
536
+
537
+ class Step3p5DecoderLayer(GradientCheckpointingLayer):
538
+
539
+ def __init__(self, config, layer_idx):
540
+ super().__init__()
541
+ self.hidden_size = config.hidden_size
542
+ self.layer_idx = layer_idx
543
+ self.self_attn = Step3p5Attention(config, layer_idx)
544
+ self.attention_type = config.layer_types[layer_idx]
545
+
546
+ moe_layers_enum = getattr(config, "moe_layers_enum", None)
547
+ if moe_layers_enum is not None:
548
+ moe_layers_idx = [
549
+ int(i) for i in moe_layers_enum.strip().split(',')
550
+ ]
551
+ else:
552
+ moe_layers_idx = [i for i in range(1, config.num_hidden_layers)]
553
+ self.is_moe_layer = layer_idx in moe_layers_idx
554
+ self.use_moe = False
555
+
556
+ if config.swiglu_limits_shared and config.swiglu_limits_shared[
557
+ layer_idx] is not None and config.swiglu_limits_shared[
558
+ layer_idx] != 0:
559
+ swiglu_limit_shared = config.swiglu_limits_shared[layer_idx]
560
+ else:
561
+ swiglu_limit_shared = None
562
+ if config.swiglu_limits and config.swiglu_limits[
563
+ layer_idx] is not None and config.swiglu_limits[layer_idx] != 0:
564
+ swiglu_limit = config.swiglu_limits[layer_idx]
565
+ else:
566
+ swiglu_limit = None
567
+ if self.is_moe_layer:
568
+ self.moe = Step3p5MoEMLP(config, swiglu_limit=swiglu_limit) #
569
+ self.share_expert = Step3p5MLP(
570
+ config,
571
+ intermediate_size=config.share_expert_dim,
572
+ swiglu_limit=swiglu_limit_shared)
573
+ self.use_moe = True
574
+ else:
575
+ self.mlp = Step3p5MLP(config,
576
+ intermediate_size=config.intermediate_size,
577
+ swiglu_limit=swiglu_limit_shared)
578
+
579
+ self.input_layernorm = Step3p5RMSNorm(
580
+ config.hidden_size,
581
+ eps=config.rms_norm_eps)
582
+ self.post_attention_layernorm = Step3p5RMSNorm(
583
+ config.hidden_size,
584
+ eps=config.rms_norm_eps)
585
+
586
+ def forward(
587
+ self,
588
+ hidden_states: torch.Tensor,
589
+ attention_mask: Optional[torch.Tensor] = None,
590
+ position_ids: Optional[torch.LongTensor] = None,
591
+ past_key_value: Optional[tuple[torch.Tensor]] = None,
592
+ cache_position: Optional[torch.LongTensor] = None,
593
+ **kwargs: Unpack[FlashAttentionKwargs],
594
+ ) -> torch.FloatTensor:
595
+ residual = hidden_states
596
+ hidden_states = self.input_layernorm(hidden_states)
597
+ hidden_states, _ = self.self_attn(
598
+ hidden_states=hidden_states,
599
+ attention_mask=attention_mask,
600
+ position_ids=position_ids,
601
+ past_key_value=past_key_value,
602
+ cache_position=cache_position,
603
+ **kwargs,
604
+ )
605
+ hidden_states = residual + hidden_states
606
+
607
+ # Fully Connected
608
+ residual = hidden_states
609
+ hidden_states = self.post_attention_layernorm(hidden_states)
610
+ if self.use_moe:
611
+ share_output = self.share_expert(hidden_states)
612
+ moe_output = self.moe(hidden_states)
613
+ ffn_output = moe_output + share_output
614
+ else:
615
+ ffn_output = self.mlp(hidden_states)
616
+ if isinstance(ffn_output, tuple):
617
+ hidden_states, _ = ffn_output
618
+ else:
619
+ hidden_states = ffn_output
620
+
621
+ hidden_states = residual + hidden_states
622
+ return hidden_states
623
+
624
+
625
+ class Step3p5PreTrainedModel(PreTrainedModel):
626
+ # Link this model family to its configuration class so PreTrainedModel.from_pretrained
627
+ # can load the config instead of failing with a NoneType error.
628
+ config_class = Step3p5Config
629
+ supports_gradient_checkpointing = True
630
+ _skip_keys_device_placement = ["past_key_values"]
631
+ _keys_to_ignore_on_load_unexpected = [
632
+ r"model\.layers\.45\.*",
633
+ r"model\.layers\.46\.*",
634
+ r"model\.layers\.47\.*"
635
+ ]
636
+ _supports_flash_attn = False
637
+ _supports_sdpa = True
638
+ _supports_flex_attn = True
639
+ _supports_static_cache = True
640
+ _supports_attention_backend = True
641
+
642
+
643
+ class Step3p5Model(Step3p5PreTrainedModel, GenerationMixin):
644
+ _no_split_modules = ["Step3p5DecoderLayer"]
645
+ base_model_prefix = "model"
646
+ _tied_weights_keys = ["lm_head.weight"]
647
+ config: Step3p5Config
648
+ def __init__(self, config: Step3p5Config):
649
+ super().__init__(config)
650
+ self.padding_idx = config.pad_token_id
651
+ self.vocab_size = config.vocab_size
652
+
653
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size,
654
+ self.padding_idx)
655
+ self.layers = nn.ModuleList([
656
+ Step3p5DecoderLayer(config, layer_idx)
657
+ for layer_idx in range(config.num_hidden_layers)
658
+ ])
659
+ self.norm = Step3p5RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
660
+ self.gradient_checkpointing = False
661
+ self.has_sliding_layers = "sliding_attention" in self.config.layer_types
662
+
663
+ # Initialize weights and apply final processing
664
+ self.post_init()
665
+
666
+ def get_input_embeddings(self, input_ids):
667
+ return self.embed_tokens(input_ids)
668
+
669
+ @can_return_tuple
670
+ def forward(
671
+ self,
672
+ input_ids: torch.LongTensor = None,
673
+ attention_mask: Optional[torch.Tensor] = None,
674
+ position_ids: Optional[torch.LongTensor] = None,
675
+ past_key_values: Optional[Cache] = None,
676
+ inputs_embeds: Optional[torch.FloatTensor] = None,
677
+ use_cache: Optional[bool] = None,
678
+ output_attentions: Optional[bool] = None,
679
+ output_hidden_states: Optional[bool] = None,
680
+ return_dict: Optional[bool] = None,
681
+ cache_position: Optional[torch.LongTensor] = None,
682
+ **kwargs: Unpack[TransformersKwargs],
683
+ ) -> Union[tuple, BaseModelOutputWithPast]:
684
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
685
+ output_hidden_states = (output_hidden_states
686
+ if output_hidden_states is not None else
687
+ self.config.output_hidden_states)
688
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
689
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
690
+ if (input_ids is None) ^ (inputs_embeds is not None):
691
+ raise ValueError(
692
+ "You must specify exactly one of input_ids or inputs_embeds")
693
+
694
+ if self.gradient_checkpointing and self.training and use_cache:
695
+ logger.warning_once(
696
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
697
+ )
698
+ use_cache = False
699
+
700
+ if inputs_embeds is None:
701
+ inputs_embeds = self.embed_tokens(
702
+ input_ids.to(self.embed_tokens.weight.device))
703
+
704
+ if use_cache and past_key_values is None:
705
+ past_key_values = DynamicCache()
706
+
707
+ if cache_position is None:
708
+ past_seen_tokens = past_key_values.get_seq_length(
709
+ ) if past_key_values is not None else 0
710
+ cache_position = torch.arange(past_seen_tokens,
711
+ past_seen_tokens +
712
+ inputs_embeds.shape[1],
713
+ device=inputs_embeds.device)
714
+
715
+ if position_ids is None:
716
+ position_ids = cache_position.unsqueeze(0)
717
+
718
+ hidden_states = inputs_embeds
719
+
720
+ # It may already have been prepared by e.g. `generate`
721
+ if not isinstance(causal_mask_mapping := attention_mask, dict):
722
+ # Prepare mask arguments
723
+ mask_kwargs = {
724
+ "config": self.config,
725
+ "input_embeds": inputs_embeds,
726
+ "attention_mask": attention_mask,
727
+ "cache_position": cache_position,
728
+ "past_key_values": past_key_values,
729
+ "position_ids": position_ids,
730
+ }
731
+ # Create the masks
732
+ causal_mask_mapping = {
733
+ "full_attention": create_causal_mask(**mask_kwargs),
734
+ }
735
+
736
+ # The sliding window alternating layers are not always activated depending on the config
737
+ if self.has_sliding_layers:
738
+ causal_mask_mapping[
739
+ "sliding_attention"] = create_sliding_window_causal_mask(
740
+ **mask_kwargs)
741
+
742
+ # # create position embeddings to be shared across the decoder layers
743
+ # decoder layers
744
+ all_hidden_states = () if output_hidden_states else None
745
+ all_self_attns = () if output_attentions else None
746
+ for decoder_layer in self.layers[:self.config.num_hidden_layers]:
747
+ if output_hidden_states:
748
+ all_hidden_states += (hidden_states, )
749
+
750
+ layer_outputs = decoder_layer(
751
+ hidden_states,
752
+ attention_mask=causal_mask_mapping[
753
+ decoder_layer.attention_type],
754
+ position_ids=position_ids,
755
+ past_key_value=past_key_values,
756
+ output_attentions=output_attentions,
757
+ use_cache=use_cache,
758
+ cache_position=cache_position,
759
+ **kwargs,
760
+ )
761
+
762
+ hidden_states = layer_outputs
763
+
764
+ hidden_states = self.norm(hidden_states)
765
+
766
+ return BaseModelOutputWithPast(
767
+ last_hidden_state=hidden_states,
768
+ past_key_values=past_key_values if use_cache else None,
769
+ hidden_states=all_hidden_states,
770
+ attentions=all_self_attns,
771
+ )
772
+
773
+
774
+ class Step3p5ForCausalLM(Step3p5PreTrainedModel, GenerationMixin):
775
+ _tied_weights_keys = ["lm_head.weight"]
776
+ config: Step3p5Config
777
+
778
+ def __init__(self, config: Step3p5Config):
779
+ super().__init__(config)
780
+ self.model = Step3p5Model(config)
781
+ self.lm_head = nn.Linear(config.hidden_size,
782
+ config.vocab_size,
783
+ bias=False)
784
+
785
+ self.post_init()
786
+
787
+ def get_input_embeddings(self):
788
+ return self.model.get_input_embeddings()
789
+
790
+ def set_input_embeddings(self, value):
791
+ self.model.set_input_embeddings(value)
792
+
793
+ def get_output_embeddings(self):
794
+ return self.model.get_output_embeddings()
795
+
796
+ def set_output_embeddings(self, new_embeddings):
797
+ self.model.set_output_embeddings(new_embeddings)
798
+
799
+ def set_decoder(self, decoder):
800
+ self.model.set_decoder(decoder)
801
+
802
+ def get_decoder(self):
803
+ return self.model.get_decoder()
804
+
805
+ def forward(
806
+ self,
807
+ input_ids: torch.LongTensor = None,
808
+ num_patches=None,
809
+ patch_pixel_values=None,
810
+ patch_newline_mask=None,
811
+ attention_mask: Optional[torch.Tensor] = None,
812
+ position_ids: Optional[torch.LongTensor] = None,
813
+ past_key_values: Optional[Cache] = None,
814
+ inputs_embeds: Optional[torch.FloatTensor] = None,
815
+ labels: Optional[torch.LongTensor] = None,
816
+ use_cache: Optional[bool] = None,
817
+ output_attentions: Optional[bool] = None,
818
+ output_hidden_states: Optional[bool] = None,
819
+ return_dict: Optional[bool] = None,
820
+ cache_position: Optional[torch.LongTensor] = None,
821
+ **kwargs: Unpack[TransformersKwargs],
822
+ ) -> Union[tuple, Step3p5CausalLMOutputWithPast]:
823
+ r"""
824
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
825
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
826
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
827
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
828
+ Example:
829
+ ```python
830
+ >>> from transformers import AutoTokenizer, Llama4ForCausalLM
831
+ >>> model = Llama4ForCausalLM.from_pretrained("meta-llama4/Llama4-2-7b-hf")
832
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama4/Llama4-2-7b-hf")
833
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
834
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
835
+ >>> # Generate
836
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
837
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
838
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
839
+ ```"""
840
+
841
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
842
+ output_hidden_states = (output_hidden_states
843
+ if output_hidden_states is not None else
844
+ self.config.output_hidden_states)
845
+ # breakpoint()
846
+ outputs = self.model(
847
+ input_ids=input_ids,
848
+ num_patches=num_patches,
849
+ patch_pixel_values=patch_pixel_values,
850
+ patch_newline_mask=patch_newline_mask,
851
+ position_ids=position_ids,
852
+ attention_mask=attention_mask,
853
+ past_key_values=past_key_values,
854
+ inputs_embeds=inputs_embeds,
855
+ use_cache=use_cache,
856
+ output_attentions=output_attentions,
857
+ output_hidden_states=output_hidden_states,
858
+ return_dict=return_dict,
859
+ cache_position=cache_position,
860
+ **kwargs,
861
+ )
862
+ hidden_states = outputs.last_hidden_state
863
+ logits = self.lm_head(hidden_states)
864
+
865
+ return Step3p5CausalLMOutputWithPast(logits=logits, )
866
+
867
+ def prepare_inputs_for_generation(
868
+ self,
869
+ input_ids,
870
+ past_key_values=None,
871
+ inputs_embeds=None,
872
+ pixel_values=None,
873
+ attention_mask=None,
874
+ cache_position=None,
875
+ logits_to_keep=None,
876
+ **kwargs,
877
+ ):
878
+
879
+ model_inputs = super().prepare_inputs_for_generation(
880
+ input_ids,
881
+ past_key_values=past_key_values,
882
+ inputs_embeds=inputs_embeds,
883
+ attention_mask=attention_mask,
884
+ cache_position=cache_position,
885
+ logits_to_keep=logits_to_keep,
886
+ **kwargs,
887
+ )
888
+
889
+ if cache_position[0] == 0:
890
+ # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
891
+ # Otherwise we need pixel values to be passed to model
892
+ model_inputs["pixel_values"] = pixel_values
893
+
894
+ return model_inputs
895
+
896
+ def _fix_state_dict_key_on_load(self, key: str) -> tuple[str, bool]:
897
+ if key.startswith("language_model."):
898
+ return key[len("language_model."):], True
899
+
900
+ return key, False
special_tokens_map.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|begin▁of▁sentence|>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|im_end|>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": "<|im_end|>"
17
+ }
step3p5_quantize_quark.py ADDED
@@ -0,0 +1,566 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ #
3
+ # Copyright (C) 2023 - 2026 Advanced Micro Devices, Inc. All rights reserved.
4
+ # SPDX-License-Identifier: MIT
5
+ #
6
+ # Quantization script for Step-3.5-Flash with MoE layer replacement
7
+
8
+ from __future__ import annotations
9
+
10
+ import argparse
11
+ import os
12
+ import re
13
+ import shutil
14
+ from pathlib import Path
15
+ from types import MethodType
16
+ from typing import Any
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+ from quark.torch import LLMTemplate, ModelQuantizer, export_safetensors
22
+ from quark.torch.utils.llm import (
23
+ get_calib_dataloader,
24
+ get_model,
25
+ get_tokenizer,
26
+ )
27
+ from quark.common.utils.log import ScreenLogger
28
+
29
+ try:
30
+ # Needed only when the model is loaded with accelerate offload (meta tensors).
31
+ from accelerate.hooks import AlignDevicesHook, add_hook_to_module # type: ignore
32
+ from accelerate.utils import PrefixedDataset # type: ignore
33
+
34
+ _ACCELERATE_AVAILABLE = True
35
+ except Exception:
36
+ AlignDevicesHook = None # type: ignore[assignment]
37
+ add_hook_to_module = None # type: ignore[assignment]
38
+ PrefixedDataset = None # type: ignore[assignment]
39
+ _ACCELERATE_AVAILABLE = False
40
+
41
+
42
+ DEFAULT_INPUT_MODEL_PATH = "stepfun-ai/Step-3.5-Flash"
43
+ DEFAULT_OUTPUT_MODEL_PATH = "quantized_models/Step-3.5-Flash-MXFP4"
44
+
45
+ logger = ScreenLogger(__name__)
46
+
47
+ def _step35_template_exclude_layers() -> list[str]:
48
+ return [
49
+ # embeddings / lm head / norms
50
+ "model.embed_tokens*",
51
+ "*embed_tokens*",
52
+ "*lm_head*",
53
+ "*layernorm*",
54
+ "*norm*",
55
+ # Router gate
56
+ "*moe.gate",
57
+ "*moe.router_bias*",
58
+ # The first three blocks use dense FFNs
59
+ "model.layers.0.mlp.*",
60
+ "model.layers.1.mlp.*",
61
+ "model.layers.2.mlp.*",
62
+ # Shared Experts
63
+ "*share_expert*",
64
+ "*self_attn*",
65
+ ]
66
+
67
+ PRESETS: dict[str, dict[str, Any]] = {
68
+
69
+ "mxfp4_moe_only_no_kvcache": {
70
+ "quant_scheme": "mxfp4",
71
+ "exclude_layers": _step35_template_exclude_layers(),
72
+ },
73
+ }
74
+
75
+
76
+ def _copy_non_weight_files(src_dir: str, dst_dir: str) -> None:
77
+ """
78
+ Copy non-weight files from an HF model directory (json/jinja/tokenizer, etc.),
79
+ while skipping *.safetensors and model.safetensors.index.json.
80
+
81
+ Note: `export_safetensors` exports the essential HF weights and config, but the
82
+ original model directory may contain extra assets (e.g. chat_template.jinja).
83
+ We do a conservative copy here so offline inference keeps those auxiliary files.
84
+ """
85
+ src = Path(src_dir)
86
+ dst = Path(dst_dir)
87
+ dst.mkdir(parents=True, exist_ok=True)
88
+
89
+ for p in src.iterdir():
90
+ if p.is_dir():
91
+ continue
92
+ name = p.name
93
+ if name.endswith(".safetensors"):
94
+ continue
95
+ if name == "model.safetensors.index.json":
96
+ continue
97
+ # Export will (re-)write config / generation_config; copying them here is harmless
98
+ # (later writes will overwrite).
99
+ shutil.copy2(p, dst / name)
100
+
101
+
102
+ def _register_step35_flash_template() -> None:
103
+ """
104
+ Register a Quark LLMTemplate for Step-3.5-Flash (config.model_type = step3p5).
105
+ """
106
+ model_type = "step3p5"
107
+ if model_type in LLMTemplate.list_available():
108
+ return
109
+
110
+
111
+ step35_flash_template = LLMTemplate(
112
+ model_type=model_type,
113
+ kv_layers_name=["*k_proj", "*v_proj"],
114
+ q_layer_name="*q_proj",
115
+ exclude_layers_name=_step35_template_exclude_layers(),
116
+ )
117
+ LLMTemplate.register_template(step35_flash_template)
118
+ logger.info("Registered LLMTemplate: %s", model_type)
119
+
120
+
121
+ @torch.no_grad()
122
+ def replace_step35_moelinear_with_linear(moe_module: Any) -> None:
123
+ """
124
+ Convert Step3p5MoEMLP's MoELinear modules into separate Linear layers per expert.
125
+ """
126
+ if getattr(moe_module, "_step35_replaced", False):
127
+ return
128
+
129
+ logger.debug("Converting Step3p5MoEMLP experts to separate gate/up/down Linear layers...")
130
+
131
+ # Get dimensions from the module
132
+ num_experts: int = int(getattr(moe_module, "moe_num_experts", 288))
133
+ hidden_size: int = int(getattr(moe_module, "hidden_size", 4096))
134
+ moe_intermediate_size: int = int(getattr(moe_module, "moe_intermediate_size", 1280))
135
+
136
+ # Store original device and dtype from one of the MoELinear modules
137
+ original_device = moe_module.gate_proj.weight.device
138
+ original_dtype = moe_module.gate_proj.weight.dtype
139
+ # [num_experts, in, out]
140
+ # Expose common attribute names for the forward helper
141
+ moe_module.hidden_size = hidden_size
142
+ moe_module.expert_dim = moe_intermediate_size
143
+ moe_module.num_experts = num_experts
144
+
145
+ is_meta: bool = original_device == torch.device("meta")
146
+ target_device_for_new = original_device if not is_meta else torch.device("meta")
147
+
148
+ # Create individual expert modules, each containing gate_proj, up_proj, down_proj
149
+ for expert_index in range(num_experts):
150
+ expert_module = nn.Module()
151
+ expert_module.gate_proj = nn.Linear(
152
+ hidden_size, moe_intermediate_size, bias=False, device=target_device_for_new, dtype=original_dtype
153
+ )
154
+ expert_module.up_proj = nn.Linear(
155
+ hidden_size, moe_intermediate_size, bias=False, device=target_device_for_new, dtype=original_dtype
156
+ )
157
+ expert_module.down_proj = nn.Linear(
158
+ moe_intermediate_size, hidden_size, bias=False, device=target_device_for_new, dtype=original_dtype
159
+ )
160
+ setattr(moe_module, str(expert_index), expert_module)
161
+
162
+
163
+ # Sync weights from MoELinear to individual Linear modules
164
+ weights_synced = _step35_sync_weights_to_linear(moe_module)
165
+
166
+ # Replace forward method
167
+ moe_module.forward = MethodType(_step35_moe_forward, moe_module)
168
+
169
+ if weights_synced:
170
+ _step35_cleanup_fused(moe_module)
171
+
172
+ moe_module._step35_replaced = True
173
+
174
+
175
+ @torch.no_grad()
176
+ def _step35_sync_weights_to_linear(module: Any) -> bool:
177
+ """
178
+ Split MoELinear weights and copy into per-expert Linear layers.
179
+ Returns True if synced; returns False if fused weights are still on 'meta' (not materialized).
180
+ MoELinear tensors in Step3p5MoEMLP are expected to be:
181
+ - gate_proj.weight: [num_experts, moe_intermediate_size, hidden_size]
182
+ - up_proj.weight: [num_experts, moe_intermediate_size, hidden_size]
183
+ - down_proj.weight: [num_experts, hidden_size, moe_intermediate_size]
184
+ """
185
+ if getattr(module, "_weights_synced", False):
186
+ return True
187
+
188
+ W_gate = getattr(module, "gate_proj", None)
189
+ W_up = getattr(module, "up_proj", None)
190
+ W_down = getattr(module, "down_proj", None)
191
+
192
+ if W_gate is None or W_up is None or W_down is None:
193
+ return False
194
+
195
+ is_offload = getattr(W_gate.weight, "is_meta", False) or W_gate.weight.device == torch.device("meta")
196
+ if is_offload:
197
+ # Loaded with accelerate offload: tensors live in module._hf_hook.weights_map on CPU.
198
+ if not _ACCELERATE_AVAILABLE:
199
+ raise RuntimeError(
200
+ "Model appears to be loaded with accelerate offload (meta tensors), but accelerate is not available."
201
+ )
202
+ if not hasattr(module, "_hf_hook"):
203
+ return False
204
+ W_gate = module._hf_hook.weights_map["gate_proj.weight"]
205
+ W_up = module._hf_hook.weights_map["up_proj.weight"]
206
+ W_down = module._hf_hook.weights_map["down_proj.weight"]
207
+
208
+ try:
209
+ for expert_index in range(int(module.num_experts)):
210
+ expert_module = getattr(module, str(expert_index))
211
+
212
+ W_gate_current = W_gate.weight[expert_index] # [moe_intermediate_size, hidden_size]
213
+ W_up_current = W_up.weight[expert_index] # [moe_intermediate_size, hidden_size]
214
+ W_down_current = W_down.weight[expert_index] # [hidden_size, moe_intermediate_size]
215
+
216
+ if is_offload:
217
+ hook = module._hf_hook
218
+ dataset = hook.weights_map.dataset
219
+ layer_value = [W_gate_current, W_up_current, W_down_current]
220
+ for idx, layer_name in enumerate(["gate_proj", "up_proj", "down_proj"]):
221
+ prefix = f"{hook.weights_map.prefix}{expert_index}.{layer_name}."
222
+ prefixed_weights_map = PrefixedDataset(dataset, prefix)
223
+ full_name = f"{prefix}weight"
224
+ dataset.all_keys.append(full_name)
225
+ dataset.state_dict[full_name] = layer_value[idx]
226
+
227
+ quark_hook = AlignDevicesHook(
228
+ execution_device=hook.execution_device,
229
+ offload=hook.offload,
230
+ io_same_device=hook.io_same_device,
231
+ weights_map=prefixed_weights_map,
232
+ offload_buffers=hook.offload_buffers,
233
+ place_submodules=hook.place_submodules,
234
+ skip_keys=hook.skip_keys,
235
+ tied_params_map=hook.tied_params_map,
236
+ )
237
+ linear_module = getattr(expert_module, layer_name)
238
+ add_hook_to_module(linear_module, quark_hook)
239
+ else:
240
+ # No transpose needed: nn.Linear expects [out_features, in_features], which matches MoELinear tensors.
241
+ expert_module.gate_proj.weight.data.copy_(W_gate_current.to(W_gate.weight.device))
242
+ expert_module.up_proj.weight.data.copy_(W_up_current.to(W_up.weight.device))
243
+ expert_module.down_proj.weight.data.copy_(W_down_current.to(W_down.weight.device))
244
+
245
+ if is_offload:
246
+ prefix = module._hf_hook.weights_map.prefix
247
+ del module._hf_hook.weights_map.dataset.state_dict[f"{prefix}gate_proj.weight"]
248
+ del module._hf_hook.weights_map.dataset.state_dict[f"{prefix}up_proj.weight"]
249
+ del module._hf_hook.weights_map.dataset.state_dict[f"{prefix}down_proj.weight"]
250
+ module._hf_hook.weights_map.dataset.all_keys.remove(f"{prefix}gate_proj.weight")
251
+ module._hf_hook.weights_map.dataset.all_keys.remove(f"{prefix}up_proj.weight")
252
+ module._hf_hook.weights_map.dataset.all_keys.remove(f"{prefix}down_proj.weight")
253
+
254
+ module._weights_synced = True
255
+ return True
256
+ except Exception as e:
257
+ logger.warning("Failed to sync Step3.5 MoE weights: %s", e)
258
+ return False
259
+
260
+
261
+
262
+ @torch.no_grad()
263
+ def _step35_cleanup_fused(module: Any) -> None:
264
+ """Optionally remove fused MoELinear modules after replacement."""
265
+ # The original MoELinear modules should be garbage collected
266
+ # when they're replaced, but we can explicitly clear references
267
+ for proj_name in ["gate_proj", "up_proj", "down_proj"]:
268
+ # Clear any remaining references to original MoELinear
269
+ if hasattr(module, proj_name):
270
+ delattr(module, proj_name)
271
+
272
+ torch.cuda.empty_cache()
273
+ logger.debug(f"Cleaned up original MoELinear modules")
274
+
275
+
276
+ def _step35_moe_forward(self: Any, hidden_states: torch.Tensor) -> torch.Tensor:
277
+ """
278
+ Forward using per-expert gate_proj, up_proj, down_proj (nn.Linear),
279
+ matching the original Step3p5MoEMLP.forward semantics but without MoELinear.
280
+ """
281
+ synced = _step35_sync_weights_to_linear(self)
282
+ if not synced:
283
+ raise RuntimeError(
284
+ "Step3p5MoEMLP weights are on 'meta' (not materialized). "
285
+ "Move fused parameters to a real device first, then call forward."
286
+ )
287
+
288
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
289
+ hidden_states = hidden_states.view(-1, hidden_dim)
290
+
291
+ # Router/gating
292
+ if self.need_fp32_gate:
293
+ router_logits = torch.matmul(hidden_states.to(torch.float32), self.gate.weight.t().to(torch.float32))
294
+ else:
295
+ # router_logits: (batch * sequence_length, n_experts)
296
+ router_logits = self.gate(hidden_states)
297
+
298
+ # Custom routing or standard softmax + top-k
299
+ if hasattr(self, 'custom_routing_function') and self.custom_routing_function:
300
+ routing_weights, selected_experts = self.custom_routing_function(
301
+ router_logits, self.top_k, renormalize=True)
302
+ else:
303
+ routing_weights = torch.nn.functional.softmax(router_logits, dim=1, dtype=torch.float)
304
+ routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
305
+
306
+ # Apply scaling factor
307
+ routing_weights = routing_weights * self.routed_scaling_factor
308
+
309
+ # Initialize output
310
+ final_hidden_states = torch.zeros(
311
+ (batch_size * sequence_length, hidden_dim),
312
+ dtype=hidden_states.dtype,
313
+ device=hidden_states.device)
314
+
315
+ # One hot encode the selected experts to create an expert mask
316
+ # this will be used to easily index which expert is going to be solicited
317
+ expert_mask = torch.nn.functional.one_hot(
318
+ selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
319
+
320
+ limit = getattr(self, 'limit', None)
321
+
322
+ # Loop over all available experts in the model and perform the computation on each expert
323
+ for expert_idx in range(self.num_experts):
324
+ idx, top_x = torch.where(expert_mask[expert_idx])
325
+
326
+ # Index the correct hidden states and compute the expert hidden state for
327
+ # the current expert. We need to make sure to multiply the output hidden
328
+ # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
329
+ current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
330
+
331
+ expert_module = getattr(self, str(expert_idx))
332
+
333
+ up = expert_module.up_proj(current_state)
334
+ gate = self.act_fn(expert_module.gate_proj(current_state))
335
+
336
+ if limit is not None:
337
+ gate = gate.clamp(min=None, max=limit)
338
+ up = up.clamp(min=-limit, max=limit)
339
+
340
+ current_hidden_states = expert_module.down_proj(gate * up) * routing_weights[top_x, idx, None]
341
+
342
+ # However `index_add_` only support torch tensors for indexing so we'll use
343
+ # the `top_x` tensor here.
344
+ final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
345
+
346
+ final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
347
+ return final_hidden_states
348
+
349
+ @torch.no_grad()
350
+ def patch_step35_moe(model: nn.Module) -> int:
351
+ """
352
+ Apply Step-3.5-Flash MoE replacement to all Step3p5MoEMLP modules in the model.
353
+ """
354
+ patched = 0
355
+ for name, module in model.named_modules(remove_duplicate=False):
356
+ if module.__class__.__name__ == "Step3p5MoEMLP":
357
+ replace_step35_moelinear_with_linear(module)
358
+ patched += 1
359
+ logger.debug(f"Patched MoE module: {name}")
360
+
361
+ if patched > 0:
362
+ logger.info("Patched %d Step3p5MoEMLP module(s) for quantization.", patched)
363
+ return patched
364
+
365
+
366
+ def _resolve_calib_device(device: str, model: nn.Module) -> str:
367
+ """
368
+ Resolve a torch-compatible device string for calibration inputs.
369
+ """
370
+ if device != "auto":
371
+ return str(device)
372
+
373
+ hf_map = getattr(model, "hf_device_map", None)
374
+ if isinstance(hf_map, dict):
375
+ cuda_ids: list[int] = []
376
+ for v in hf_map.values():
377
+ m = re.match(r"^cuda:(\d+)$", str(v))
378
+ if m:
379
+ cuda_ids.append(int(m.group(1)))
380
+ if cuda_ids:
381
+ return f"cuda:{min(cuda_ids)}"
382
+
383
+ if torch.cuda.is_available():
384
+ return "cuda:0"
385
+ return "cpu"
386
+
387
+
388
+ def main(args: argparse.Namespace) -> None:
389
+ os.makedirs(args.output_quantized_hf_path, exist_ok=True)
390
+
391
+ _register_step35_flash_template()
392
+
393
+ if getattr(args, "preset", None):
394
+ preset_cfg = PRESETS[args.preset]
395
+ args.quant_scheme = preset_cfg["quant_scheme"]
396
+ if getattr(args, "quant_algo", None) is None and "quant_algo" in preset_cfg:
397
+ args.quant_algo = preset_cfg["quant_algo"]
398
+ logger.info("Using preset: %s", args.preset)
399
+
400
+ logger.info("Input model: %s", args.model_dir)
401
+ logger.info("Output dir: %s", args.output_quantized_hf_path)
402
+
403
+ logger.info("Step 1/4: Loading model and tokenizer ...")
404
+ model, _ = get_model(
405
+ args.model_dir,
406
+ data_type=args.data_type,
407
+ device=args.device,
408
+ multi_gpu=args.multi_gpu,
409
+ multi_device=args.multi_device,
410
+ attn_implementation=args.model_attn_implementation,
411
+ trust_remote_code=args.trust_remote_code,
412
+ )
413
+
414
+ patch_step35_moe(model)
415
+
416
+ model_type = model.config.model_type if hasattr(model.config, "model_type") else model.config.architectures[0]
417
+ tokenizer = get_tokenizer(
418
+ args.model_dir, max_seq_len=args.seq_len, model_type=model_type, trust_remote_code=args.trust_remote_code
419
+ )
420
+
421
+ logger.info("Step 2/4: Building calibration dataloader ...")
422
+ base_device = str(model.device) if (args.multi_gpu or args.multi_device) else str(args.device)
423
+ main_device = _resolve_calib_device(base_device, model)
424
+ logger.info("Calibration dataset: %s", args.dataset)
425
+ calib_dataloader = get_calib_dataloader(
426
+ dataset_name=args.dataset,
427
+ tokenizer=tokenizer,
428
+ batch_size=args.batch_size,
429
+ num_calib_data=args.num_calib_data,
430
+ seqlen=args.seq_len,
431
+ device=main_device,
432
+ )
433
+
434
+ logger.info("Step 3/4: Quantizing ...")
435
+ template = LLMTemplate.get(model_type)
436
+ if args.exclude_layers is not None:
437
+ logger.warning(
438
+ "Ignoring --exclude_layers (%s). This script always uses "
439
+ "_register_step35_flash_template excludes for Step-3.5-Flash.",
440
+ args.exclude_layers,
441
+ )
442
+ exclude_layers = _step35_template_exclude_layers()
443
+ logger.info("Exclude layers (template): %s", exclude_layers)
444
+ if getattr(args, "quant_algo", None):
445
+ logger.info("Quantization algorithm(s): %s", args.quant_algo)
446
+
447
+ quant_config = template.get_config(
448
+ scheme=args.quant_scheme,
449
+ algorithm=args.quant_algo,
450
+ kv_cache_scheme=None,
451
+ min_kv_scale=0.0,
452
+ layer_config={},
453
+ attention_scheme=None,
454
+ exclude_layers=exclude_layers,
455
+ algo_configs=None,
456
+ )
457
+
458
+ quantizer = ModelQuantizer(quant_config, args.multi_device)
459
+ model = quantizer.quantize_model(model, calib_dataloader)
460
+
461
+ model = quantizer.freeze(model)
462
+
463
+ logger.info("Step 4/4: Exporting HF safetensors ...")
464
+ _copy_non_weight_files(args.model_dir, args.output_quantized_hf_path)
465
+ with torch.no_grad():
466
+ export_safetensors(
467
+ model=model,
468
+ output_dir=args.output_quantized_hf_path,
469
+ custom_mode="quark",
470
+ weight_format=args.export_weight_format,
471
+ pack_method=args.pack_method,
472
+ )
473
+ tokenizer.save_pretrained(args.output_quantized_hf_path)
474
+
475
+ logger.info("Export completed.")
476
+ logger.info("========== Quantization Completed Successfully ==========")
477
+
478
+
479
+ if __name__ == "__main__":
480
+ parser = argparse.ArgumentParser(
481
+ description="Offline quantization for Step-3.5-Flash with MoE layer replacement"
482
+ )
483
+ parser.add_argument("--model_dir", dest="model_dir", type=str, default=DEFAULT_INPUT_MODEL_PATH)
484
+ parser.add_argument("--output_dir", dest="output_quantized_hf_path", type=str, default=DEFAULT_OUTPUT_MODEL_PATH)
485
+
486
+ # Model loading
487
+ parser.add_argument("--device", type=str, default="auto", choices=["auto", "cuda", "cpu"])
488
+ parser.add_argument("--multi_gpu", dest="multi_gpu", action="store_true")
489
+ parser.add_argument("--multi_device", dest="multi_device", action="store_true")
490
+ parser.add_argument(
491
+ "--model_attn_implementation",
492
+ dest="model_attn_implementation",
493
+ type=str,
494
+ default="eager",
495
+ choices=["eager", "sdpa", "flash_attention_2"],
496
+ )
497
+ parser.add_argument(
498
+ "--data_type",
499
+ dest="data_type",
500
+ type=str,
501
+ default="auto",
502
+ choices=["auto", "float16", "bfloat16", "float32"],
503
+ )
504
+
505
+ # Calibration
506
+ parser.add_argument(
507
+ "--dataset",
508
+ dest="dataset",
509
+ type=str,
510
+ default="pileval",
511
+ help="Calibration dataset name. Default is 'pileval'.",
512
+ )
513
+ parser.add_argument("--seq_len", dest="seq_len", type=int, default=512)
514
+ parser.add_argument("--batch_size", dest="batch_size", type=int, default=1)
515
+ parser.add_argument("--num_calib_data", dest="num_calib_data", type=int, default=128)
516
+
517
+ # Quantization
518
+ parser.add_argument(
519
+ "--preset",
520
+ dest="preset",
521
+ type=str,
522
+ choices=sorted(PRESETS.keys()),
523
+ default="mxfp4_moe_only_no_kvcache",
524
+ help="Convenience preset for quantization settings.",
525
+ )
526
+ parser.add_argument(
527
+ "--quant_algo",
528
+ dest="quant_algo",
529
+ type=str,
530
+ default=None,
531
+ help="Optional quantization algorithm(s) to apply.",
532
+ )
533
+ parser.add_argument(
534
+ "--exclude_layers",
535
+ type=str,
536
+ nargs="*",
537
+ default=None,
538
+ help="Layer wildcard patterns to exclude from quantization.",
539
+ )
540
+
541
+ # Export
542
+ parser.add_argument("--pack_method", dest="pack_method", type=str, default="reorder", choices=["order", "reorder"])
543
+ parser.add_argument(
544
+ "--export_weight_format",
545
+ dest="export_weight_format",
546
+ type=str,
547
+ default="real_quantized",
548
+ choices=["fake_quantized", "real_quantized"],
549
+ )
550
+ group = parser.add_mutually_exclusive_group()
551
+ group.add_argument(
552
+ "--trust_remote_code",
553
+ action="store_true",
554
+ dest="trust_remote_code",
555
+ help="Enable execution of custom model code from the Hub (use only with repositories you fully trust).",
556
+ )
557
+ group.add_argument(
558
+ "--no_trust_remote_code",
559
+ action="store_false",
560
+ dest="trust_remote_code",
561
+ help="Disable execution of custom model code from the Hub (safer, recommended if unsure).",
562
+ )
563
+ parser.set_defaults(trust_remote_code=True)
564
+
565
+ main(parser.parse_args())
566
+
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
The diff for this file is too large to render. See raw diff