OpenOneRec commited on
Commit
a4e273f
·
verified ·
1 Parent(s): 928e824

Upload folder using huggingface_hub

Browse files
README.md CHANGED
@@ -1,3 +1,28 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # HF Template
2
+
3
+ Populate this folder on the training machine with a working HF model snapshot
4
+ (Qwen3 + Summary Attention variant) **before** running
5
+ `examples/pretrain/convert/convert_muse_to_hf.sh`.
6
+
7
+ ## Expected contents
8
+
9
+ | File | Purpose |
10
+ |---|---|
11
+ | `config.json` | HF config with `summary_*` fields matching your trained model |
12
+ | `generation_config.json` | Default generation settings |
13
+ | `tokenizer.json` / `tokenizer_config.json` / `special_tokens_map.json` | Tokenizer |
14
+ | `vocab.json` / `merges.txt` | Tokenizer vocab (if applicable) |
15
+ | `modeling_qwen3*.py` | HF-compatible modeling code with SA support |
16
+ | `summary_context.py` | Helper module imported by the modeling code |
17
+
18
+ Only the **weights** come from the Muse DCP — everything else above is copied
19
+ verbatim into `<OUTPUT_DIR>/<STEP>/hf/` by the convert script.
20
+
21
+ ## Usage
22
+
23
+ ```bash
24
+ bash examples/pretrain/convert/convert_muse_to_hf.sh \
25
+ /path/to/muse_outputs/1b6_sa_hybrid_8k \
26
+ global_step5000 \
27
+ examples/pretrain/hf_template
28
+ ```
config.json ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Qwen3ForCausalLM"
4
+ ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 151643,
8
+ "eos_token_id": 151643,
9
+ "head_dim": 128,
10
+ "hidden_act": "silu",
11
+ "hidden_size": 2048,
12
+ "initializer_range": 0.02,
13
+ "intermediate_size": 6144,
14
+ "max_position_embeddings": 131072,
15
+ "max_window_layers": 24,
16
+ "mix_coeff": 1.0,
17
+ "model_type": "qwen3",
18
+ "num_attention_heads": 16,
19
+ "num_hidden_layers": 24,
20
+ "num_key_value_heads": 16,
21
+ "rms_norm_eps": 1e-06,
22
+ "rope_scaling": null,
23
+ "rope_theta": 10000,
24
+ "sliding_window": null,
25
+ "tie_word_embeddings": true,
26
+ "torch_dtype": "bfloat16",
27
+ "transformers_version": "4.51.0",
28
+ "use_cache": true,
29
+ "use_sliding_window": false,
30
+ "vocab_size": 151936,
31
+ "summary_token_begin": 151936,
32
+ "summary_chunk_size": 8,
33
+ "summary_token_num": 1,
34
+ "use_summary_attention": true,
35
+ "summary_sliding_chunk_num": "([128]*3+[1024]*1)*6",
36
+ "summary_chunk_position_ids_type": "origin",
37
+ "summary_token_position_ids_type": "last_chunk_slice_right",
38
+ "summary_independent_parameters": true,
39
+ "summary_independent_attention_layernorm": false,
40
+ "summary_attention_mode": "kernel",
41
+ "auto_map": {
42
+ "AutoModel": "modeling_qwen3sa.Qwen3Model",
43
+ "AutoModelForCausalLM": "modeling_qwen3sa.Qwen3ForCausalLM"
44
+ }
45
+ }
46
+
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 151643,
3
+ "do_sample": false,
4
+ "eos_token_id": 151643,
5
+ "max_new_tokens": 2048,
6
+ "transformers_version": "4.37.0"
7
+ }
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
model-00000-of-00002.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c091f250701f2f187b3ecb7d4842d14f54d2c2e1224ec0a7b118628b6a0f39bb
3
+ size 4311824296
model-00001-of-00002.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:69b402da8cb159db5ee832f2f7350ed28c43746f7baf7d5c7ad079ac4e18d745
3
+ size 3733162824
model.safetensors.index.json ADDED
@@ -0,0 +1,405 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 8044941312
4
+ },
5
+ "weight_map": {
6
+ "model.embed_tokens.weight": "model-00000-of-00002.safetensors",
7
+ "model.layers.0.self_attn.q_proj.weight": "model-00000-of-00002.safetensors",
8
+ "model.layers.0.self_attn.k_proj.weight": "model-00000-of-00002.safetensors",
9
+ "model.layers.0.self_attn.v_proj.weight": "model-00000-of-00002.safetensors",
10
+ "model.layers.0.self_attn.o_proj.weight": "model-00000-of-00002.safetensors",
11
+ "model.layers.0.self_attn.q_norm.weight": "model-00000-of-00002.safetensors",
12
+ "model.layers.0.self_attn.k_norm.weight": "model-00000-of-00002.safetensors",
13
+ "model.layers.0.mlp.gate_proj.weight": "model-00000-of-00002.safetensors",
14
+ "model.layers.0.mlp.down_proj.weight": "model-00000-of-00002.safetensors",
15
+ "model.layers.0.mlp.up_proj.weight": "model-00000-of-00002.safetensors",
16
+ "model.layers.0.input_layernorm.weight": "model-00000-of-00002.safetensors",
17
+ "model.layers.0.post_attention_layernorm.weight": "model-00000-of-00002.safetensors",
18
+ "model.layers.1.self_attn.q_proj.weight": "model-00000-of-00002.safetensors",
19
+ "model.layers.1.self_attn.k_proj.weight": "model-00000-of-00002.safetensors",
20
+ "model.layers.1.self_attn.v_proj.weight": "model-00000-of-00002.safetensors",
21
+ "model.layers.1.self_attn.o_proj.weight": "model-00000-of-00002.safetensors",
22
+ "model.layers.1.self_attn.q_norm.weight": "model-00000-of-00002.safetensors",
23
+ "model.layers.1.self_attn.k_norm.weight": "model-00000-of-00002.safetensors",
24
+ "model.layers.1.mlp.gate_proj.weight": "model-00000-of-00002.safetensors",
25
+ "model.layers.1.mlp.down_proj.weight": "model-00000-of-00002.safetensors",
26
+ "model.layers.1.mlp.up_proj.weight": "model-00000-of-00002.safetensors",
27
+ "model.layers.1.input_layernorm.weight": "model-00000-of-00002.safetensors",
28
+ "model.layers.1.post_attention_layernorm.weight": "model-00000-of-00002.safetensors",
29
+ "model.layers.2.self_attn.q_proj.weight": "model-00000-of-00002.safetensors",
30
+ "model.layers.2.self_attn.k_proj.weight": "model-00000-of-00002.safetensors",
31
+ "model.layers.2.self_attn.v_proj.weight": "model-00000-of-00002.safetensors",
32
+ "model.layers.2.self_attn.o_proj.weight": "model-00000-of-00002.safetensors",
33
+ "model.layers.2.self_attn.q_norm.weight": "model-00000-of-00002.safetensors",
34
+ "model.layers.2.self_attn.k_norm.weight": "model-00000-of-00002.safetensors",
35
+ "model.layers.2.mlp.gate_proj.weight": "model-00000-of-00002.safetensors",
36
+ "model.layers.2.mlp.down_proj.weight": "model-00000-of-00002.safetensors",
37
+ "model.layers.2.mlp.up_proj.weight": "model-00000-of-00002.safetensors",
38
+ "model.layers.2.input_layernorm.weight": "model-00000-of-00002.safetensors",
39
+ "model.layers.2.post_attention_layernorm.weight": "model-00000-of-00002.safetensors",
40
+ "model.layers.3.self_attn.q_proj.weight": "model-00000-of-00002.safetensors",
41
+ "model.layers.3.self_attn.k_proj.weight": "model-00000-of-00002.safetensors",
42
+ "model.layers.3.self_attn.v_proj.weight": "model-00000-of-00002.safetensors",
43
+ "model.layers.3.self_attn.o_proj.weight": "model-00000-of-00002.safetensors",
44
+ "model.layers.3.self_attn.q_norm.weight": "model-00000-of-00002.safetensors",
45
+ "model.layers.3.self_attn.k_norm.weight": "model-00000-of-00002.safetensors",
46
+ "model.layers.3.mlp.gate_proj.weight": "model-00000-of-00002.safetensors",
47
+ "model.layers.3.mlp.down_proj.weight": "model-00000-of-00002.safetensors",
48
+ "model.layers.3.mlp.up_proj.weight": "model-00000-of-00002.safetensors",
49
+ "model.layers.3.input_layernorm.weight": "model-00000-of-00002.safetensors",
50
+ "model.layers.3.post_attention_layernorm.weight": "model-00000-of-00002.safetensors",
51
+ "model.layers.4.self_attn.q_proj.weight": "model-00000-of-00002.safetensors",
52
+ "model.layers.4.self_attn.k_proj.weight": "model-00000-of-00002.safetensors",
53
+ "model.layers.4.self_attn.v_proj.weight": "model-00000-of-00002.safetensors",
54
+ "model.layers.4.self_attn.o_proj.weight": "model-00000-of-00002.safetensors",
55
+ "model.layers.4.self_attn.q_norm.weight": "model-00000-of-00002.safetensors",
56
+ "model.layers.4.self_attn.k_norm.weight": "model-00000-of-00002.safetensors",
57
+ "model.layers.4.mlp.gate_proj.weight": "model-00000-of-00002.safetensors",
58
+ "model.layers.4.mlp.down_proj.weight": "model-00000-of-00002.safetensors",
59
+ "model.layers.4.mlp.up_proj.weight": "model-00000-of-00002.safetensors",
60
+ "model.layers.4.input_layernorm.weight": "model-00000-of-00002.safetensors",
61
+ "model.layers.4.post_attention_layernorm.weight": "model-00000-of-00002.safetensors",
62
+ "model.layers.5.self_attn.q_proj.weight": "model-00000-of-00002.safetensors",
63
+ "model.layers.5.self_attn.k_proj.weight": "model-00000-of-00002.safetensors",
64
+ "model.layers.5.self_attn.v_proj.weight": "model-00000-of-00002.safetensors",
65
+ "model.layers.5.self_attn.o_proj.weight": "model-00000-of-00002.safetensors",
66
+ "model.layers.5.self_attn.q_norm.weight": "model-00000-of-00002.safetensors",
67
+ "model.layers.5.self_attn.k_norm.weight": "model-00000-of-00002.safetensors",
68
+ "model.layers.5.mlp.gate_proj.weight": "model-00000-of-00002.safetensors",
69
+ "model.layers.5.mlp.down_proj.weight": "model-00000-of-00002.safetensors",
70
+ "model.layers.5.mlp.up_proj.weight": "model-00000-of-00002.safetensors",
71
+ "model.layers.5.input_layernorm.weight": "model-00000-of-00002.safetensors",
72
+ "model.layers.5.post_attention_layernorm.weight": "model-00000-of-00002.safetensors",
73
+ "model.layers.6.self_attn.q_proj.weight": "model-00000-of-00002.safetensors",
74
+ "model.layers.6.self_attn.k_proj.weight": "model-00000-of-00002.safetensors",
75
+ "model.layers.6.self_attn.v_proj.weight": "model-00000-of-00002.safetensors",
76
+ "model.layers.6.self_attn.o_proj.weight": "model-00000-of-00002.safetensors",
77
+ "model.layers.6.self_attn.q_norm.weight": "model-00000-of-00002.safetensors",
78
+ "model.layers.6.self_attn.k_norm.weight": "model-00000-of-00002.safetensors",
79
+ "model.layers.6.mlp.gate_proj.weight": "model-00000-of-00002.safetensors",
80
+ "model.layers.6.mlp.down_proj.weight": "model-00000-of-00002.safetensors",
81
+ "model.layers.6.mlp.up_proj.weight": "model-00000-of-00002.safetensors",
82
+ "model.layers.6.input_layernorm.weight": "model-00000-of-00002.safetensors",
83
+ "model.layers.6.post_attention_layernorm.weight": "model-00000-of-00002.safetensors",
84
+ "model.layers.7.self_attn.q_proj.weight": "model-00000-of-00002.safetensors",
85
+ "model.layers.7.self_attn.k_proj.weight": "model-00000-of-00002.safetensors",
86
+ "model.layers.7.self_attn.v_proj.weight": "model-00000-of-00002.safetensors",
87
+ "model.layers.7.self_attn.o_proj.weight": "model-00000-of-00002.safetensors",
88
+ "model.layers.7.self_attn.q_norm.weight": "model-00000-of-00002.safetensors",
89
+ "model.layers.7.self_attn.k_norm.weight": "model-00000-of-00002.safetensors",
90
+ "model.layers.7.mlp.gate_proj.weight": "model-00000-of-00002.safetensors",
91
+ "model.layers.7.mlp.down_proj.weight": "model-00000-of-00002.safetensors",
92
+ "model.layers.7.mlp.up_proj.weight": "model-00000-of-00002.safetensors",
93
+ "model.layers.7.input_layernorm.weight": "model-00000-of-00002.safetensors",
94
+ "model.layers.7.post_attention_layernorm.weight": "model-00000-of-00002.safetensors",
95
+ "model.layers.8.self_attn.q_proj.weight": "model-00000-of-00002.safetensors",
96
+ "model.layers.8.self_attn.k_proj.weight": "model-00000-of-00002.safetensors",
97
+ "model.layers.8.self_attn.v_proj.weight": "model-00000-of-00002.safetensors",
98
+ "model.layers.8.self_attn.o_proj.weight": "model-00000-of-00002.safetensors",
99
+ "model.layers.8.self_attn.q_norm.weight": "model-00000-of-00002.safetensors",
100
+ "model.layers.8.self_attn.k_norm.weight": "model-00000-of-00002.safetensors",
101
+ "model.layers.8.mlp.gate_proj.weight": "model-00000-of-00002.safetensors",
102
+ "model.layers.8.mlp.down_proj.weight": "model-00000-of-00002.safetensors",
103
+ "model.layers.8.mlp.up_proj.weight": "model-00000-of-00002.safetensors",
104
+ "model.layers.8.input_layernorm.weight": "model-00000-of-00002.safetensors",
105
+ "model.layers.8.post_attention_layernorm.weight": "model-00000-of-00002.safetensors",
106
+ "model.layers.9.self_attn.q_proj.weight": "model-00000-of-00002.safetensors",
107
+ "model.layers.9.self_attn.k_proj.weight": "model-00000-of-00002.safetensors",
108
+ "model.layers.9.self_attn.v_proj.weight": "model-00000-of-00002.safetensors",
109
+ "model.layers.9.self_attn.o_proj.weight": "model-00000-of-00002.safetensors",
110
+ "model.layers.9.self_attn.q_norm.weight": "model-00000-of-00002.safetensors",
111
+ "model.layers.9.self_attn.k_norm.weight": "model-00000-of-00002.safetensors",
112
+ "model.layers.9.mlp.gate_proj.weight": "model-00000-of-00002.safetensors",
113
+ "model.layers.9.mlp.down_proj.weight": "model-00000-of-00002.safetensors",
114
+ "model.layers.9.mlp.up_proj.weight": "model-00000-of-00002.safetensors",
115
+ "model.layers.9.input_layernorm.weight": "model-00000-of-00002.safetensors",
116
+ "model.layers.9.post_attention_layernorm.weight": "model-00000-of-00002.safetensors",
117
+ "model.layers.10.self_attn.q_proj.weight": "model-00000-of-00002.safetensors",
118
+ "model.layers.10.self_attn.k_proj.weight": "model-00000-of-00002.safetensors",
119
+ "model.layers.10.self_attn.v_proj.weight": "model-00000-of-00002.safetensors",
120
+ "model.layers.10.self_attn.o_proj.weight": "model-00000-of-00002.safetensors",
121
+ "model.layers.10.self_attn.q_norm.weight": "model-00000-of-00002.safetensors",
122
+ "model.layers.10.self_attn.k_norm.weight": "model-00000-of-00002.safetensors",
123
+ "model.layers.10.mlp.gate_proj.weight": "model-00000-of-00002.safetensors",
124
+ "model.layers.10.mlp.down_proj.weight": "model-00000-of-00002.safetensors",
125
+ "model.layers.10.mlp.up_proj.weight": "model-00000-of-00002.safetensors",
126
+ "model.layers.10.input_layernorm.weight": "model-00000-of-00002.safetensors",
127
+ "model.layers.10.post_attention_layernorm.weight": "model-00000-of-00002.safetensors",
128
+ "model.layers.11.self_attn.q_proj.weight": "model-00000-of-00002.safetensors",
129
+ "model.layers.11.self_attn.k_proj.weight": "model-00000-of-00002.safetensors",
130
+ "model.layers.11.self_attn.v_proj.weight": "model-00000-of-00002.safetensors",
131
+ "model.layers.11.self_attn.o_proj.weight": "model-00000-of-00002.safetensors",
132
+ "model.layers.11.self_attn.q_norm.weight": "model-00000-of-00002.safetensors",
133
+ "model.layers.11.self_attn.k_norm.weight": "model-00000-of-00002.safetensors",
134
+ "model.layers.11.mlp.gate_proj.weight": "model-00000-of-00002.safetensors",
135
+ "model.layers.11.mlp.down_proj.weight": "model-00000-of-00002.safetensors",
136
+ "model.layers.11.mlp.up_proj.weight": "model-00000-of-00002.safetensors",
137
+ "model.layers.11.input_layernorm.weight": "model-00000-of-00002.safetensors",
138
+ "model.layers.11.post_attention_layernorm.weight": "model-00000-of-00002.safetensors",
139
+ "model.layers.12.self_attn.q_proj.weight": "model-00000-of-00002.safetensors",
140
+ "model.layers.12.self_attn.k_proj.weight": "model-00000-of-00002.safetensors",
141
+ "model.layers.12.self_attn.v_proj.weight": "model-00000-of-00002.safetensors",
142
+ "model.layers.12.self_attn.o_proj.weight": "model-00000-of-00002.safetensors",
143
+ "model.layers.12.self_attn.q_norm.weight": "model-00000-of-00002.safetensors",
144
+ "model.layers.12.self_attn.k_norm.weight": "model-00000-of-00002.safetensors",
145
+ "model.layers.12.mlp.gate_proj.weight": "model-00000-of-00002.safetensors",
146
+ "model.layers.12.mlp.down_proj.weight": "model-00000-of-00002.safetensors",
147
+ "model.layers.12.mlp.up_proj.weight": "model-00000-of-00002.safetensors",
148
+ "model.layers.12.input_layernorm.weight": "model-00000-of-00002.safetensors",
149
+ "model.layers.12.post_attention_layernorm.weight": "model-00000-of-00002.safetensors",
150
+ "model.layers.13.self_attn.q_proj.weight": "model-00000-of-00002.safetensors",
151
+ "model.layers.13.self_attn.k_proj.weight": "model-00000-of-00002.safetensors",
152
+ "model.layers.13.self_attn.v_proj.weight": "model-00000-of-00002.safetensors",
153
+ "model.layers.13.self_attn.o_proj.weight": "model-00000-of-00002.safetensors",
154
+ "model.layers.13.self_attn.q_norm.weight": "model-00000-of-00002.safetensors",
155
+ "model.layers.13.self_attn.k_norm.weight": "model-00000-of-00002.safetensors",
156
+ "model.layers.13.mlp.gate_proj.weight": "model-00000-of-00002.safetensors",
157
+ "model.layers.13.mlp.down_proj.weight": "model-00000-of-00002.safetensors",
158
+ "model.layers.13.mlp.up_proj.weight": "model-00000-of-00002.safetensors",
159
+ "model.layers.13.input_layernorm.weight": "model-00000-of-00002.safetensors",
160
+ "model.layers.13.post_attention_layernorm.weight": "model-00000-of-00002.safetensors",
161
+ "model.layers.14.self_attn.q_proj.weight": "model-00000-of-00002.safetensors",
162
+ "model.layers.14.self_attn.k_proj.weight": "model-00000-of-00002.safetensors",
163
+ "model.layers.14.self_attn.v_proj.weight": "model-00000-of-00002.safetensors",
164
+ "model.layers.14.self_attn.o_proj.weight": "model-00000-of-00002.safetensors",
165
+ "model.layers.14.self_attn.q_norm.weight": "model-00000-of-00002.safetensors",
166
+ "model.layers.14.self_attn.k_norm.weight": "model-00000-of-00002.safetensors",
167
+ "model.layers.14.mlp.gate_proj.weight": "model-00000-of-00002.safetensors",
168
+ "model.layers.14.mlp.down_proj.weight": "model-00000-of-00002.safetensors",
169
+ "model.layers.14.mlp.up_proj.weight": "model-00000-of-00002.safetensors",
170
+ "model.layers.14.input_layernorm.weight": "model-00000-of-00002.safetensors",
171
+ "model.layers.14.post_attention_layernorm.weight": "model-00000-of-00002.safetensors",
172
+ "model.layers.15.self_attn.q_proj.weight": "model-00000-of-00002.safetensors",
173
+ "model.layers.15.self_attn.k_proj.weight": "model-00000-of-00002.safetensors",
174
+ "model.layers.15.self_attn.v_proj.weight": "model-00000-of-00002.safetensors",
175
+ "model.layers.15.self_attn.o_proj.weight": "model-00000-of-00002.safetensors",
176
+ "model.layers.15.self_attn.q_norm.weight": "model-00000-of-00002.safetensors",
177
+ "model.layers.15.self_attn.k_norm.weight": "model-00000-of-00002.safetensors",
178
+ "model.layers.15.mlp.gate_proj.weight": "model-00000-of-00002.safetensors",
179
+ "model.layers.15.mlp.down_proj.weight": "model-00000-of-00002.safetensors",
180
+ "model.layers.15.mlp.up_proj.weight": "model-00000-of-00002.safetensors",
181
+ "model.layers.15.input_layernorm.weight": "model-00000-of-00002.safetensors",
182
+ "model.layers.15.post_attention_layernorm.weight": "model-00000-of-00002.safetensors",
183
+ "model.layers.16.self_attn.q_proj.weight": "model-00000-of-00002.safetensors",
184
+ "model.layers.16.self_attn.k_proj.weight": "model-00000-of-00002.safetensors",
185
+ "model.layers.16.self_attn.v_proj.weight": "model-00000-of-00002.safetensors",
186
+ "model.layers.16.self_attn.o_proj.weight": "model-00000-of-00002.safetensors",
187
+ "model.layers.16.self_attn.q_norm.weight": "model-00000-of-00002.safetensors",
188
+ "model.layers.16.self_attn.k_norm.weight": "model-00000-of-00002.safetensors",
189
+ "model.layers.16.mlp.gate_proj.weight": "model-00000-of-00002.safetensors",
190
+ "model.layers.16.mlp.down_proj.weight": "model-00000-of-00002.safetensors",
191
+ "model.layers.16.mlp.up_proj.weight": "model-00000-of-00002.safetensors",
192
+ "model.layers.16.input_layernorm.weight": "model-00000-of-00002.safetensors",
193
+ "model.layers.16.post_attention_layernorm.weight": "model-00000-of-00002.safetensors",
194
+ "model.layers.17.self_attn.q_proj.weight": "model-00000-of-00002.safetensors",
195
+ "model.layers.17.self_attn.k_proj.weight": "model-00000-of-00002.safetensors",
196
+ "model.layers.17.self_attn.v_proj.weight": "model-00000-of-00002.safetensors",
197
+ "model.layers.17.self_attn.o_proj.weight": "model-00000-of-00002.safetensors",
198
+ "model.layers.17.self_attn.q_norm.weight": "model-00000-of-00002.safetensors",
199
+ "model.layers.17.self_attn.k_norm.weight": "model-00000-of-00002.safetensors",
200
+ "model.layers.17.mlp.gate_proj.weight": "model-00000-of-00002.safetensors",
201
+ "model.layers.17.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
202
+ "model.layers.17.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
203
+ "model.layers.17.input_layernorm.weight": "model-00001-of-00002.safetensors",
204
+ "model.layers.17.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
205
+ "model.layers.18.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
206
+ "model.layers.18.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
207
+ "model.layers.18.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
208
+ "model.layers.18.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
209
+ "model.layers.18.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
210
+ "model.layers.18.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
211
+ "model.layers.18.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
212
+ "model.layers.18.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
213
+ "model.layers.18.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
214
+ "model.layers.18.input_layernorm.weight": "model-00001-of-00002.safetensors",
215
+ "model.layers.18.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
216
+ "model.layers.19.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
217
+ "model.layers.19.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
218
+ "model.layers.19.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
219
+ "model.layers.19.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
220
+ "model.layers.19.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
221
+ "model.layers.19.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
222
+ "model.layers.19.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
223
+ "model.layers.19.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
224
+ "model.layers.19.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
225
+ "model.layers.19.input_layernorm.weight": "model-00001-of-00002.safetensors",
226
+ "model.layers.19.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
227
+ "model.layers.20.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
228
+ "model.layers.20.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
229
+ "model.layers.20.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
230
+ "model.layers.20.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
231
+ "model.layers.20.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
232
+ "model.layers.20.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
233
+ "model.layers.20.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
234
+ "model.layers.20.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
235
+ "model.layers.20.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
236
+ "model.layers.20.input_layernorm.weight": "model-00001-of-00002.safetensors",
237
+ "model.layers.20.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
238
+ "model.layers.21.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
239
+ "model.layers.21.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
240
+ "model.layers.21.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
241
+ "model.layers.21.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
242
+ "model.layers.21.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
243
+ "model.layers.21.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
244
+ "model.layers.21.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
245
+ "model.layers.21.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
246
+ "model.layers.21.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
247
+ "model.layers.21.input_layernorm.weight": "model-00001-of-00002.safetensors",
248
+ "model.layers.21.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
249
+ "model.layers.22.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
250
+ "model.layers.22.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
251
+ "model.layers.22.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
252
+ "model.layers.22.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
253
+ "model.layers.22.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
254
+ "model.layers.22.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
255
+ "model.layers.22.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
256
+ "model.layers.22.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
257
+ "model.layers.22.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
258
+ "model.layers.22.input_layernorm.weight": "model-00001-of-00002.safetensors",
259
+ "model.layers.22.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
260
+ "model.layers.23.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
261
+ "model.layers.23.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
262
+ "model.layers.23.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
263
+ "model.layers.23.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
264
+ "model.layers.23.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
265
+ "model.layers.23.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
266
+ "model.layers.23.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
267
+ "model.layers.23.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
268
+ "model.layers.23.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
269
+ "model.layers.23.input_layernorm.weight": "model-00001-of-00002.safetensors",
270
+ "model.layers.23.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
271
+ "model.layers.24.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
272
+ "model.layers.24.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
273
+ "model.layers.24.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
274
+ "model.layers.24.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
275
+ "model.layers.24.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
276
+ "model.layers.24.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
277
+ "model.layers.24.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
278
+ "model.layers.24.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
279
+ "model.layers.24.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
280
+ "model.layers.24.input_layernorm.weight": "model-00001-of-00002.safetensors",
281
+ "model.layers.24.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
282
+ "model.layers.25.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
283
+ "model.layers.25.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
284
+ "model.layers.25.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
285
+ "model.layers.25.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
286
+ "model.layers.25.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
287
+ "model.layers.25.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
288
+ "model.layers.25.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
289
+ "model.layers.25.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
290
+ "model.layers.25.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
291
+ "model.layers.25.input_layernorm.weight": "model-00001-of-00002.safetensors",
292
+ "model.layers.25.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
293
+ "model.layers.26.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
294
+ "model.layers.26.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
295
+ "model.layers.26.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
296
+ "model.layers.26.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
297
+ "model.layers.26.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
298
+ "model.layers.26.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
299
+ "model.layers.26.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
300
+ "model.layers.26.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
301
+ "model.layers.26.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
302
+ "model.layers.26.input_layernorm.weight": "model-00001-of-00002.safetensors",
303
+ "model.layers.26.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
304
+ "model.layers.27.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
305
+ "model.layers.27.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
306
+ "model.layers.27.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
307
+ "model.layers.27.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
308
+ "model.layers.27.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
309
+ "model.layers.27.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
310
+ "model.layers.27.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
311
+ "model.layers.27.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
312
+ "model.layers.27.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
313
+ "model.layers.27.input_layernorm.weight": "model-00001-of-00002.safetensors",
314
+ "model.layers.27.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
315
+ "model.layers.28.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
316
+ "model.layers.28.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
317
+ "model.layers.28.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
318
+ "model.layers.28.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
319
+ "model.layers.28.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
320
+ "model.layers.28.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
321
+ "model.layers.28.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
322
+ "model.layers.28.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
323
+ "model.layers.28.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
324
+ "model.layers.28.input_layernorm.weight": "model-00001-of-00002.safetensors",
325
+ "model.layers.28.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
326
+ "model.layers.29.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
327
+ "model.layers.29.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
328
+ "model.layers.29.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
329
+ "model.layers.29.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
330
+ "model.layers.29.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
331
+ "model.layers.29.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
332
+ "model.layers.29.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
333
+ "model.layers.29.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
334
+ "model.layers.29.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
335
+ "model.layers.29.input_layernorm.weight": "model-00001-of-00002.safetensors",
336
+ "model.layers.29.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
337
+ "model.layers.30.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
338
+ "model.layers.30.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
339
+ "model.layers.30.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
340
+ "model.layers.30.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
341
+ "model.layers.30.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
342
+ "model.layers.30.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
343
+ "model.layers.30.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
344
+ "model.layers.30.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
345
+ "model.layers.30.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
346
+ "model.layers.30.input_layernorm.weight": "model-00001-of-00002.safetensors",
347
+ "model.layers.30.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
348
+ "model.layers.31.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
349
+ "model.layers.31.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
350
+ "model.layers.31.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
351
+ "model.layers.31.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
352
+ "model.layers.31.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
353
+ "model.layers.31.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
354
+ "model.layers.31.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
355
+ "model.layers.31.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
356
+ "model.layers.31.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
357
+ "model.layers.31.input_layernorm.weight": "model-00001-of-00002.safetensors",
358
+ "model.layers.31.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
359
+ "model.layers.32.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
360
+ "model.layers.32.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
361
+ "model.layers.32.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
362
+ "model.layers.32.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
363
+ "model.layers.32.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
364
+ "model.layers.32.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
365
+ "model.layers.32.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
366
+ "model.layers.32.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
367
+ "model.layers.32.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
368
+ "model.layers.32.input_layernorm.weight": "model-00001-of-00002.safetensors",
369
+ "model.layers.32.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
370
+ "model.layers.33.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
371
+ "model.layers.33.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
372
+ "model.layers.33.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
373
+ "model.layers.33.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
374
+ "model.layers.33.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
375
+ "model.layers.33.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
376
+ "model.layers.33.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
377
+ "model.layers.33.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
378
+ "model.layers.33.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
379
+ "model.layers.33.input_layernorm.weight": "model-00001-of-00002.safetensors",
380
+ "model.layers.33.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
381
+ "model.layers.34.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
382
+ "model.layers.34.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
383
+ "model.layers.34.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
384
+ "model.layers.34.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
385
+ "model.layers.34.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
386
+ "model.layers.34.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
387
+ "model.layers.34.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
388
+ "model.layers.34.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
389
+ "model.layers.34.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
390
+ "model.layers.34.input_layernorm.weight": "model-00001-of-00002.safetensors",
391
+ "model.layers.34.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
392
+ "model.layers.35.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
393
+ "model.layers.35.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
394
+ "model.layers.35.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
395
+ "model.layers.35.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
396
+ "model.layers.35.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
397
+ "model.layers.35.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
398
+ "model.layers.35.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
399
+ "model.layers.35.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
400
+ "model.layers.35.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
401
+ "model.layers.35.input_layernorm.weight": "model-00001-of-00002.safetensors",
402
+ "model.layers.35.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
403
+ "model.norm.weight": "model-00001-of-00002.safetensors"
404
+ }
405
+ }
modeling_qwen3.py ADDED
@@ -0,0 +1,1577 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/qwen3/modular_qwen3.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_qwen3.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ # coding=utf-8
8
+ # Copyright 2025 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+
22
+ from typing import Any, Callable, Optional, Union
23
+
24
+ import torch
25
+ from torch import nn
26
+ import torch.nn.functional as F
27
+ #from flash_attn import flash_attn_func
28
+
29
+ from transformers.activations import ACT2FN
30
+ from transformers.cache_utils import Cache, DynamicCache
31
+ from transformers.generation import GenerationMixin
32
+ from transformers.integrations import use_kernel_forward_from_hub
33
+ from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask
34
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
35
+ from transformers.modeling_layers import (
36
+ GenericForQuestionAnswering,
37
+ GenericForSequenceClassification,
38
+ GenericForTokenClassification,
39
+ GradientCheckpointingLayer,
40
+ )
41
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
42
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
43
+ from transformers.modeling_utils import PreTrainedModel
44
+ from transformers.processing_utils import Unpack
45
+ from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple
46
+ from transformers.utils.deprecation import deprecate_kwarg
47
+ from transformers.utils.generic import check_model_inputs
48
+ from transformers.models.qwen3.configuration_qwen3 import Qwen3Config
49
+ # InfinityLM imports for summary attention
50
+ from .summary_context import SummaryBatchContext, build_summary_context, build_summary_sliding_context
51
+ from summary_attn import summary_attn_func
52
+
53
+
54
+
55
+ def _parse_config_pattern(val):
56
+ """Parse a config value that may be an int, list, or Python pattern string like '([4096]*1+[128]*3)*9'."""
57
+ if isinstance(val, list):
58
+ return val
59
+ if isinstance(val, str):
60
+ return eval(val)
61
+ return val
62
+
63
+
64
+ @use_kernel_forward_from_hub("RMSNorm")
65
+ class Qwen3RMSNorm(nn.Module):
66
+ def __init__(self, hidden_size, eps: float = 1e-6) -> None:
67
+ """
68
+ Qwen3RMSNorm is equivalent to T5LayerNorm
69
+ """
70
+ super().__init__()
71
+ self.weight = nn.Parameter(torch.ones(hidden_size))
72
+ self.variance_epsilon = eps
73
+
74
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
75
+ input_dtype = hidden_states.dtype
76
+ hidden_states = hidden_states.to(torch.float32)
77
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
78
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
79
+ return self.weight * hidden_states.to(input_dtype)
80
+
81
+ def extra_repr(self):
82
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
83
+
84
+
85
+ class Qwen3RingBufferCache:
86
+ """
87
+ Ring buffer KV cache with summary support.
88
+
89
+ Two strategies based on per-layer sliding_chunk_num:
90
+ - Large window layers (is_large_window=True): append-only buffer storing only text KV.
91
+ Summary KV is NOT stored since text tokens attend to all text KV directly.
92
+ - Small window layers (is_large_window=False):
93
+ Three buffers:
94
+ 1. key_cache: [ring(ws) | old_summaries(growing) | chunk_mirror(≤C)]
95
+ → attention input, steady state is a single contiguous slice
96
+ 2. new_summary_buf: ring buffer of size scn, stores summaries whose text
97
+ is still in the window (not needed for attention)
98
+ 3. chunk_buf: size C, holds current chunk's text KV
99
+
100
+ RoPE position information is baked into KV, so physical order doesn't matter.
101
+ """
102
+
103
+ is_compileable = False
104
+ _SUMMARY_INIT_CAP = 512
105
+ _APPEND_HEADROOM = 1024
106
+
107
+ def __init__(self, config: Qwen3Config, sliding_chunk_nums: list[int]):
108
+ super().__init__()
109
+ self.summary_chunk_size = getattr(config, "summary_chunk_size", 0)
110
+ self.summary_token_num = getattr(config, "summary_token_num", 0)
111
+ self.num_hidden_layers = config.num_hidden_layers
112
+
113
+ self.sliding_chunk_nums = sliding_chunk_nums
114
+ large_window_threshold = min(sliding_chunk_nums) * self.summary_chunk_size
115
+ self.is_large_window = [sv * self.summary_chunk_size > large_window_threshold for sv in sliding_chunk_nums]
116
+ self.window_sizes = [sv * self.summary_chunk_size for sv in sliding_chunk_nums]
117
+
118
+ self.key_cache = [None for _ in range(config.num_hidden_layers)]
119
+ self.value_cache = [None for _ in range(config.num_hidden_layers)]
120
+
121
+ # Large window: append-only
122
+ self._text_len = [0] * config.num_hidden_layers
123
+ self._capacity = [0] * config.num_hidden_layers
124
+
125
+ # Small window: ring buffer + summary
126
+ self._window_write_ptr = [0] * config.num_hidden_layers
127
+ self._n_valid_window = [0] * config.num_hidden_layers
128
+ self._old_summary_len = [0] * config.num_hidden_layers # old summaries in key_cache
129
+ self._old_summary_cap = [0] * config.num_hidden_layers
130
+
131
+ # New summary ring buffer (small window only): summaries whose text is still in window
132
+ self._new_sum_key_buf = [None for _ in range(config.num_hidden_layers)]
133
+ self._new_sum_value_buf = [None for _ in range(config.num_hidden_layers)]
134
+ self._new_sum_len = [0] * config.num_hidden_layers # how many filled (≤ scn)
135
+ self._new_sum_write_ptr = [0] * config.num_hidden_layers # ring write pointer
136
+
137
+ # Current chunk buffer (small window only): holds partial chunk text KV
138
+ self._chunk_key_buf = [None for _ in range(config.num_hidden_layers)]
139
+ self._chunk_value_buf = [None for _ in range(config.num_hidden_layers)]
140
+ self._chunk_buf_len = [0] * config.num_hidden_layers
141
+
142
+ # Common
143
+ self.cur_chunk_sizes = [0] * config.num_hidden_layers
144
+ self.true_tokens = [0] * config.num_hidden_layers
145
+ self._total_chunks = [0] * config.num_hidden_layers # completed chunks count
146
+ self._reorganized = False
147
+
148
+ def __len__(self):
149
+ return self.num_hidden_layers
150
+
151
+ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
152
+ """Returns nonzero when cache is populated (used to detect prefill vs decode)."""
153
+ if layer_idx >= self.num_hidden_layers:
154
+ return 0
155
+ if self.is_large_window[layer_idx]:
156
+ return self._text_len[layer_idx]
157
+ else:
158
+ return (self._n_valid_window[layer_idx] + self._chunk_buf_len[layer_idx]
159
+ + self._old_summary_len[layer_idx] + self._new_sum_len[layer_idx])
160
+
161
+ def get_cur_chunk_size(self, layer_idx: Optional[int] = None) -> int:
162
+ if layer_idx is None:
163
+ layer_idx = self.num_hidden_layers - 1
164
+ return self.cur_chunk_sizes[layer_idx]
165
+
166
+ def get_true_token_num(self, layer_idx: Optional[int] = None) -> int:
167
+ if layer_idx is None:
168
+ layer_idx = self.num_hidden_layers - 1
169
+ return self.true_tokens[layer_idx]
170
+
171
+ # ── Prefill: standard append (before reorganize) ──
172
+
173
+ def update(
174
+ self,
175
+ key_states: torch.Tensor,
176
+ value_states: torch.Tensor,
177
+ layer_idx: int,
178
+ cache_kwargs: Optional[dict[str, Any]] = None,
179
+ ) -> tuple[torch.Tensor, torch.Tensor]:
180
+ """Append KV during prefill (before reorganize). Returns full KV for prefill attention."""
181
+ add_len = key_states.shape[-2]
182
+ cur_len = self._text_len[layer_idx]
183
+ new_len = cur_len + add_len
184
+
185
+ if self.key_cache[layer_idx] is None:
186
+ cap = new_len + self._APPEND_HEADROOM
187
+ bsz, heads, _, head_dim = key_states.shape
188
+ self.key_cache[layer_idx] = torch.empty(
189
+ bsz, heads, cap, head_dim, dtype=key_states.dtype, device=key_states.device)
190
+ self.value_cache[layer_idx] = torch.empty(
191
+ bsz, heads, cap, head_dim, dtype=value_states.dtype, device=value_states.device)
192
+ self._capacity[layer_idx] = cap
193
+ elif new_len > self._capacity[layer_idx]:
194
+ cap = max(new_len + self._APPEND_HEADROOM, self._capacity[layer_idx] * 2)
195
+ old_k, old_v = self.key_cache[layer_idx], self.value_cache[layer_idx]
196
+ bsz, heads, _, head_dim = old_k.shape
197
+ new_k = torch.empty(bsz, heads, cap, head_dim, dtype=old_k.dtype, device=old_k.device)
198
+ new_v = torch.empty(bsz, heads, cap, head_dim, dtype=old_v.dtype, device=old_v.device)
199
+ new_k[:, :, :cur_len, :].copy_(old_k[:, :, :cur_len, :])
200
+ new_v[:, :, :cur_len, :].copy_(old_v[:, :, :cur_len, :])
201
+ self.key_cache[layer_idx] = new_k
202
+ self.value_cache[layer_idx] = new_v
203
+ self._capacity[layer_idx] = cap
204
+
205
+ self.key_cache[layer_idx][:, :, cur_len:new_len, :].copy_(key_states)
206
+ self.value_cache[layer_idx][:, :, cur_len:new_len, :].copy_(value_states)
207
+ self._text_len[layer_idx] = new_len
208
+
209
+ if self.summary_chunk_size > 0:
210
+ if cache_kwargs and 'summary_mask' in cache_kwargs:
211
+ text_count = add_len - cache_kwargs['summary_mask'][0].sum().item()
212
+ else:
213
+ text_count = add_len
214
+ self.cur_chunk_sizes[layer_idx] += add_len
215
+ self.true_tokens[layer_idx] += text_count
216
+
217
+ return self.key_cache[layer_idx][:, :, :new_len, :], self.value_cache[layer_idx][:, :, :new_len, :]
218
+
219
+ # ── Reorganize after prefill ──
220
+
221
+ def reorganize_after_prefill(self, summary_mask: torch.Tensor):
222
+ """Reorganize all layers from prefill block layout to ring buffer layout."""
223
+ if self._reorganized:
224
+ return
225
+ self._reorganized = True
226
+
227
+ text_mask = ~summary_mask[0]
228
+
229
+ for layer_idx in range(self.num_hidden_layers):
230
+ prefill_len = self._text_len[layer_idx]
231
+ prefill_k = self.key_cache[layer_idx][:, :, :prefill_len, :]
232
+ prefill_v = self.value_cache[layer_idx][:, :, :prefill_len, :]
233
+ bsz, heads, _, head_dim = prefill_k.shape
234
+ device, dtype = prefill_k.device, prefill_k.dtype
235
+
236
+ text_k = prefill_k[:, :, text_mask, :]
237
+ text_v = prefill_v[:, :, text_mask, :]
238
+ n_text = text_k.shape[2]
239
+
240
+ if self.is_large_window[layer_idx]:
241
+ # Large window: keep only text KV
242
+ cap = n_text + self._APPEND_HEADROOM
243
+ new_k = torch.empty(bsz, heads, cap, head_dim, dtype=dtype, device=device)
244
+ new_v = torch.empty(bsz, heads, cap, head_dim, dtype=dtype, device=device)
245
+ new_k[:, :, :n_text, :].copy_(text_k)
246
+ new_v[:, :, :n_text, :].copy_(text_v)
247
+ self.key_cache[layer_idx] = new_k
248
+ self.value_cache[layer_idx] = new_v
249
+ self._text_len[layer_idx] = n_text
250
+ self._capacity[layer_idx] = cap
251
+ else:
252
+ # Small window: split summaries into old (evicted) and new (in window)
253
+ all_summary_k = prefill_k[:, :, summary_mask[0], :]
254
+ all_summary_v = prefill_v[:, :, summary_mask[0], :]
255
+ n_summary = all_summary_k.shape[2]
256
+
257
+ C = self.summary_chunk_size
258
+ ws = self.window_sizes[layer_idx]
259
+ scn = self.sliding_chunk_nums[layer_idx]
260
+
261
+ # Split text into complete chunks + partial remainder
262
+ n_complete_chunks = n_text // C
263
+ n_partial = n_text % C
264
+ n_complete_text = n_complete_chunks * C
265
+
266
+ # Window: last scn complete chunks (or all if fewer)
267
+ n_window_chunks = min(scn, n_complete_chunks)
268
+ n_window_text = n_window_chunks * C
269
+ window_start = n_complete_text - n_window_text
270
+
271
+ # Split summaries: old (text evicted from ring) vs new (text in ring)
272
+ n_old = max(0, n_summary - n_window_chunks)
273
+ n_new = n_summary - n_old
274
+
275
+ # key_cache: [ring(ws) | old_summaries | chunk_mirror(≤C)]
276
+ old_s_cap = max(self._SUMMARY_INIT_CAP, (n_old + 1) * 2)
277
+ total_cap = ws + old_s_cap + C
278
+ new_k = torch.empty(bsz, heads, total_cap, head_dim, dtype=dtype, device=device)
279
+ new_v = torch.empty(bsz, heads, total_cap, head_dim, dtype=dtype, device=device)
280
+
281
+ if n_window_text > 0:
282
+ new_k[:, :, :n_window_text, :].copy_(text_k[:, :, window_start:n_complete_text, :])
283
+ new_v[:, :, :n_window_text, :].copy_(text_v[:, :, window_start:n_complete_text, :])
284
+ self._n_valid_window[layer_idx] = n_window_text
285
+ self._window_write_ptr[layer_idx] = n_window_text % ws
286
+
287
+ # Old summaries go into key_cache after ring
288
+ if n_old > 0:
289
+ new_k[:, :, ws:ws + n_old, :].copy_(all_summary_k[:, :, :n_old, :])
290
+ new_v[:, :, ws:ws + n_old, :].copy_(all_summary_v[:, :, :n_old, :])
291
+ self._old_summary_len[layer_idx] = n_old
292
+ self._old_summary_cap[layer_idx] = old_s_cap
293
+
294
+ # Mirror partial chunk into key_cache after old_summaries
295
+ if n_partial > 0:
296
+ mirror_start = ws + n_old
297
+ new_k[:, :, mirror_start:mirror_start + n_partial, :].copy_(
298
+ text_k[:, :, n_complete_text:, :])
299
+ new_v[:, :, mirror_start:mirror_start + n_partial, :].copy_(
300
+ text_v[:, :, n_complete_text:, :])
301
+
302
+ self.key_cache[layer_idx] = new_k
303
+ self.value_cache[layer_idx] = new_v
304
+ self._capacity[layer_idx] = total_cap
305
+ self._text_len[layer_idx] = 0
306
+
307
+ # New summary ring buffer
308
+ ns_buf_k = torch.empty(bsz, heads, scn, head_dim, dtype=dtype, device=device)
309
+ ns_buf_v = torch.empty(bsz, heads, scn, head_dim, dtype=dtype, device=device)
310
+ if n_new > 0:
311
+ ns_buf_k[:, :, :n_new, :].copy_(all_summary_k[:, :, n_old:, :])
312
+ ns_buf_v[:, :, :n_new, :].copy_(all_summary_v[:, :, n_old:, :])
313
+ self._new_sum_key_buf[layer_idx] = ns_buf_k
314
+ self._new_sum_value_buf[layer_idx] = ns_buf_v
315
+ self._new_sum_len[layer_idx] = n_new
316
+ self._new_sum_write_ptr[layer_idx] = n_new % scn
317
+
318
+ # Chunk buffer for partial remainder
319
+ chunk_buf_k = torch.empty(bsz, heads, C, head_dim, dtype=dtype, device=device)
320
+ chunk_buf_v = torch.empty(bsz, heads, C, head_dim, dtype=dtype, device=device)
321
+ if n_partial > 0:
322
+ chunk_buf_k[:, :, :n_partial, :].copy_(text_k[:, :, n_complete_text:, :])
323
+ chunk_buf_v[:, :, :n_partial, :].copy_(text_v[:, :, n_complete_text:, :])
324
+ self._chunk_key_buf[layer_idx] = chunk_buf_k
325
+ self._chunk_value_buf[layer_idx] = chunk_buf_v
326
+ self._chunk_buf_len[layer_idx] = n_partial
327
+
328
+ block = self.summary_chunk_size + self.summary_token_num
329
+ for layer_idx in range(self.num_hidden_layers):
330
+ self.cur_chunk_sizes[layer_idx] = self.cur_chunk_sizes[layer_idx] % block
331
+ self._total_chunks[layer_idx] = (
332
+ self._old_summary_len[layer_idx] + self._new_sum_len[layer_idx]
333
+ if not self.is_large_window[layer_idx]
334
+ else (self.true_tokens[layer_idx] // self.summary_chunk_size)
335
+ )
336
+
337
+ # ── Decode: text token update ──
338
+
339
+ def update_text(self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int):
340
+ """Write a single text token KV during decode."""
341
+ if self.is_large_window[layer_idx]:
342
+ cur = self._text_len[layer_idx]
343
+ new_len = cur + 1
344
+ if new_len > self._capacity[layer_idx]:
345
+ cap = max(new_len + self._APPEND_HEADROOM, self._capacity[layer_idx] * 2)
346
+ old_k, old_v = self.key_cache[layer_idx], self.value_cache[layer_idx]
347
+ bsz, heads, _, head_dim = old_k.shape
348
+ new_k = torch.empty(bsz, heads, cap, head_dim, dtype=old_k.dtype, device=old_k.device)
349
+ new_v = torch.empty(bsz, heads, cap, head_dim, dtype=old_v.dtype, device=old_v.device)
350
+ new_k[:, :, :cur, :].copy_(old_k[:, :, :cur, :])
351
+ new_v[:, :, :cur, :].copy_(old_v[:, :, :cur, :])
352
+ self.key_cache[layer_idx] = new_k
353
+ self.value_cache[layer_idx] = new_v
354
+ self._capacity[layer_idx] = cap
355
+ self.key_cache[layer_idx][:, :, cur:new_len, :].copy_(key_states)
356
+ self.value_cache[layer_idx][:, :, cur:new_len, :].copy_(value_states)
357
+ self._text_len[layer_idx] = new_len
358
+ else:
359
+ # Write only to key_cache mirror region (chunk_buf eliminated)
360
+ ws = self.window_sizes[layer_idx]
361
+ n_old = self._old_summary_len[layer_idx]
362
+ pos = self._chunk_buf_len[layer_idx]
363
+ mirror_pos = ws + n_old + pos
364
+ self.key_cache[layer_idx][:, :, mirror_pos:mirror_pos+1, :].copy_(key_states)
365
+ self.value_cache[layer_idx][:, :, mirror_pos:mirror_pos+1, :].copy_(value_states)
366
+ self._chunk_buf_len[layer_idx] = pos + 1
367
+
368
+ self.cur_chunk_sizes[layer_idx] += 1
369
+ self.true_tokens[layer_idx] += 1
370
+
371
+ # ── Decode: summary token update ──
372
+
373
+ def update_summary(self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int):
374
+ """Write summary token KV during decode (chunk boundary).
375
+
376
+ Large window: skip.
377
+ Small window (order matters — flush mirror before evict to avoid clobbering):
378
+ 1. Flush mirror region → ring
379
+ 2. Evict oldest new_summary → old_summary in key_cache (if full)
380
+ 3. Write new summary → new_summary_buf
381
+ """
382
+ n_summary = key_states.shape[2]
383
+
384
+ if self.is_large_window[layer_idx]:
385
+ self.cur_chunk_sizes[layer_idx] += n_summary
386
+ self._total_chunks[layer_idx] += n_summary
387
+ return
388
+
389
+ C = self.summary_chunk_size
390
+ ws = self.window_sizes[layer_idx]
391
+ scn = self.sliding_chunk_nums[layer_idx]
392
+ cbl = self._chunk_buf_len[layer_idx]
393
+ ptr = self._window_write_ptr[layer_idx]
394
+ n_old = self._old_summary_len[layer_idx]
395
+
396
+ # Step 1: Flush mirror region → ring (must happen before evict touches mirror[0])
397
+ mirror_start = ws + n_old
398
+ if ptr + cbl <= ws:
399
+ self.key_cache[layer_idx][:, :, ptr:ptr + cbl, :].copy_(
400
+ self.key_cache[layer_idx][:, :, mirror_start:mirror_start + cbl, :])
401
+ self.value_cache[layer_idx][:, :, ptr:ptr + cbl, :].copy_(
402
+ self.value_cache[layer_idx][:, :, mirror_start:mirror_start + cbl, :])
403
+ else:
404
+ first = ws - ptr
405
+ self.key_cache[layer_idx][:, :, ptr:ws, :].copy_(
406
+ self.key_cache[layer_idx][:, :, mirror_start:mirror_start + first, :])
407
+ self.value_cache[layer_idx][:, :, ptr:ws, :].copy_(
408
+ self.value_cache[layer_idx][:, :, mirror_start:mirror_start + first, :])
409
+ rest = cbl - first
410
+ self.key_cache[layer_idx][:, :, :rest, :].copy_(
411
+ self.key_cache[layer_idx][:, :, mirror_start + first:mirror_start + cbl, :])
412
+ self.value_cache[layer_idx][:, :, :rest, :].copy_(
413
+ self.value_cache[layer_idx][:, :, mirror_start + first:mirror_start + cbl, :])
414
+
415
+ self._window_write_ptr[layer_idx] = (ptr + cbl) % ws
416
+ if self._n_valid_window[layer_idx] < ws:
417
+ self._n_valid_window[layer_idx] = min(ws, self._n_valid_window[layer_idx] + cbl)
418
+ self._chunk_buf_len[layer_idx] = 0
419
+
420
+ # Step 2: Evict oldest new_summary → old_summary (now safe — mirror already flushed)
421
+ if self._new_sum_len[layer_idx] >= scn:
422
+ read_ptr = self._new_sum_write_ptr[layer_idx]
423
+ old_dst = ws + n_old # == mirror_start, but mirror data is already in ring
424
+
425
+ # Check capacity for old_summary growth
426
+ needed = old_dst + 1 + C # +1 for new old_sum, +C for future chunk mirror
427
+ if needed > self._capacity[layer_idx]:
428
+ new_s_cap = max(self._old_summary_cap[layer_idx] * 2, n_old + self._SUMMARY_INIT_CAP)
429
+ new_total = ws + new_s_cap + C
430
+ old_k, old_v = self.key_cache[layer_idx], self.value_cache[layer_idx]
431
+ bsz, heads, _, head_dim = old_k.shape
432
+ nk = torch.empty(bsz, heads, new_total, head_dim, dtype=old_k.dtype, device=old_k.device)
433
+ nv = torch.empty(bsz, heads, new_total, head_dim, dtype=old_v.dtype, device=old_v.device)
434
+ copy_len = ws + n_old
435
+ nk[:, :, :copy_len, :].copy_(old_k[:, :, :copy_len, :])
436
+ nv[:, :, :copy_len, :].copy_(old_v[:, :, :copy_len, :])
437
+ self.key_cache[layer_idx] = nk
438
+ self.value_cache[layer_idx] = nv
439
+ self._old_summary_cap[layer_idx] = new_s_cap
440
+ self._capacity[layer_idx] = new_total
441
+
442
+ self.key_cache[layer_idx][:, :, old_dst:old_dst+1, :].copy_(
443
+ self._new_sum_key_buf[layer_idx][:, :, read_ptr:read_ptr+1, :])
444
+ self.value_cache[layer_idx][:, :, old_dst:old_dst+1, :].copy_(
445
+ self._new_sum_value_buf[layer_idx][:, :, read_ptr:read_ptr+1, :])
446
+ self._old_summary_len[layer_idx] += 1
447
+
448
+ # Step 3: Write new summary to new_summary_buf (overwrite oldest slot)
449
+ w_ptr = self._new_sum_write_ptr[layer_idx]
450
+ self._new_sum_key_buf[layer_idx][:, :, w_ptr:w_ptr+1, :].copy_(key_states)
451
+ self._new_sum_value_buf[layer_idx][:, :, w_ptr:w_ptr+1, :].copy_(value_states)
452
+ self._new_sum_write_ptr[layer_idx] = (w_ptr + 1) % scn
453
+ if self._new_sum_len[layer_idx] < scn:
454
+ self._new_sum_len[layer_idx] += 1
455
+
456
+ self.cur_chunk_sizes[layer_idx] += n_summary
457
+ self._total_chunks[layer_idx] += n_summary
458
+
459
+ # ── Decode: get KV for attention ──
460
+
461
+ def get_attention_kv(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]:
462
+ """Get full KV for text token attention.
463
+
464
+ Large window: buffer[:text_len]
465
+ Small window (steady state): key_cache[:ws + n_old + cbl] — single slice, zero cat
466
+ """
467
+ if self.is_large_window[layer_idx]:
468
+ tl = self._text_len[layer_idx]
469
+ return (self.key_cache[layer_idx][:, :, :tl, :],
470
+ self.value_cache[layer_idx][:, :, :tl, :])
471
+
472
+ ws = self.window_sizes[layer_idx]
473
+ nv = self._n_valid_window[layer_idx]
474
+ cbl = self._chunk_buf_len[layer_idx]
475
+ n_old = self._old_summary_len[layer_idx]
476
+
477
+ # Steady state: ring full → [ring(ws) | old_sums(n_old) | chunk_mirror(cbl)] contiguous
478
+ if nv == ws:
479
+ end = ws + n_old + cbl
480
+ return (self.key_cache[layer_idx][:, :, :end, :],
481
+ self.value_cache[layer_idx][:, :, :end, :])
482
+
483
+ # Warmup: ring not full, [nv:ws] is gap → cat
484
+ parts_k, parts_v = [], []
485
+ if nv > 0:
486
+ parts_k.append(self.key_cache[layer_idx][:, :, :nv, :])
487
+ parts_v.append(self.value_cache[layer_idx][:, :, :nv, :])
488
+ if cbl > 0:
489
+ mirror_start = ws + n_old
490
+ parts_k.append(self.key_cache[layer_idx][:, :, mirror_start:mirror_start + cbl, :])
491
+ parts_v.append(self.value_cache[layer_idx][:, :, mirror_start:mirror_start + cbl, :])
492
+ if n_old > 0:
493
+ parts_k.append(self.key_cache[layer_idx][:, :, ws:ws + n_old, :])
494
+ parts_v.append(self.value_cache[layer_idx][:, :, ws:ws + n_old, :])
495
+ if len(parts_k) == 1:
496
+ return parts_k[0], parts_v[0]
497
+ return torch.cat(parts_k, dim=2), torch.cat(parts_v, dim=2)
498
+
499
+ def get_current_chunk_kv(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]:
500
+ """Get KV of the current chunk's C text tokens for summary token attention."""
501
+ C = self.summary_chunk_size
502
+ if self.is_large_window[layer_idx]:
503
+ tl = self._text_len[layer_idx]
504
+ return (self.key_cache[layer_idx][:, :, tl - C:tl, :],
505
+ self.value_cache[layer_idx][:, :, tl - C:tl, :])
506
+ else:
507
+ ws = self.window_sizes[layer_idx]
508
+ n_old = self._old_summary_len[layer_idx]
509
+ cbl = self._chunk_buf_len[layer_idx]
510
+ mirror_start = ws + n_old
511
+ return (self.key_cache[layer_idx][:, :, mirror_start:mirror_start + cbl, :],
512
+ self.value_cache[layer_idx][:, :, mirror_start:mirror_start + cbl, :])
513
+
514
+ def reset_chunk_counter(self):
515
+ """Reset chunk counters after a chunk boundary step completes."""
516
+ block = self.summary_chunk_size + self.summary_token_num
517
+ for layer_idx in range(self.num_hidden_layers):
518
+ if self.cur_chunk_sizes[layer_idx] >= block:
519
+ self.cur_chunk_sizes[layer_idx] %= block
520
+
521
+
522
+ class Qwen3MLP(nn.Module):
523
+ def __init__(self, config):
524
+ super().__init__()
525
+ self.config = config
526
+ self.hidden_size = config.hidden_size
527
+ self.intermediate_size = config.intermediate_size
528
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
529
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
530
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
531
+ self.act_fn = ACT2FN[config.hidden_act]
532
+
533
+ def forward(self, x):
534
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
535
+ return down_proj
536
+
537
+
538
+ def rotate_half(x):
539
+ """Rotates half the hidden dims of the input."""
540
+ x1 = x[..., : x.shape[-1] // 2]
541
+ x2 = x[..., x.shape[-1] // 2 :]
542
+ return torch.cat((-x2, x1), dim=-1)
543
+
544
+
545
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
546
+ """Applies Rotary Position Embedding to the query and key tensors.
547
+
548
+ Args:
549
+ q (`torch.Tensor`): The query tensor.
550
+ k (`torch.Tensor`): The key tensor.
551
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
552
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
553
+ position_ids (`torch.Tensor`, *optional*):
554
+ Deprecated and unused.
555
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
556
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
557
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
558
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
559
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
560
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
561
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
562
+ Returns:
563
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
564
+ """
565
+ cos = cos.unsqueeze(unsqueeze_dim)
566
+ sin = sin.unsqueeze(unsqueeze_dim)
567
+ q_embed = (q * cos) + (rotate_half(q) * sin)
568
+ k_embed = (k * cos) + (rotate_half(k) * sin)
569
+ return q_embed, k_embed
570
+
571
+
572
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
573
+ """
574
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
575
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
576
+ """
577
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
578
+ if n_rep == 1:
579
+ return hidden_states
580
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
581
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
582
+
583
+
584
+ def _sdpa_attention_forward(
585
+ module: nn.Module,
586
+ query: torch.Tensor,
587
+ key: torch.Tensor,
588
+ value: torch.Tensor,
589
+ attention_mask: Optional[torch.Tensor],
590
+ scaling: float,
591
+ dropout: float = 0.0,
592
+ **kwargs: Unpack[TransformersKwargs],
593
+ ):
594
+ key_states = repeat_kv(key, module.num_key_value_groups)
595
+ value_states = repeat_kv(value, module.num_key_value_groups)
596
+ attn_output = F.scaled_dot_product_attention(
597
+ query,
598
+ key_states,
599
+ value_states,
600
+ attn_mask=None,
601
+ dropout_p=dropout,
602
+ is_causal=False,
603
+ )
604
+ attn_output = attn_output.transpose(1, 2).contiguous()
605
+ return attn_output, None
606
+
607
+
608
+
609
+ class Qwen3Attention(nn.Module):
610
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
611
+
612
+ def __init__(self, config: Qwen3Config, layer_idx: int):
613
+ super().__init__()
614
+ self.config = config
615
+ self.layer_idx = layer_idx
616
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
617
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
618
+ self.scaling = self.head_dim**-0.5
619
+ self.attention_dropout = config.attention_dropout
620
+
621
+ self.q_proj = nn.Linear(
622
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
623
+ )
624
+ self.k_proj = nn.Linear(
625
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
626
+ )
627
+ self.v_proj = nn.Linear(
628
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
629
+ )
630
+ self.o_proj = nn.Linear(
631
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
632
+ )
633
+ self.q_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim!
634
+ self.k_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape
635
+ self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None
636
+
637
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
638
+ def forward(
639
+ self,
640
+ hidden_states: torch.Tensor,
641
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
642
+ attention_mask: Optional[torch.Tensor],
643
+ past_key_values: Optional[Cache] = None,
644
+ cache_position: Optional[torch.LongTensor] = None,
645
+ **kwargs: Unpack[FlashAttentionKwargs],
646
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
647
+ input_shape = hidden_states.shape[:-1]
648
+ hidden_shape = (*input_shape, -1, self.head_dim)
649
+
650
+
651
+ query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
652
+ key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
653
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
654
+
655
+ cos, sin = position_embeddings
656
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
657
+
658
+ if past_key_values is not None:
659
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
660
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
661
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
662
+
663
+ attn_output, attn_weights = _sdpa_attention_forward(
664
+ self,
665
+ query_states,
666
+ key_states,
667
+ value_states,
668
+ attention_mask,
669
+ dropout=0.0 if not self.training else self.attention_dropout,
670
+ scaling=self.scaling,
671
+ **kwargs,
672
+ )
673
+
674
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
675
+ attn_output = self.o_proj(attn_output)
676
+ return attn_output, attn_weights
677
+
678
+
679
+ class Qwen3SummaryAttention(Qwen3Attention):
680
+ """
681
+ Summary-aware variant of Qwen3Attention: uses a sliding summary mask.
682
+ """
683
+
684
+ def __init__(self, config: Qwen3Config, layer_idx: int):
685
+ super().__init__(config, layer_idx)
686
+ self.summary_chunk_size = getattr(self.config, "summary_chunk_size", 0)
687
+ self.summary_token_num = getattr(self.config, "summary_token_num", 0)
688
+
689
+ # Cache sliding_chunk_num to avoid eval() on every forward call
690
+ val = getattr(config, "summary_sliding_chunk_num", 0) or 0
691
+ val = _parse_config_pattern(val)
692
+ if isinstance(val, list):
693
+ self._sliding_chunk_num = val[layer_idx]
694
+ else:
695
+ self._sliding_chunk_num = int(val)
696
+
697
+ if config.summary_independent_parameters:
698
+ self.q_proj_summary = nn.Linear(
699
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
700
+ )
701
+ self.k_proj_summary = nn.Linear(
702
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
703
+ )
704
+ self.v_proj_summary = nn.Linear(
705
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
706
+ )
707
+
708
+ def _get_sliding_chunk_num(self):
709
+ return self._sliding_chunk_num
710
+
711
+ def get_query_key_value_tensors(self, hidden_states):
712
+ input_shape = hidden_states.shape[:-1]
713
+ hidden_shape = (*input_shape, -1, self.head_dim)
714
+ query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
715
+ key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
716
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
717
+
718
+ return query_states, key_states, value_states
719
+
720
+ def get_query_key_value_tensors_summary(self, hidden_states):
721
+ input_shape = hidden_states.shape[:-1]
722
+ hidden_shape = (*input_shape, -1, self.head_dim)
723
+ query_states = self.q_norm(self.q_proj_summary(hidden_states).view(hidden_shape)).transpose(1, 2)
724
+ key_states = self.k_norm(self.k_proj_summary(hidden_states).view(hidden_shape)).transpose(1, 2)
725
+ value_states = self.v_proj_summary(hidden_states).view(hidden_shape).transpose(1, 2)
726
+
727
+ return query_states, key_states, value_states
728
+
729
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
730
+ def forward(
731
+ self,
732
+ hidden_states: torch.Tensor,
733
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
734
+ attention_mask: Optional[torch.Tensor] = None,
735
+ past_key_values: Optional[Cache] = None,
736
+ cache_position: Optional[torch.LongTensor] = None,
737
+ summary_ctx: Optional[SummaryBatchContext] = None,
738
+ **kwargs,
739
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
740
+ input_shape = hidden_states.shape[:-1]
741
+ if hidden_states.size(0) != 1:
742
+ raise ValueError("Summary sliding attention only supports batch size=1.")
743
+
744
+ # Compute q/k/v for the full sequence once.
745
+ if self.config.summary_independent_parameters:
746
+ if summary_ctx is None:
747
+ raise ValueError("summary_ctx is required when using summary_independent_parameters.")
748
+ summary_mask = summary_ctx.summary_mask
749
+ summary_pos = summary_mask[0]
750
+ assert (summary_mask == summary_mask[0:1]).all()
751
+
752
+ if self.config.mix_coeff == 0:
753
+ # When mix_coeff=0, summary projections have no effect — skip clone + extra linear
754
+ query_states, key_states, value_states = self.get_query_key_value_tensors(hidden_states)
755
+ else:
756
+ query, key, value = self.get_query_key_value_tensors(hidden_states)
757
+
758
+ query_states = query.clone()
759
+ key_states = key.clone()
760
+ value_states = value.clone()
761
+
762
+ hs_summary = hidden_states[:, summary_pos, :]
763
+ if hs_summary.size(1) > 0:
764
+ base_q_summary = query[:, :, summary_pos, :]
765
+ base_k_summary = key[:, :, summary_pos, :]
766
+ base_v_summary = value[:, :, summary_pos, :]
767
+
768
+ q_s, k_s, v_s = self.get_query_key_value_tensors_summary(hs_summary)
769
+
770
+ q_s = self.config.mix_coeff * q_s + (1.0 - self.config.mix_coeff) * base_q_summary
771
+ k_s = self.config.mix_coeff * k_s + (1.0 - self.config.mix_coeff) * base_k_summary
772
+ v_s = self.config.mix_coeff * v_s + (1.0 - self.config.mix_coeff) * base_v_summary
773
+
774
+ query_states[:, :, summary_pos, :] = q_s
775
+ key_states[:, :, summary_pos, :] = k_s
776
+ value_states[:, :, summary_pos, :] = v_s
777
+ else:
778
+ query_states, key_states, value_states = self.get_query_key_value_tensors(hidden_states)
779
+
780
+ cos, sin = position_embeddings
781
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
782
+
783
+ query_len = query_states.shape[2]
784
+ is_prefill = past_key_values is None or not past_key_values._reorganized
785
+
786
+ if is_prefill:
787
+ # Prefill: use standard append and summary_attn_func
788
+ if past_key_values is not None:
789
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
790
+ if summary_ctx is not None:
791
+ cache_kwargs["summary_mask"] = summary_ctx.summary_mask
792
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
793
+
794
+ with torch.cuda.device(query_states.device):
795
+ attn_output, attn_weights = summary_attn_func(
796
+ query_states.transpose(1,2).contiguous(),
797
+ key_states.transpose(1,2).contiguous(),
798
+ value_states.transpose(1,2).contiguous(),
799
+ self.summary_chunk_size,
800
+ self.summary_token_num,
801
+ self._get_sliding_chunk_num(),
802
+ summary_pos=summary_ctx.summary_mask.squeeze()
803
+ )
804
+ elif query_len == 1:
805
+ # Single text token decode: write to cache, attend to full buffer
806
+ past_key_values.update_text(key_states, value_states, self.layer_idx)
807
+ k_full, v_full = past_key_values.get_attention_kv(self.layer_idx)
808
+ attn_output, attn_weights = _sdpa_attention_forward(
809
+ self,
810
+ query_states,
811
+ k_full,
812
+ v_full,
813
+ None,
814
+ dropout=0.0 if not self.training else self.attention_dropout,
815
+ scaling=self.scaling,
816
+ sliding_window=self.sliding_window,
817
+ **kwargs,
818
+ )
819
+ else:
820
+ # Chunk boundary: query = [text_token, summary_token(s)]
821
+ # Split into text (first token) and summary (remaining tokens)
822
+ q_text = query_states[:, :, :1, :]
823
+ q_summary = query_states[:, :, 1:, :]
824
+ k_text = key_states[:, :, :1, :]
825
+ v_text = value_states[:, :, :1, :]
826
+ k_summary = key_states[:, :, 1:, :]
827
+ v_summary = value_states[:, :, 1:, :]
828
+
829
+ # 1. Write text token to cache, get full KV, run text attention
830
+ past_key_values.update_text(k_text, v_text, self.layer_idx)
831
+ k_full, v_full = past_key_values.get_attention_kv(self.layer_idx)
832
+ text_out, _ = _sdpa_attention_forward(
833
+ self,
834
+ q_text,
835
+ k_full,
836
+ v_full,
837
+ None,
838
+ dropout=0.0 if not self.training else self.attention_dropout,
839
+ scaling=self.scaling,
840
+ sliding_window=self.sliding_window,
841
+ **kwargs,
842
+ )
843
+
844
+ # 2. Summary attention: attend to current chunk's C text tokens + own KV (self-attention)
845
+ # The original model includes the summary token's own KV in its attention
846
+ # (causal within summary positions). With S=1, this is just self-attention.
847
+ k_chunk, v_chunk = past_key_values.get_current_chunk_kv(self.layer_idx)
848
+ k_chunk_with_self = torch.cat([k_chunk, k_summary], dim=2)
849
+ v_chunk_with_self = torch.cat([v_chunk, v_summary], dim=2)
850
+ summary_out, _ = _sdpa_attention_forward(
851
+ self,
852
+ q_summary,
853
+ k_chunk_with_self,
854
+ v_chunk_with_self,
855
+ None,
856
+ dropout=0.0 if not self.training else self.attention_dropout,
857
+ scaling=self.scaling,
858
+ sliding_window=self.sliding_window,
859
+ **kwargs,
860
+ )
861
+
862
+ # 3. Write summary KV to cache
863
+ past_key_values.update_summary(k_summary, v_summary, self.layer_idx)
864
+
865
+ attn_output = torch.cat([text_out, summary_out], dim=2)
866
+ attn_weights = None
867
+
868
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
869
+ attn_output = self.o_proj(attn_output)
870
+ return attn_output, attn_weights
871
+
872
+
873
+ class Qwen3DecoderLayer(GradientCheckpointingLayer):
874
+ def __init__(self, config: Qwen3Config, layer_idx: int):
875
+ super().__init__()
876
+ self.config = config
877
+ self.hidden_size = config.hidden_size
878
+
879
+ # Use SummaryAttention if enabled in config
880
+ if getattr(config, "use_summary_attention", False) is True and config.summary_layer_freq[layer_idx] == 1:
881
+ self.self_attn = Qwen3SummaryAttention(config=config, layer_idx=layer_idx)
882
+ elif getattr(config, "use_summary_attention", False) is False and config.summary_layer_freq[layer_idx] == 0:
883
+ self.self_attn = Qwen3Attention(config=config, layer_idx=layer_idx)
884
+ else:
885
+ raise ValueError(f'Check config.summary_layer_freq {config.summary_layer_freq} and config.use_summary_attention {config.use_summary_attention}')
886
+
887
+ self.mlp = Qwen3MLP(config)
888
+ self.input_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
889
+ if getattr(config, "summary_independent_attention_layernorm", False):
890
+ self.input_layernorm_summary = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
891
+ self.post_attention_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
892
+ self.attention_type = config.layer_types[layer_idx]
893
+
894
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
895
+ def forward(
896
+ self,
897
+ hidden_states: torch.Tensor,
898
+ attention_mask: Optional[torch.Tensor] = None,
899
+ position_ids: Optional[torch.LongTensor] = None,
900
+ past_key_values: Optional[Cache] = None,
901
+ use_cache: Optional[bool] = False,
902
+ cache_position: Optional[torch.LongTensor] = None,
903
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
904
+ summary_ctx: Optional[SummaryBatchContext] = None,
905
+ **kwargs: Unpack[TransformersKwargs],
906
+ ) -> torch.Tensor:
907
+ residual = hidden_states
908
+ if getattr(self.config, "summary_independent_attention_layernorm", False):
909
+ summary_mask = summary_ctx.summary_mask
910
+ assert (summary_mask == summary_mask[0:1]).all(), \
911
+ "summary_mask must be identical across batch"
912
+ hidden_states = self.input_layernorm(hidden_states)
913
+ if summary_mask.any():
914
+ hidden_summary = residual[:, summary_mask[0].to(residual.device), :]
915
+ hidden_summary = self.input_layernorm_summary(hidden_summary)
916
+ hidden_states[:, summary_mask[0], :] = hidden_summary
917
+ else:
918
+ hidden_states = self.input_layernorm(hidden_states)
919
+
920
+ # Self Attention - pass summary_ctx if using summary attention
921
+ attn_kwargs = {
922
+ "hidden_states": hidden_states,
923
+ "attention_mask": attention_mask,
924
+ "position_ids": position_ids,
925
+ "past_key_values": past_key_values,
926
+ "use_cache": use_cache,
927
+ "cache_position": cache_position,
928
+ "position_embeddings": position_embeddings,
929
+ **kwargs,
930
+ }
931
+ if isinstance(self.self_attn, Qwen3SummaryAttention):
932
+ attn_kwargs["summary_ctx"] = summary_ctx
933
+
934
+ hidden_states, _ = self.self_attn(**attn_kwargs)
935
+ hidden_states = residual + hidden_states
936
+
937
+ # Fully Connected
938
+ residual = hidden_states
939
+ hidden_states = self.post_attention_layernorm(hidden_states)
940
+ hidden_states = self.mlp(hidden_states)
941
+ hidden_states = residual + hidden_states
942
+ return hidden_states
943
+
944
+
945
+ @auto_docstring
946
+ class Qwen3PreTrainedModel(PreTrainedModel):
947
+ config: Qwen3Config
948
+ base_model_prefix = "model"
949
+ supports_gradient_checkpointing = True
950
+ _no_split_modules = ["Qwen3DecoderLayer"]
951
+ _skip_keys_device_placement = ["past_key_values"]
952
+ _supports_flash_attn = True
953
+ _supports_sdpa = True
954
+ _supports_flex_attn = True
955
+
956
+ _can_compile_fullgraph = True
957
+ _supports_attention_backend = True
958
+ _can_record_outputs = {
959
+ "hidden_states": Qwen3DecoderLayer,
960
+ "attentions": Qwen3Attention,
961
+ }
962
+
963
+
964
+ class Qwen3RotaryEmbedding(nn.Module):
965
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
966
+
967
+ def __init__(self, config: Qwen3Config, device=None):
968
+ super().__init__()
969
+ self.max_seq_len_cached = config.max_position_embeddings
970
+ self.original_max_seq_len = config.max_position_embeddings
971
+
972
+ self.config = config
973
+
974
+ self.rope_type = self.config.rope_parameters["rope_type"]
975
+ rope_init_fn: Callable = self.compute_default_rope_parameters
976
+ if self.rope_type != "default":
977
+ rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
978
+ inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
979
+
980
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
981
+ self.original_inv_freq = inv_freq
982
+
983
+ @staticmethod
984
+ def compute_default_rope_parameters(
985
+ config: Optional[Qwen3Config] = None,
986
+ device: Optional["torch.device"] = None,
987
+ seq_len: Optional[int] = None,
988
+ ) -> tuple["torch.Tensor", float]:
989
+ """
990
+ Computes the inverse frequencies according to the original RoPE implementation
991
+ Args:
992
+ config ([`~transformers.PreTrainedConfig`]):
993
+ The model configuration.
994
+ device (`torch.device`):
995
+ The device to use for initialization of the inverse frequencies.
996
+ seq_len (`int`, *optional*):
997
+ The current sequence length. Unused for this type of RoPE.
998
+ Returns:
999
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
1000
+ post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
1001
+ """
1002
+ base = config.rope_parameters["rope_theta"]
1003
+ dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
1004
+
1005
+ attention_factor = 1.0 # Unused in this type of RoPE
1006
+
1007
+ # Compute the inverse frequencies
1008
+ inv_freq = 1.0 / (
1009
+ base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
1010
+ )
1011
+ return inv_freq, attention_factor
1012
+
1013
+ @torch.no_grad()
1014
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
1015
+ def forward(self, x, position_ids):
1016
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
1017
+ position_ids_expanded = position_ids[:, None, :].float()
1018
+
1019
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
1020
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
1021
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
1022
+ emb = torch.cat((freqs, freqs), dim=-1)
1023
+ cos = emb.cos() * self.attention_scaling
1024
+ sin = emb.sin() * self.attention_scaling
1025
+
1026
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
1027
+
1028
+
1029
+ @auto_docstring
1030
+ class Qwen3Model(Qwen3PreTrainedModel):
1031
+ def __init__(self, config: Qwen3Config):
1032
+ super().__init__(config)
1033
+ self.padding_idx = config.pad_token_id
1034
+ self.vocab_size = config.vocab_size
1035
+ if not getattr(config, "summary_layer_freq", False):
1036
+ if config.use_summary_attention:
1037
+ config.summary_layer_freq = [1]*config.num_hidden_layers
1038
+ else:
1039
+ config.summary_layer_freq = [0]*config.num_hidden_layers
1040
+ Warning(f'Please set config.summary_layer_freq, temp set summary_layer_freq = {config.num_hidden_layers}')
1041
+ else:
1042
+ config.summary_layer_freq = _parse_config_pattern(config.summary_layer_freq)
1043
+
1044
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
1045
+ self.layers = nn.ModuleList(
1046
+ [Qwen3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
1047
+ )
1048
+ self.norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1049
+ self.rotary_emb = Qwen3RotaryEmbedding(config=config)
1050
+ self.gradient_checkpointing = False
1051
+ self.has_sliding_layers = "sliding_attention" in self.config.layer_types
1052
+
1053
+ # Cache per-layer sliding_chunk_nums for KV cache eviction
1054
+ _sv = _parse_config_pattern(getattr(config, "summary_sliding_chunk_num", 0) or 0)
1055
+ if isinstance(_sv, list):
1056
+ self._sliding_chunk_nums = [int(v) for v in _sv]
1057
+ else:
1058
+ self._sliding_chunk_nums = [int(_sv)] * config.num_hidden_layers
1059
+
1060
+ # Initialize weights and apply final processing
1061
+ self.post_init()
1062
+
1063
+ def _expand_input_with_summary_tokens(self, input_ids):
1064
+ """Expand input_ids with summary tokens for prefill phase (vectorized).
1065
+
1066
+ Returns:
1067
+ Tuple of (expanded_input_ids, position_ids, text_only_mask)
1068
+ """
1069
+ summary_chunk = self.config.summary_chunk_size
1070
+ summary_num = self.config.summary_token_num
1071
+ summary_begin = self.config.summary_token_begin
1072
+
1073
+ if summary_chunk == 0 or summary_num == 0:
1074
+ return input_ids, None, None
1075
+
1076
+ batch_size, seq_len = input_ids.shape
1077
+ device = input_ids.device
1078
+ dtype = input_ids.dtype
1079
+ block = summary_chunk + summary_num
1080
+
1081
+ # Number of full chunks and remainder
1082
+ n_full_chunks = seq_len // summary_chunk
1083
+ remainder = seq_len % summary_chunk
1084
+ has_remainder = remainder > 0
1085
+
1086
+ # Total expanded length: full_chunks * block + remainder
1087
+ expanded_len = n_full_chunks * block + (remainder if has_remainder else 0)
1088
+
1089
+ # --- Build expanded_input_ids ---
1090
+ expanded_ids = torch.empty((batch_size, expanded_len), dtype=dtype, device=device)
1091
+ text_only_mask = torch.zeros((batch_size, expanded_len), dtype=torch.bool, device=device)
1092
+
1093
+ # Compute text positions: for chunk i, text goes to [i*block, i*block+summary_chunk)
1094
+ # Summary positions: [i*block+summary_chunk, (i+1)*block)
1095
+ if n_full_chunks > 0:
1096
+ chunk_indices = torch.arange(n_full_chunks, device=device)
1097
+ # Text source positions in original input_ids
1098
+ text_src_offsets = (chunk_indices * summary_chunk).unsqueeze(1) + torch.arange(summary_chunk, device=device).unsqueeze(0) # [n_full_chunks, summary_chunk]
1099
+ # Text dest positions in expanded
1100
+ text_dst_offsets = (chunk_indices * block).unsqueeze(1) + torch.arange(summary_chunk, device=device).unsqueeze(0) # [n_full_chunks, summary_chunk]
1101
+ # Summary dest positions
1102
+ summary_dst_offsets = (chunk_indices * block + summary_chunk).unsqueeze(1) + torch.arange(summary_num, device=device).unsqueeze(0) # [n_full_chunks, summary_num]
1103
+
1104
+ text_src_flat = text_src_offsets.reshape(-1)
1105
+ text_dst_flat = text_dst_offsets.reshape(-1)
1106
+ summary_dst_flat = summary_dst_offsets.reshape(-1)
1107
+
1108
+ # Copy text tokens
1109
+ expanded_ids[:, text_dst_flat] = input_ids[:, text_src_flat]
1110
+ text_only_mask[:, text_dst_flat] = True
1111
+
1112
+ # Fill summary tokens
1113
+ summary_ids_val = torch.arange(summary_num, device=device, dtype=dtype) + summary_begin
1114
+ expanded_ids[:, summary_dst_flat] = summary_ids_val.repeat(n_full_chunks).unsqueeze(0).expand(batch_size, -1)
1115
+
1116
+ # Handle remainder (last partial chunk, no summary tokens)
1117
+ if has_remainder:
1118
+ rem_start_src = n_full_chunks * summary_chunk
1119
+ rem_start_dst = n_full_chunks * block
1120
+ rem_offsets = torch.arange(remainder, device=device)
1121
+ expanded_ids[:, rem_start_dst + rem_offsets] = input_ids[:, rem_start_src + rem_offsets]
1122
+ text_only_mask[:, rem_start_dst + rem_offsets] = True
1123
+
1124
+ # --- Build position_ids ---
1125
+ position_ids = torch.empty((batch_size, expanded_len), dtype=torch.long, device=device)
1126
+
1127
+ if n_full_chunks > 0:
1128
+ # Text position IDs
1129
+ if self.config.summary_chunk_position_ids_type == 'origin':
1130
+ text_pos = text_src_flat.unsqueeze(0).expand(batch_size, -1)
1131
+ elif self.config.summary_chunk_position_ids_type == 'inner_chunk':
1132
+ inner_pos = torch.arange(summary_chunk, device=device).repeat(n_full_chunks)
1133
+ text_pos = inner_pos.unsqueeze(0).expand(batch_size, -1)
1134
+ else:
1135
+ raise ValueError(f'Check config.summary_chunk_position_ids_type: {self.config.summary_chunk_position_ids_type}')
1136
+ position_ids[:, text_dst_flat] = text_pos
1137
+
1138
+ # Summary position IDs
1139
+ if self.config.summary_token_position_ids_type == 'zeros':
1140
+ position_ids[:, summary_dst_flat] = 0
1141
+ elif self.config.summary_token_position_ids_type in ('last_chunk_slice_left', 'last_chunk_slice_right'):
1142
+ # Vectorized slice_ends computation for all chunks at once
1143
+ if self.config.summary_token_position_ids_type == 'last_chunk_slice_left':
1144
+ idx = torch.arange(0, summary_num, device=device, dtype=torch.long)
1145
+ else:
1146
+ idx = torch.arange(1, summary_num + 1, device=device, dtype=torch.long)
1147
+ # For each chunk i: prev_text_end = i * summary_chunk
1148
+ prev_ends = (chunk_indices * summary_chunk).unsqueeze(1) # [n_full_chunks, 1]
1149
+ slice_ends = prev_ends + (idx.unsqueeze(0) * summary_chunk) // summary_num - 1 # [n_full_chunks, summary_num]
1150
+ slice_ends = slice_ends.clamp(min=0)
1151
+ # Clamp per-chunk: min is prev_text_end for that chunk
1152
+ slice_ends = torch.max(slice_ends, prev_ends)
1153
+ position_ids[:, summary_dst_flat] = slice_ends.reshape(-1).unsqueeze(0).expand(batch_size, -1)
1154
+ else:
1155
+ raise ValueError(f'Unknown summary_token_position_ids_type: {self.config.summary_token_position_ids_type}')
1156
+
1157
+ # Remainder position IDs
1158
+ if has_remainder:
1159
+ if self.config.summary_chunk_position_ids_type == 'origin':
1160
+ rem_pos = (rem_start_src + rem_offsets).unsqueeze(0).expand(batch_size, -1)
1161
+ elif self.config.summary_chunk_position_ids_type == 'inner_chunk':
1162
+ rem_pos = rem_offsets.unsqueeze(0).expand(batch_size, -1)
1163
+ else:
1164
+ raise ValueError(f'Check config.summary_chunk_position_ids_type: {self.config.summary_chunk_position_ids_type}')
1165
+ position_ids[:, rem_start_dst + rem_offsets] = rem_pos
1166
+
1167
+ return expanded_ids, position_ids, text_only_mask
1168
+
1169
+ def _build_summary_context(self, input_ids, position_ids, is_prefill, use_cache):
1170
+ """Build summary context for attention layers."""
1171
+ summary_chunk = self.config.summary_chunk_size
1172
+ summary_num = self.config.summary_token_num
1173
+ summary_begin = self.config.summary_token_begin
1174
+
1175
+ if summary_chunk > 0 and summary_num > 0:
1176
+ return build_summary_sliding_context(
1177
+ input_ids=input_ids,
1178
+ position_ids=position_ids,
1179
+ summary_token_num=summary_num,
1180
+ summary_token_begin=summary_begin,
1181
+ )
1182
+ return None
1183
+
1184
+ def _filter_summary_tokens(self, hidden_states, text_only_mask, use_summary, is_decode):
1185
+ """Filter out summary tokens from output hidden states."""
1186
+ if text_only_mask is not None:
1187
+ # Prefill: vectorized filtering using boolean mask
1188
+ batch_size, _, hidden_size = hidden_states.shape
1189
+ text_length = text_only_mask[0].sum().item()
1190
+ return hidden_states[text_only_mask.to(hidden_states.device)].reshape(batch_size, text_length, hidden_size)
1191
+ elif use_summary and is_decode and hidden_states.size(1) > 1:
1192
+ # Decode: if we have multiple tokens, only return the first (text token)
1193
+ return hidden_states[:, :1, :]
1194
+ return hidden_states
1195
+
1196
+ @check_model_inputs()
1197
+ @auto_docstring
1198
+ def forward(
1199
+ self,
1200
+ input_ids: Optional[torch.LongTensor] = None,
1201
+ attention_mask: Optional[torch.Tensor] = None,
1202
+ position_ids: Optional[torch.LongTensor] = None,
1203
+ past_key_values: Optional[Cache] = None,
1204
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1205
+ use_cache: Optional[bool] = None,
1206
+ cache_position: Optional[torch.LongTensor] = None,
1207
+ summary_ctx: Optional[SummaryBatchContext] = None,
1208
+ **kwargs: Unpack[TransformersKwargs],
1209
+ ) -> BaseModelOutputWithPast:
1210
+ if (input_ids is None) ^ (inputs_embeds is not None):
1211
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
1212
+ use_summary = getattr(self.config, "use_summary_attention", False)
1213
+ is_prefill = past_key_values is None or past_key_values.get_seq_length() == 0
1214
+
1215
+ # Prefill phase with summary attention: expand input_ids with summary tokens
1216
+ text_only_mask = None
1217
+ if use_summary and input_ids is not None and inputs_embeds is None and is_prefill:
1218
+ input_ids, position_ids, text_only_mask = self._expand_input_with_summary_tokens(input_ids)
1219
+
1220
+ if inputs_embeds is None:
1221
+ inputs_embeds = self.embed_tokens(input_ids)
1222
+
1223
+ # Initialize cache
1224
+ if use_cache and past_key_values is None:
1225
+ if use_summary:
1226
+ past_key_values = Qwen3RingBufferCache(
1227
+ config=self.config, sliding_chunk_nums=self._sliding_chunk_nums)
1228
+ else:
1229
+ past_key_values = DynamicCache(config=self.config)
1230
+
1231
+ if cache_position is None:
1232
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
1233
+ cache_position = torch.arange(
1234
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
1235
+ )
1236
+
1237
+ if position_ids is None:
1238
+ position_ids = cache_position.unsqueeze(0)
1239
+
1240
+ # Build summary context if needed
1241
+ if use_summary and summary_ctx is None and input_ids is not None:
1242
+ summary_ctx = self._build_summary_context(input_ids, position_ids, is_prefill, use_cache)
1243
+
1244
+ causal_mask_mapping = attention_mask
1245
+ if not isinstance(causal_mask_mapping, (dict, list)):
1246
+ if summary_ctx and summary_ctx.enabled:
1247
+ seq_len = inputs_embeds.shape[1]
1248
+ # During prefill, Qwen3SummaryAttention uses summary_attn_func
1249
+ # which does not need a dense mask. Skip expensive mask construction.
1250
+ # During decode, prepare_inputs_for_generation already computed
1251
+ # per-layer keep_indices and passed them as attention_mask (list).
1252
+ # If we reach here with a non-list, it means no mask is needed.
1253
+ causal_mask_mapping = None
1254
+ else:
1255
+ # Prepare mask arguments
1256
+ mask_kwargs = {
1257
+ "config": self.config,
1258
+ "input_embeds": inputs_embeds,
1259
+ "attention_mask": attention_mask,
1260
+ "cache_position": cache_position,
1261
+ "past_key_values": past_key_values,
1262
+ "position_ids": position_ids,
1263
+ }
1264
+ # Create the masks - disable causal mask when summary context is enabled
1265
+ causal_mask_mapping = {
1266
+ "full_attention": create_causal_mask(**mask_kwargs),
1267
+ }
1268
+ # The sliding window alternating layers are not always activated depending on the config
1269
+ if self.has_sliding_layers:
1270
+ causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs)
1271
+
1272
+ hidden_states = inputs_embeds
1273
+
1274
+ # create position embeddings to be shared across the decoder layers
1275
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
1276
+
1277
+ for layer_idx, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
1278
+ if causal_mask_mapping is None:
1279
+ layer_mask = None
1280
+ elif isinstance(causal_mask_mapping, list):
1281
+ layer_mask = causal_mask_mapping[layer_idx]
1282
+ else:
1283
+ layer_mask = causal_mask_mapping[decoder_layer.attention_type]
1284
+ hidden_states = decoder_layer(
1285
+ hidden_states,
1286
+ attention_mask=layer_mask,
1287
+ position_ids=position_ids,
1288
+ past_key_values=past_key_values,
1289
+ use_cache=use_cache,
1290
+ cache_position=cache_position,
1291
+ position_embeddings=position_embeddings,
1292
+ summary_ctx=summary_ctx,
1293
+ **kwargs,
1294
+ )
1295
+
1296
+ hidden_states = self.norm(hidden_states)
1297
+
1298
+ # After prefill: reorganize cache to ring buffer layout
1299
+ if use_cache and use_summary and past_key_values is not None and is_prefill:
1300
+ if hasattr(past_key_values, 'reorganize_after_prefill') and summary_ctx is not None:
1301
+ past_key_values.reorganize_after_prefill(summary_ctx.summary_mask)
1302
+
1303
+ # After chunk boundary decode: reset chunk counters
1304
+ if use_cache and use_summary and past_key_values is not None and not is_prefill:
1305
+ if hasattr(past_key_values, 'reset_chunk_counter'):
1306
+ past_key_values.reset_chunk_counter()
1307
+
1308
+ # Filter out summary tokens from output
1309
+ hidden_states = self._filter_summary_tokens(hidden_states, text_only_mask, use_summary,
1310
+ past_key_values is not None and past_key_values.get_seq_length() > 0)
1311
+
1312
+ return BaseModelOutputWithPast(
1313
+ last_hidden_state=hidden_states,
1314
+ past_key_values=past_key_values if use_cache else None,
1315
+ )
1316
+
1317
+
1318
+ @auto_docstring
1319
+ class Qwen3ForCausalLM(Qwen3PreTrainedModel, GenerationMixin):
1320
+ _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
1321
+ _tp_plan = {"lm_head": "colwise_rep"}
1322
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
1323
+
1324
+ def __init__(self, config):
1325
+ super().__init__(config)
1326
+ self.model = Qwen3Model(config)
1327
+ self.vocab_size = config.vocab_size
1328
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1329
+
1330
+ # Initialize weights and apply final processing
1331
+ self.post_init()
1332
+
1333
+ @can_return_tuple
1334
+ @auto_docstring
1335
+ def forward(
1336
+ self,
1337
+ input_ids: Optional[torch.LongTensor] = None,
1338
+ attention_mask: Optional[torch.Tensor] = None,
1339
+ position_ids: Optional[torch.LongTensor] = None,
1340
+ past_key_values: Optional[Cache] = None,
1341
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1342
+ labels: Optional[torch.LongTensor] = None,
1343
+ use_cache: Optional[bool] = None,
1344
+ cache_position: Optional[torch.LongTensor] = None,
1345
+ logits_to_keep: Union[int, torch.Tensor] = 0,
1346
+ summary_ctx: Optional[SummaryBatchContext] = None,
1347
+ **kwargs: Unpack[TransformersKwargs],
1348
+ ) -> CausalLMOutputWithPast:
1349
+ r"""
1350
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1351
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1352
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1353
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1354
+
1355
+ Example:
1356
+
1357
+ ```python
1358
+ >>> from transformers import AutoTokenizer, Qwen3ForCausalLM
1359
+
1360
+ >>> model = Qwen3ForCausalLM.from_pretrained("Qwen/Qwen3-8B")
1361
+ >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B")
1362
+
1363
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
1364
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1365
+
1366
+ >>> # Generate
1367
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1368
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1369
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1370
+ ```"""
1371
+ outputs: BaseModelOutputWithPast = self.model(
1372
+ input_ids=input_ids,
1373
+ attention_mask=attention_mask,
1374
+ position_ids=position_ids,
1375
+ past_key_values=past_key_values,
1376
+ inputs_embeds=inputs_embeds,
1377
+ use_cache=use_cache,
1378
+ cache_position=cache_position,
1379
+ summary_ctx=summary_ctx,
1380
+ **kwargs,
1381
+ )
1382
+
1383
+ hidden_states = outputs.last_hidden_state
1384
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
1385
+ if isinstance(logits_to_keep, int) and logits_to_keep == 0 and labels is None:
1386
+ # Inference: only need last token's logits to avoid OOM from [seq_len, vocab_size]
1387
+ logits = self.lm_head(hidden_states[:, -1:, :])
1388
+ else:
1389
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
1390
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
1391
+
1392
+ truncate_n = getattr(self.config, "truncate_predict_nums", 151936)
1393
+ if truncate_n > 0:
1394
+ logits = logits[..., :truncate_n]
1395
+
1396
+ loss = None
1397
+ if labels is not None:
1398
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=logits.shape[-1], **kwargs)
1399
+
1400
+ return CausalLMOutputWithPast(
1401
+ loss=loss,
1402
+ logits=logits,
1403
+ past_key_values=outputs.past_key_values,
1404
+ hidden_states=outputs.hidden_states,
1405
+ attentions=outputs.attentions,
1406
+ )
1407
+
1408
+ def _build_summary_attention_mask_for_generation(
1409
+ self,
1410
+ *,
1411
+ input_ids: torch.LongTensor,
1412
+ past_key_values: Optional[Cache],
1413
+ attention_mask: Optional[torch.Tensor],
1414
+ ) -> Optional[torch.Tensor]:
1415
+ """Ring buffer cache handles attention internally — no mask needed for decode."""
1416
+ if isinstance(past_key_values, Qwen3RingBufferCache):
1417
+ return None
1418
+ return attention_mask
1419
+
1420
+ def prepare_inputs_for_generation(
1421
+ self,
1422
+ input_ids: torch.LongTensor,
1423
+ past_key_values: Optional[Cache] = None,
1424
+ attention_mask: Optional[torch.LongTensor] = None,
1425
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1426
+ cache_position: Optional[torch.LongTensor] = None,
1427
+ position_ids: Optional[torch.LongTensor] = None,
1428
+ **kwargs,
1429
+ ):
1430
+ use_summary = getattr(self.config, "use_summary_attention", False)
1431
+
1432
+ # If not using summary attention, use standard behavior
1433
+ if not use_summary:
1434
+ return super().prepare_inputs_for_generation(
1435
+ input_ids=input_ids,
1436
+ past_key_values=past_key_values,
1437
+ attention_mask=attention_mask,
1438
+ inputs_embeds=inputs_embeds,
1439
+ cache_position=cache_position,
1440
+ position_ids=position_ids,
1441
+ **kwargs,
1442
+ )
1443
+
1444
+ # For summary attention: handle cache-based input slicing
1445
+ summary_chunk_size = getattr(self.config, "summary_chunk_size", 0)
1446
+ summary_token_num = getattr(self.config, "summary_token_num", 0)
1447
+ summary_token_begin = getattr(self.config, "summary_token_begin", 0)
1448
+
1449
+ # Prefill phase: pass full sequence, forward() will handle summary token insertion
1450
+ if past_key_values is None or past_key_values.get_seq_length() == 0:
1451
+ if cache_position is None:
1452
+ cache_position = torch.arange(0, input_ids.shape[1], device=input_ids.device)
1453
+
1454
+ return {
1455
+ "input_ids": input_ids,
1456
+ "attention_mask": attention_mask,
1457
+ "position_ids": position_ids,
1458
+ "past_key_values": past_key_values,
1459
+ "cache_position": cache_position,
1460
+ "use_cache": kwargs.get("use_cache"),
1461
+ }
1462
+
1463
+ # Decode phase: only pass new tokens not in cache
1464
+ # Get current chunk size (number of text tokens in current chunk)
1465
+ cur_chunk = past_key_values.get_cur_chunk_size() if hasattr(past_key_values, "get_cur_chunk_size") else 0
1466
+ true_token_num = past_key_values.get_true_token_num()
1467
+
1468
+ # Only take the new tokens that haven't been processed
1469
+ if input_ids.shape[1] > 1:
1470
+ # Slice to get only new tokens
1471
+ new_token_count = input_ids.shape[1] - true_token_num
1472
+ assert new_token_count > 0, f'new_token_count={new_token_count} should be greater than 0'
1473
+ input_ids = input_ids[:, -new_token_count:]
1474
+ device = input_ids.device
1475
+ # Check if we need to insert summary tokens
1476
+ # If cur_chunk >= summary_chunk_size, we need to generate summary tokens
1477
+ if cur_chunk == summary_chunk_size - 1:
1478
+ # Insert summary tokens
1479
+ batch_size = input_ids.shape[0]
1480
+ summary_ids = (
1481
+ torch.arange(summary_token_num, device=device, dtype=input_ids.dtype)
1482
+ + summary_token_begin
1483
+ ).unsqueeze(0).repeat(batch_size, 1)
1484
+
1485
+ # Concatenate: [text_token, summary_tokens]
1486
+ input_ids = torch.cat([input_ids, summary_ids], dim=1)
1487
+
1488
+ # Position IDs: text token uses cur_chunk, summary tokens use 0
1489
+ if self.config.summary_chunk_position_ids_type == 'origin':
1490
+ text_pos = torch.full((batch_size, 1), past_key_values.get_true_token_num(), device=device, dtype=torch.long)
1491
+ elif self.config.summary_chunk_position_ids_type == 'inner_chunk':
1492
+ text_pos = torch.full((batch_size, 1), cur_chunk, device=device, dtype=torch.long)
1493
+ else:
1494
+ raise ValueError(f'Check config.summary_chunk_position_ids_type: {self.config.summary_chunk_position_ids_type}')
1495
+
1496
+ if self.config.summary_token_position_ids_type == 'zeros':
1497
+ summary_pos = torch.zeros((batch_size, summary_token_num), device=device, dtype=torch.long)
1498
+ elif self.config.summary_token_position_ids_type == 'last_chunk_slice_left':
1499
+ # 等分成 summary_num 份,每个 summary token 取对应 slice 的末尾
1500
+ prev_text_end = past_key_values.get_true_token_num()+1-summary_chunk_size
1501
+ cur_text_end = past_key_values.get_true_token_num()+1
1502
+ chunk_len = cur_text_end - prev_text_end
1503
+
1504
+ idx = torch.arange(0, summary_token_num, device=device, dtype=torch.long,)
1505
+
1506
+ # 每一份的末尾(全局 position)
1507
+ slice_ends = prev_text_end + (idx * chunk_len) // summary_token_num - 1
1508
+ slice_ends = slice_ends.clamp(min=prev_text_end)
1509
+
1510
+ summary_pos = slice_ends.to(dtype=torch.long, device=device).unsqueeze(0)
1511
+ elif self.config.summary_token_position_ids_type == 'last_chunk_slice_right':
1512
+ # 等分成 summary_num 份,每个 summary token 取对应 slice 的末尾
1513
+ prev_text_end = past_key_values.get_true_token_num()+1-summary_chunk_size
1514
+ cur_text_end = past_key_values.get_true_token_num()+1
1515
+ chunk_len = cur_text_end - prev_text_end
1516
+
1517
+ idx = torch.arange(1, summary_token_num + 1, device=device, dtype=torch.long,)
1518
+
1519
+ # 每一份的末尾(全局 position)
1520
+ slice_ends = prev_text_end + (idx * chunk_len) // summary_token_num - 1
1521
+ slice_ends = slice_ends.clamp(min=prev_text_end)
1522
+
1523
+ summary_pos = slice_ends.to(dtype=torch.long, device=device).unsqueeze(0)
1524
+
1525
+ else:
1526
+ raise ValueError('')
1527
+
1528
+ position_ids = torch.cat([text_pos, summary_pos], dim=1)
1529
+ else:
1530
+ # Normal decode: just the new text token with position = cur_chunk
1531
+ if position_ids is None:
1532
+ batch_size = input_ids.shape[0]
1533
+ if self.config.summary_chunk_position_ids_type == 'origin':
1534
+ position_ids = torch.full((batch_size, input_ids.shape[1]), past_key_values.get_true_token_num(), device=input_ids.device, dtype=torch.long)
1535
+ elif self.config.summary_chunk_position_ids_type == 'inner_chunk':
1536
+ position_ids = torch.full((batch_size, input_ids.shape[1]), cur_chunk, device=input_ids.device, dtype=torch.long)
1537
+ else:
1538
+ raise ValueError(f'Check config.summary_chunk_position_ids_type: {self.config.summary_chunk_position_ids_type}')
1539
+ return {
1540
+ "input_ids": input_ids,
1541
+ "attention_mask": self._build_summary_attention_mask_for_generation(
1542
+ input_ids=input_ids,
1543
+ past_key_values=past_key_values,
1544
+ attention_mask=attention_mask,
1545
+ ),
1546
+ "position_ids": position_ids,
1547
+ "past_key_values": past_key_values,
1548
+ "cache_position": cache_position,
1549
+ "use_cache": kwargs.get("use_cache"),
1550
+ }
1551
+
1552
+
1553
+ class Qwen3ForSequenceClassification(GenericForSequenceClassification, Qwen3PreTrainedModel):
1554
+ pass
1555
+
1556
+
1557
+ class Qwen3ForTokenClassification(GenericForTokenClassification, Qwen3PreTrainedModel):
1558
+ pass
1559
+
1560
+
1561
+ class Qwen3ForQuestionAnswering(GenericForQuestionAnswering, Qwen3PreTrainedModel):
1562
+ base_model_prefix = "transformer" # For BC, where `transformer` was used instead of `model`
1563
+
1564
+
1565
+ __all__ = [
1566
+ "Qwen3ForCausalLM",
1567
+ "Qwen3ForQuestionAnswering",
1568
+ "Qwen3PreTrainedModel",
1569
+ "Qwen3Model",
1570
+ "Qwen3ForSequenceClassification",
1571
+ "Qwen3ForTokenClassification",
1572
+ "Qwen3RingBufferCache",
1573
+ "Qwen3SummaryAttention",
1574
+ "SummaryBatchContext",
1575
+ "build_summary_context",
1576
+ "build_summary_sliding_context",
1577
+ ]
modeling_qwen3sa.py ADDED
@@ -0,0 +1,1587 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/qwen3/modular_qwen3.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_qwen3.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ # coding=utf-8
8
+ # Copyright 2025 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+
22
+ from typing import Any, Callable, Optional, Union
23
+
24
+ import torch
25
+ from torch import nn
26
+ import torch.nn.functional as F
27
+ from torch.nn.attention import SDPBackend, sdpa_kernel
28
+ from flash_attn import flash_attn_func
29
+
30
+ from transformers.activations import ACT2FN
31
+ from transformers.cache_utils import Cache, DynamicCache
32
+ from transformers.generation import GenerationMixin
33
+ from transformers.integrations import use_kernel_forward_from_hub
34
+ from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask
35
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
36
+ from transformers.modeling_layers import (
37
+ GenericForQuestionAnswering,
38
+ GenericForSequenceClassification,
39
+ GenericForTokenClassification,
40
+ GradientCheckpointingLayer,
41
+ )
42
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
43
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
44
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
45
+ from transformers.processing_utils import Unpack
46
+ from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple
47
+ from transformers.utils.deprecation import deprecate_kwarg
48
+ from transformers.utils.generic import check_model_inputs
49
+ from transformers.models.qwen3.configuration_qwen3 import Qwen3Config
50
+
51
+ from .summary_context import SummaryBatchContext, build_summary_context, build_summary_sliding_context
52
+ from summary_attn import summary_attn_func
53
+
54
+
55
+ def _parse_config_pattern(val):
56
+ """Parse a config value that may be an int, list, or Python pattern string like '([4096]*1+[128]*3)*9'."""
57
+ if isinstance(val, list):
58
+ return val
59
+ if isinstance(val, str):
60
+ return eval(val)
61
+ return val
62
+
63
+
64
+ @use_kernel_forward_from_hub("RMSNorm")
65
+ class Qwen3RMSNorm(nn.Module):
66
+ def __init__(self, hidden_size, eps: float = 1e-6) -> None:
67
+ """
68
+ Qwen3RMSNorm is equivalent to T5LayerNorm
69
+ """
70
+ super().__init__()
71
+ self.weight = nn.Parameter(torch.ones(hidden_size))
72
+ self.variance_epsilon = eps
73
+
74
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
75
+ input_dtype = hidden_states.dtype
76
+ hidden_states = hidden_states.to(torch.float32)
77
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
78
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
79
+ return self.weight * hidden_states.to(input_dtype)
80
+
81
+ def extra_repr(self):
82
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
83
+
84
+
85
+ class Qwen3RingBufferCache:
86
+ """
87
+ Ring buffer KV cache with summary support.
88
+
89
+ Two strategies based on per-layer sliding_chunk_num:
90
+ - Large window layers (is_large_window=True): append-only buffer storing only text KV.
91
+ Summary KV is NOT stored since text tokens attend to all text KV directly.
92
+ - Small window layers (is_large_window=False): single buffer layout:
93
+ [ scratch (S) | chunk_text (C) ←fill | ring_text (ws) | summaries →fill | headroom ]
94
+ ^0 ^S ^S+C-cbl ^S+C ^S+C+ws ^S+C+ws+n_sum
95
+ scratch: temp area for summary self-KV at chunk boundaries (avoids cat).
96
+ chunk_text fills right-to-left; summaries append left-to-right.
97
+ get_attention_kv returns a single contiguous slice.
98
+ get_summary_attention_kv writes summary KV to scratch, returns [0 : S+C].
99
+
100
+ RoPE position information is baked into KV, so physical order doesn't matter.
101
+ """
102
+
103
+ is_compileable = False
104
+ _SUMMARY_INIT_CAP = 512
105
+ _APPEND_HEADROOM = 1024
106
+
107
+ def __init__(self, config: Qwen3Config, sliding_chunk_nums: list[int]):
108
+ super().__init__()
109
+ self.summary_chunk_size = getattr(config, "summary_chunk_size", 0)
110
+ self.summary_token_num = getattr(config, "summary_token_num", 0)
111
+ self.num_hidden_layers = config.num_hidden_layers
112
+
113
+ self.sliding_chunk_nums = sliding_chunk_nums
114
+ large_window_threshold = min(sliding_chunk_nums) * self.summary_chunk_size
115
+ self.is_large_window = [sv * self.summary_chunk_size > large_window_threshold for sv in sliding_chunk_nums]
116
+ self.window_sizes = [sv * self.summary_chunk_size for sv in sliding_chunk_nums]
117
+
118
+ self.key_cache = [None for _ in range(config.num_hidden_layers)]
119
+ self.value_cache = [None for _ in range(config.num_hidden_layers)]
120
+
121
+ # Large window: append-only
122
+ self._text_len = [0] * config.num_hidden_layers
123
+ self._capacity = [0] * config.num_hidden_layers
124
+
125
+ # Small window: [ scratch (S) | chunk_text (C) ←fill | ring_text (ws) | summaries →fill | headroom ]
126
+ self._window_write_ptr = [0] * config.num_hidden_layers
127
+ self._n_valid_window = [0] * config.num_hidden_layers
128
+ self._chunk_buf_len = [0] * config.num_hidden_layers
129
+ self._n_summaries = [0] * config.num_hidden_layers # number of summaries stored
130
+
131
+ # Common
132
+ self.cur_chunk_sizes = [0] * config.num_hidden_layers
133
+ self.true_tokens = [0] * config.num_hidden_layers
134
+ self._total_chunks = [0] * config.num_hidden_layers # completed chunks count
135
+ self._reorganized = False
136
+
137
+ def __len__(self):
138
+ return self.num_hidden_layers
139
+
140
+ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
141
+ """Returns nonzero when cache is populated (used to detect prefill vs decode)."""
142
+ if layer_idx >= self.num_hidden_layers:
143
+ return 0
144
+ if self.is_large_window[layer_idx]:
145
+ return self._text_len[layer_idx]
146
+ else:
147
+ return self._n_valid_window[layer_idx] + self._chunk_buf_len[layer_idx] + self._n_summaries[layer_idx]
148
+
149
+ def get_cur_chunk_size(self, layer_idx: Optional[int] = None) -> int:
150
+ if layer_idx is None:
151
+ layer_idx = self.num_hidden_layers - 1
152
+ return self.cur_chunk_sizes[layer_idx]
153
+
154
+ def get_true_token_num(self, layer_idx: Optional[int] = None) -> int:
155
+ if layer_idx is None:
156
+ layer_idx = self.num_hidden_layers - 1
157
+ return self.true_tokens[layer_idx]
158
+
159
+ # ── Prefill: standard append (before reorganize) ──
160
+
161
+ def update(
162
+ self,
163
+ key_states: torch.Tensor,
164
+ value_states: torch.Tensor,
165
+ layer_idx: int,
166
+ cache_kwargs: Optional[dict[str, Any]] = None,
167
+ ) -> tuple[torch.Tensor, torch.Tensor]:
168
+ """Append KV during prefill (before reorganize). Returns full KV for prefill attention."""
169
+ add_len = key_states.shape[-2]
170
+ cur_len = self._text_len[layer_idx]
171
+ new_len = cur_len + add_len
172
+
173
+ if self.key_cache[layer_idx] is None:
174
+ cap = new_len + self._APPEND_HEADROOM
175
+ bsz, heads, _, head_dim = key_states.shape
176
+ self.key_cache[layer_idx] = torch.empty(
177
+ bsz, heads, cap, head_dim, dtype=key_states.dtype, device=key_states.device)
178
+ self.value_cache[layer_idx] = torch.empty(
179
+ bsz, heads, cap, head_dim, dtype=value_states.dtype, device=value_states.device)
180
+ self._capacity[layer_idx] = cap
181
+ elif new_len > self._capacity[layer_idx]:
182
+ cap = max(new_len + self._APPEND_HEADROOM, self._capacity[layer_idx] * 2)
183
+ old_k, old_v = self.key_cache[layer_idx], self.value_cache[layer_idx]
184
+ bsz, heads, _, head_dim = old_k.shape
185
+ new_k = torch.empty(bsz, heads, cap, head_dim, dtype=old_k.dtype, device=old_k.device)
186
+ new_v = torch.empty(bsz, heads, cap, head_dim, dtype=old_v.dtype, device=old_v.device)
187
+ new_k[:, :, :cur_len, :].copy_(old_k[:, :, :cur_len, :])
188
+ new_v[:, :, :cur_len, :].copy_(old_v[:, :, :cur_len, :])
189
+ self.key_cache[layer_idx] = new_k
190
+ self.value_cache[layer_idx] = new_v
191
+ self._capacity[layer_idx] = cap
192
+
193
+ self.key_cache[layer_idx][:, :, cur_len:new_len, :].copy_(key_states)
194
+ self.value_cache[layer_idx][:, :, cur_len:new_len, :].copy_(value_states)
195
+ self._text_len[layer_idx] = new_len
196
+
197
+ if self.summary_chunk_size > 0:
198
+ if cache_kwargs and 'summary_mask' in cache_kwargs:
199
+ text_count = add_len - cache_kwargs['summary_mask'][0].sum().item()
200
+ else:
201
+ text_count = add_len
202
+ self.cur_chunk_sizes[layer_idx] += add_len
203
+ self.true_tokens[layer_idx] += text_count
204
+
205
+ return self.key_cache[layer_idx][:, :, :new_len, :], self.value_cache[layer_idx][:, :, :new_len, :]
206
+
207
+ # ── Reorganize after prefill ──
208
+
209
+ def reorganize_after_prefill(self, summary_mask: torch.Tensor):
210
+ """Reorganize all layers from prefill block layout to ring buffer layout.
211
+
212
+ Args:
213
+ summary_mask: bool tensor [bsz, prefill_seq_len] where True = summary position.
214
+ """
215
+ if self._reorganized:
216
+ return
217
+ self._reorganized = True
218
+
219
+ text_mask = ~summary_mask[0]
220
+
221
+ for layer_idx in range(self.num_hidden_layers):
222
+ prefill_len = self._text_len[layer_idx]
223
+ prefill_k = self.key_cache[layer_idx][:, :, :prefill_len, :]
224
+ prefill_v = self.value_cache[layer_idx][:, :, :prefill_len, :]
225
+ bsz, heads, _, head_dim = prefill_k.shape
226
+ device, dtype = prefill_k.device, prefill_k.dtype
227
+
228
+ text_k = prefill_k[:, :, text_mask, :]
229
+ text_v = prefill_v[:, :, text_mask, :]
230
+ n_text = text_k.shape[2]
231
+
232
+ if self.is_large_window[layer_idx]:
233
+ # Large window: keep only text KV
234
+ cap = n_text + self._APPEND_HEADROOM
235
+ new_k = torch.empty(bsz, heads, cap, head_dim, dtype=dtype, device=device)
236
+ new_v = torch.empty(bsz, heads, cap, head_dim, dtype=dtype, device=device)
237
+ new_k[:, :, :n_text, :].copy_(text_k)
238
+ new_v[:, :, :n_text, :].copy_(text_v)
239
+ self.key_cache[layer_idx] = new_k
240
+ self.value_cache[layer_idx] = new_v
241
+ self._text_len[layer_idx] = n_text
242
+ self._capacity[layer_idx] = cap
243
+ else:
244
+ # Small window: [ scratch (S) | chunk_text (C) ←fill | ring_text (ws) | summaries →fill | headroom ]
245
+ summary_k = prefill_k[:, :, summary_mask[0], :]
246
+ summary_v = prefill_v[:, :, summary_mask[0], :]
247
+ n_summary = summary_k.shape[2]
248
+
249
+ C = self.summary_chunk_size
250
+ S = self.summary_token_num
251
+ ws = self.window_sizes[layer_idx]
252
+ scn = self.sliding_chunk_nums[layer_idx]
253
+
254
+ # Split text into complete chunks + partial remainder
255
+ n_complete_chunks = n_text // C
256
+ n_partial = n_text % C
257
+ n_complete_text = n_complete_chunks * C
258
+
259
+ # Window: last scn complete chunks (or all if fewer)
260
+ n_window_chunks = min(scn, n_complete_chunks)
261
+ n_window_text = n_window_chunks * C
262
+ window_start = n_complete_text - n_window_text
263
+
264
+ # Layout: [ scratch (S) | chunk_text (C) | ring_text (ws) | summaries | headroom ]
265
+ summary_headroom = max(self._SUMMARY_INIT_CAP, n_summary + 256)
266
+ total_cap = S + C + ws + n_summary + summary_headroom
267
+
268
+ new_k = torch.empty(bsz, heads, total_cap, head_dim, dtype=dtype, device=device)
269
+ new_v = torch.empty(bsz, heads, total_cap, head_dim, dtype=dtype, device=device)
270
+
271
+ # Copy partial chunk text to [S+C-n_partial : S+C] (left-filled)
272
+ if n_partial > 0:
273
+ new_k[:, :, S + C - n_partial:S + C, :].copy_(
274
+ text_k[:, :, n_complete_text:, :])
275
+ new_v[:, :, S + C - n_partial:S + C, :].copy_(
276
+ text_v[:, :, n_complete_text:, :])
277
+
278
+ # Copy window text to [S+C : S+C+n_window_text]
279
+ if n_window_text > 0:
280
+ new_k[:, :, S + C:S + C + n_window_text, :].copy_(
281
+ text_k[:, :, window_start:n_complete_text, :])
282
+ new_v[:, :, S + C:S + C + n_window_text, :].copy_(
283
+ text_v[:, :, window_start:n_complete_text, :])
284
+ self._n_valid_window[layer_idx] = n_window_text
285
+ self._window_write_ptr[layer_idx] = n_window_text % ws
286
+
287
+ # Copy summaries to [S+C+ws : S+C+ws+n_summary]
288
+ if n_summary > 0:
289
+ new_k[:, :, S + C + ws:S + C + ws + n_summary, :].copy_(summary_k)
290
+ new_v[:, :, S + C + ws:S + C + ws + n_summary, :].copy_(summary_v)
291
+
292
+ self.key_cache[layer_idx] = new_k
293
+ self.value_cache[layer_idx] = new_v
294
+ self._n_summaries[layer_idx] = n_summary
295
+ self._capacity[layer_idx] = total_cap
296
+ self._text_len[layer_idx] = 0
297
+ self._chunk_buf_len[layer_idx] = n_partial
298
+
299
+ block = self.summary_chunk_size + self.summary_token_num
300
+ for layer_idx in range(self.num_hidden_layers):
301
+ self.cur_chunk_sizes[layer_idx] = self.cur_chunk_sizes[layer_idx] % block
302
+ self._total_chunks[layer_idx] = self._n_summaries[layer_idx] if not self.is_large_window[layer_idx] else (self.true_tokens[layer_idx] // self.summary_chunk_size)
303
+
304
+ # ── Decode: text token update ──
305
+
306
+ def update_text(self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int):
307
+ """Write a single text token KV during decode."""
308
+ if self.is_large_window[layer_idx]:
309
+ cur = self._text_len[layer_idx]
310
+ new_len = cur + 1
311
+ if new_len > self._capacity[layer_idx]:
312
+ cap = max(new_len + self._APPEND_HEADROOM, self._capacity[layer_idx] * 2)
313
+ old_k, old_v = self.key_cache[layer_idx], self.value_cache[layer_idx]
314
+ bsz, heads, _, head_dim = old_k.shape
315
+ new_k = torch.empty(bsz, heads, cap, head_dim, dtype=old_k.dtype, device=old_k.device)
316
+ new_v = torch.empty(bsz, heads, cap, head_dim, dtype=old_v.dtype, device=old_v.device)
317
+ new_k[:, :, :cur, :].copy_(old_k[:, :, :cur, :])
318
+ new_v[:, :, :cur, :].copy_(old_v[:, :, :cur, :])
319
+ self.key_cache[layer_idx] = new_k
320
+ self.value_cache[layer_idx] = new_v
321
+ self._capacity[layer_idx] = cap
322
+ self.key_cache[layer_idx][:, :, cur:new_len, :].copy_(key_states)
323
+ self.value_cache[layer_idx][:, :, cur:new_len, :].copy_(value_states)
324
+ self._text_len[layer_idx] = new_len
325
+ else:
326
+ # Write to chunk_text region, left-filled: position S+C-1-cbl
327
+ C = self.summary_chunk_size
328
+ S = self.summary_token_num
329
+ cbl = self._chunk_buf_len[layer_idx]
330
+ dst = S + C - 1 - cbl
331
+ self.key_cache[layer_idx][:, :, dst:dst+1, :].copy_(key_states)
332
+ self.value_cache[layer_idx][:, :, dst:dst+1, :].copy_(value_states)
333
+ self._chunk_buf_len[layer_idx] = cbl + 1
334
+
335
+ self.cur_chunk_sizes[layer_idx] += 1
336
+ self.true_tokens[layer_idx] += 1
337
+
338
+ # ── Decode: summary token update ──
339
+
340
+ def update_summary(self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int):
341
+ """Write summary token KV during decode (chunk boundary).
342
+
343
+ Large window: skip. Small window: flush chunk to ring, append summary rightward.
344
+ """
345
+ n_summary = key_states.shape[2]
346
+
347
+ if self.is_large_window[layer_idx]:
348
+ self.cur_chunk_sizes[layer_idx] += n_summary
349
+ self._total_chunks[layer_idx] += n_summary
350
+ return
351
+
352
+ # ── Small window: boundary processing ──
353
+ C = self.summary_chunk_size
354
+ S = self.summary_token_num
355
+ ws = self.window_sizes[layer_idx]
356
+ cbl = self._chunk_buf_len[layer_idx]
357
+ ptr = self._window_write_ptr[layer_idx]
358
+
359
+ # Step A: Flush chunk_text to ring_text
360
+ # chunk_text lives at [S+C-cbl : S+C], left-filled
361
+ chunk_src = S + C - cbl
362
+ if cbl > 0:
363
+ ring_dst = S + C + ptr
364
+ if ptr + cbl <= ws:
365
+ self.key_cache[layer_idx][:, :, ring_dst:ring_dst + cbl, :].copy_(
366
+ self.key_cache[layer_idx][:, :, chunk_src:chunk_src + cbl, :])
367
+ self.value_cache[layer_idx][:, :, ring_dst:ring_dst + cbl, :].copy_(
368
+ self.value_cache[layer_idx][:, :, chunk_src:chunk_src + cbl, :])
369
+ else:
370
+ first = ws - ptr
371
+ self.key_cache[layer_idx][:, :, ring_dst:ring_dst + first, :].copy_(
372
+ self.key_cache[layer_idx][:, :, chunk_src:chunk_src + first, :])
373
+ self.value_cache[layer_idx][:, :, ring_dst:ring_dst + first, :].copy_(
374
+ self.value_cache[layer_idx][:, :, chunk_src:chunk_src + first, :])
375
+ rest = cbl - first
376
+ self.key_cache[layer_idx][:, :, S + C:S + C + rest, :].copy_(
377
+ self.key_cache[layer_idx][:, :, chunk_src + first:chunk_src + cbl, :])
378
+ self.value_cache[layer_idx][:, :, S + C:S + C + rest, :].copy_(
379
+ self.value_cache[layer_idx][:, :, chunk_src + first:chunk_src + cbl, :])
380
+
381
+ self._window_write_ptr[layer_idx] = (ptr + cbl) % ws
382
+ if self._n_valid_window[layer_idx] < ws:
383
+ self._n_valid_window[layer_idx] = min(ws, self._n_valid_window[layer_idx] + cbl)
384
+ self._chunk_buf_len[layer_idx] = 0
385
+
386
+ # Step B: Append summary rightward
387
+ n_sum = self._n_summaries[layer_idx]
388
+ sum_dst = S + C + ws + n_sum
389
+ if sum_dst + n_summary > self._capacity[layer_idx]:
390
+ self._grow_buffer_right(layer_idx)
391
+ sum_dst = S + C + ws + self._n_summaries[layer_idx]
392
+
393
+ self.key_cache[layer_idx][:, :, sum_dst:sum_dst + n_summary, :].copy_(key_states)
394
+ self.value_cache[layer_idx][:, :, sum_dst:sum_dst + n_summary, :].copy_(value_states)
395
+ self._n_summaries[layer_idx] += n_summary
396
+
397
+ self.cur_chunk_sizes[layer_idx] += n_summary
398
+ self._total_chunks[layer_idx] += n_summary
399
+
400
+ # ── Decode: get KV for attention ──
401
+
402
+ def get_attention_kv(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]:
403
+ """Get full KV for text token attention.
404
+
405
+ Large window: buffer[:text_len]
406
+ Small window: always a single contiguous slice.
407
+ - ring full: [S+C-cbl : S+C+ws+n_old_sum]
408
+ - ring not full: [S+C-cbl : S+C+nv]
409
+ """
410
+ if self.is_large_window[layer_idx]:
411
+ tl = self._text_len[layer_idx]
412
+ return (self.key_cache[layer_idx][:, :, :tl, :],
413
+ self.value_cache[layer_idx][:, :, :tl, :])
414
+
415
+ C = self.summary_chunk_size
416
+ S = self.summary_token_num
417
+ ws = self.window_sizes[layer_idx]
418
+ nv = self._n_valid_window[layer_idx]
419
+ cbl = self._chunk_buf_len[layer_idx]
420
+
421
+ start = S + C - cbl
422
+
423
+ if nv >= ws:
424
+ # Ring full: include old summaries (skip in-window ones)
425
+ scn = self.sliding_chunk_nums[layer_idx]
426
+ n_summaries = self._n_summaries[layer_idx]
427
+ skip = min(scn * S, n_summaries)
428
+ end = S + C + ws + (n_summaries - skip)
429
+ else:
430
+ # Ring not full: all summaries are in-window, skip them all
431
+ end = S + C + nv
432
+
433
+ return (self.key_cache[layer_idx][:, :, start:end, :],
434
+ self.value_cache[layer_idx][:, :, start:end, :])
435
+
436
+ def get_current_chunk_kv(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]:
437
+ """Get KV of the current chunk's C text tokens for summary token attention."""
438
+ C = self.summary_chunk_size
439
+ if self.is_large_window[layer_idx]:
440
+ tl = self._text_len[layer_idx]
441
+ return (self.key_cache[layer_idx][:, :, tl - C:tl, :],
442
+ self.value_cache[layer_idx][:, :, tl - C:tl, :])
443
+ else:
444
+ S = self.summary_token_num
445
+ cbl = self._chunk_buf_len[layer_idx]
446
+ return (self.key_cache[layer_idx][:, :, S + C - cbl:S + C, :],
447
+ self.value_cache[layer_idx][:, :, S + C - cbl:S + C, :])
448
+
449
+ def get_summary_attention_kv(
450
+ self,
451
+ layer_idx: int,
452
+ k_summary: torch.Tensor,
453
+ v_summary: torch.Tensor,
454
+ ) -> tuple[torch.Tensor, torch.Tensor]:
455
+ """Write summary KV to scratch area [0:S], return contiguous [0 : S+C] for summary attention.
456
+
457
+ This avoids a torch.cat by using the pre-reserved scratch region.
458
+ For large window layers, falls back to cat (no scratch area).
459
+ """
460
+ C = self.summary_chunk_size
461
+ S = self.summary_token_num
462
+ if self.is_large_window[layer_idx]:
463
+ tl = self._text_len[layer_idx]
464
+ k_chunk = self.key_cache[layer_idx][:, :, tl - C:tl, :]
465
+ v_chunk = self.value_cache[layer_idx][:, :, tl - C:tl, :]
466
+ return (torch.cat([k_chunk, k_summary], dim=2),
467
+ torch.cat([v_chunk, v_summary], dim=2))
468
+ else:
469
+ # Write summary KV into scratch area [0:S]
470
+ self.key_cache[layer_idx][:, :, 0:S, :].copy_(k_summary)
471
+ self.value_cache[layer_idx][:, :, 0:S, :].copy_(v_summary)
472
+ # Return contiguous [scratch | chunk_text] = [0 : S+C]
473
+ return (self.key_cache[layer_idx][:, :, 0:S + C, :],
474
+ self.value_cache[layer_idx][:, :, 0:S + C, :])
475
+
476
+ def _grow_buffer_right(self, layer_idx: int):
477
+ """Grow buffer rightward when summary headroom is exhausted (doubling strategy).
478
+
479
+ Only the tail (headroom) is extended; chunk_text and ring_text positions are unchanged.
480
+ """
481
+ old_k = self.key_cache[layer_idx]
482
+ old_v = self.value_cache[layer_idx]
483
+ bsz, heads, old_cap, head_dim = old_k.shape
484
+
485
+ extra = max(self._SUMMARY_INIT_CAP, self._n_summaries[layer_idx])
486
+ new_cap = max(old_cap + extra, old_cap * 2)
487
+
488
+ new_k = torch.empty(bsz, heads, new_cap, head_dim, dtype=old_k.dtype, device=old_k.device)
489
+ new_v = torch.empty(bsz, heads, new_cap, head_dim, dtype=old_v.dtype, device=old_v.device)
490
+
491
+ # Copy all existing data in place — positions unchanged
492
+ new_k[:, :, :old_cap, :].copy_(old_k)
493
+ new_v[:, :, :old_cap, :].copy_(old_v)
494
+
495
+ self.key_cache[layer_idx] = new_k
496
+ self.value_cache[layer_idx] = new_v
497
+ self._capacity[layer_idx] = new_cap
498
+
499
+ def reset_chunk_counter(self):
500
+ """Reset chunk counters after a chunk boundary step completes."""
501
+ block = self.summary_chunk_size + self.summary_token_num
502
+ for layer_idx in range(self.num_hidden_layers):
503
+ if self.cur_chunk_sizes[layer_idx] >= block:
504
+ self.cur_chunk_sizes[layer_idx] %= block
505
+
506
+
507
+ class Qwen3MLP(nn.Module):
508
+ def __init__(self, config):
509
+ super().__init__()
510
+ self.config = config
511
+ self.hidden_size = config.hidden_size
512
+ self.intermediate_size = config.intermediate_size
513
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
514
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
515
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
516
+ self.act_fn = ACT2FN[config.hidden_act]
517
+
518
+ def forward(self, x):
519
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
520
+ return down_proj
521
+
522
+
523
+ def rotate_half(x):
524
+ """Rotates half the hidden dims of the input."""
525
+ x1 = x[..., : x.shape[-1] // 2]
526
+ x2 = x[..., x.shape[-1] // 2 :]
527
+ return torch.cat((-x2, x1), dim=-1)
528
+
529
+
530
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
531
+ """Applies Rotary Position Embedding to the query and key tensors.
532
+
533
+ Args:
534
+ q (`torch.Tensor`): The query tensor.
535
+ k (`torch.Tensor`): The key tensor.
536
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
537
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
538
+ position_ids (`torch.Tensor`, *optional*):
539
+ Deprecated and unused.
540
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
541
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
542
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
543
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
544
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
545
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
546
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
547
+ Returns:
548
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
549
+ """
550
+ cos = cos.unsqueeze(unsqueeze_dim)
551
+ sin = sin.unsqueeze(unsqueeze_dim)
552
+ q_embed = (q * cos) + (rotate_half(q) * sin)
553
+ k_embed = (k * cos) + (rotate_half(k) * sin)
554
+ return q_embed, k_embed
555
+
556
+
557
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
558
+ """
559
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
560
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
561
+ """
562
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
563
+ if n_rep == 1:
564
+ return hidden_states
565
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
566
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
567
+
568
+
569
+ def eager_attention_forward(
570
+ module: nn.Module,
571
+ query: torch.Tensor,
572
+ key: torch.Tensor,
573
+ value: torch.Tensor,
574
+ attention_mask: Optional[torch.Tensor],
575
+ scaling: float,
576
+ dropout: float = 0.0,
577
+ **kwargs: Unpack[TransformersKwargs],
578
+ ):
579
+ key_states = repeat_kv(key, module.num_key_value_groups)
580
+ value_states = repeat_kv(value, module.num_key_value_groups)
581
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
582
+ if attention_mask is not None:
583
+ attn_weights = attn_weights + attention_mask
584
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
585
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
586
+ attn_output = torch.matmul(attn_weights, value_states)
587
+ attn_output = attn_output.transpose(1, 2).contiguous()
588
+ return attn_output, attn_weights
589
+
590
+
591
+ def _sdpa_attention_forward(
592
+ module: nn.Module,
593
+ query: torch.Tensor,
594
+ key: torch.Tensor,
595
+ value: torch.Tensor,
596
+ attention_mask: Optional[torch.Tensor],
597
+ scaling: float,
598
+ dropout: float = 0.0,
599
+ **kwargs: Unpack[TransformersKwargs],
600
+ ):
601
+ key_states = repeat_kv(key, module.num_key_value_groups)
602
+ value_states = repeat_kv(value, module.num_key_value_groups)
603
+ attn_output = F.scaled_dot_product_attention(
604
+ query,
605
+ key_states,
606
+ value_states,
607
+ attn_mask=None,
608
+ dropout_p=dropout,
609
+ is_causal=False,
610
+ )
611
+ attn_output = attn_output.transpose(1, 2).contiguous()
612
+ return attn_output, None
613
+
614
+
615
+
616
+ class Qwen3Attention(nn.Module):
617
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
618
+
619
+ def __init__(self, config: Qwen3Config, layer_idx: int):
620
+ super().__init__()
621
+ self.config = config
622
+ self.layer_idx = layer_idx
623
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
624
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
625
+ self.scaling = self.head_dim**-0.5
626
+ self.attention_dropout = config.attention_dropout
627
+
628
+ self.q_proj = nn.Linear(
629
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
630
+ )
631
+ self.k_proj = nn.Linear(
632
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
633
+ )
634
+ self.v_proj = nn.Linear(
635
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
636
+ )
637
+ self.o_proj = nn.Linear(
638
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
639
+ )
640
+ self.q_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim!
641
+ self.k_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape
642
+ self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None
643
+ if getattr(config, "_attn_implementation", None) == "eager":
644
+ self._decode_attn_fn = eager_attention_forward
645
+ else:
646
+ self._decode_attn_fn = _sdpa_attention_forward
647
+
648
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
649
+ def forward(
650
+ self,
651
+ hidden_states: torch.Tensor,
652
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
653
+ attention_mask: Optional[torch.Tensor],
654
+ past_key_values: Optional[Cache] = None,
655
+ cache_position: Optional[torch.LongTensor] = None,
656
+ **kwargs: Unpack[FlashAttentionKwargs],
657
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
658
+ input_shape = hidden_states.shape[:-1]
659
+ hidden_shape = (*input_shape, -1, self.head_dim)
660
+
661
+
662
+ query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
663
+ key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
664
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
665
+
666
+ cos, sin = position_embeddings
667
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
668
+
669
+ if past_key_values is not None:
670
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
671
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
672
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
673
+
674
+ attn_output, attn_weights = self._decode_attn_fn(
675
+ self,
676
+ query_states,
677
+ key_states,
678
+ value_states,
679
+ attention_mask,
680
+ dropout=0.0 if not self.training else self.attention_dropout,
681
+ scaling=self.scaling,
682
+ sliding_window=self.sliding_window, # diff with Llama
683
+ **kwargs,
684
+ )
685
+
686
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
687
+ attn_output = self.o_proj(attn_output)
688
+ return attn_output, attn_weights
689
+
690
+
691
+ class Qwen3SummaryAttention(Qwen3Attention):
692
+ """
693
+ Summary-aware variant of Qwen3Attention: uses a sliding summary mask.
694
+ """
695
+
696
+ def __init__(self, config: Qwen3Config, layer_idx: int):
697
+ super().__init__(config, layer_idx)
698
+ self.summary_chunk_size = getattr(self.config, "summary_chunk_size", 0)
699
+ self.summary_token_num = getattr(self.config, "summary_token_num", 0)
700
+
701
+ # Cache sliding_chunk_num to avoid eval() on every forward call
702
+ val = getattr(config, "summary_sliding_chunk_num", 0) or 0
703
+ val = _parse_config_pattern(val)
704
+ if isinstance(val, list):
705
+ self._sliding_chunk_num = val[layer_idx]
706
+ else:
707
+ self._sliding_chunk_num = int(val)
708
+
709
+ if config.summary_independent_parameters:
710
+ self.q_proj_summary = nn.Linear(
711
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
712
+ )
713
+ self.k_proj_summary = nn.Linear(
714
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
715
+ )
716
+ self.v_proj_summary = nn.Linear(
717
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
718
+ )
719
+
720
+ def _get_sliding_chunk_num(self):
721
+ return self._sliding_chunk_num
722
+
723
+ def get_query_key_value_tensors(self, hidden_states):
724
+ input_shape = hidden_states.shape[:-1]
725
+ hidden_shape = (*input_shape, -1, self.head_dim)
726
+ query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
727
+ key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
728
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
729
+
730
+ return query_states, key_states, value_states
731
+
732
+ def get_query_key_value_tensors_summary(self, hidden_states):
733
+ input_shape = hidden_states.shape[:-1]
734
+ hidden_shape = (*input_shape, -1, self.head_dim)
735
+ query_states = self.q_norm(self.q_proj_summary(hidden_states).view(hidden_shape)).transpose(1, 2)
736
+ key_states = self.k_norm(self.k_proj_summary(hidden_states).view(hidden_shape)).transpose(1, 2)
737
+ value_states = self.v_proj_summary(hidden_states).view(hidden_shape).transpose(1, 2)
738
+
739
+ return query_states, key_states, value_states
740
+
741
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
742
+ def forward(
743
+ self,
744
+ hidden_states: torch.Tensor,
745
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
746
+ attention_mask: Optional[torch.Tensor] = None,
747
+ past_key_values: Optional[Cache] = None,
748
+ cache_position: Optional[torch.LongTensor] = None,
749
+ summary_ctx: Optional[SummaryBatchContext] = None,
750
+ **kwargs,
751
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
752
+ input_shape = hidden_states.shape[:-1]
753
+ if hidden_states.size(0) != 1:
754
+ raise ValueError("Summary sliding attention only supports batch size=1.")
755
+
756
+ # Compute q/k/v for the full sequence once.
757
+ if self.config.summary_independent_parameters:
758
+ if summary_ctx is None:
759
+ raise ValueError("summary_ctx is required when using summary_independent_parameters.")
760
+ summary_mask = summary_ctx.summary_mask
761
+ summary_pos = summary_mask[0]
762
+ assert (summary_mask == summary_mask[0:1]).all()
763
+
764
+ if self.config.mix_coeff == 0:
765
+ # When mix_coeff=0, summary projections have no effect — skip clone + extra linear
766
+ query_states, key_states, value_states = self.get_query_key_value_tensors(hidden_states)
767
+ else:
768
+ query, key, value = self.get_query_key_value_tensors(hidden_states)
769
+
770
+ query_states = query.clone()
771
+ key_states = key.clone()
772
+ value_states = value.clone()
773
+
774
+ hs_summary = hidden_states[:, summary_pos, :]
775
+ if hs_summary.size(1) > 0:
776
+ base_q_summary = query[:, :, summary_pos, :]
777
+ base_k_summary = key[:, :, summary_pos, :]
778
+ base_v_summary = value[:, :, summary_pos, :]
779
+
780
+ q_s, k_s, v_s = self.get_query_key_value_tensors_summary(hs_summary)
781
+
782
+ q_s = self.config.mix_coeff * q_s + (1.0 - self.config.mix_coeff) * base_q_summary
783
+ k_s = self.config.mix_coeff * k_s + (1.0 - self.config.mix_coeff) * base_k_summary
784
+ v_s = self.config.mix_coeff * v_s + (1.0 - self.config.mix_coeff) * base_v_summary
785
+
786
+ query_states[:, :, summary_pos, :] = q_s
787
+ key_states[:, :, summary_pos, :] = k_s
788
+ value_states[:, :, summary_pos, :] = v_s
789
+ else:
790
+ query_states, key_states, value_states = self.get_query_key_value_tensors(hidden_states)
791
+
792
+ cos, sin = position_embeddings
793
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
794
+
795
+ query_len = query_states.shape[2]
796
+ is_prefill = past_key_values is None or not past_key_values._reorganized
797
+
798
+ if is_prefill:
799
+ # Prefill: use standard append and summary_attn_func
800
+ if past_key_values is not None:
801
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
802
+ if summary_ctx is not None:
803
+ cache_kwargs["summary_mask"] = summary_ctx.summary_mask
804
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
805
+
806
+ with torch.cuda.device(query_states.device):
807
+ attn_output, attn_weights = summary_attn_func(
808
+ query_states.transpose(1,2).contiguous(),
809
+ key_states.transpose(1,2).contiguous(),
810
+ value_states.transpose(1,2).contiguous(),
811
+ self.summary_chunk_size,
812
+ self.summary_token_num,
813
+ self._get_sliding_chunk_num(),
814
+ summary_pos=summary_ctx.summary_mask.squeeze()
815
+ )
816
+ elif query_len == 1:
817
+ # Single text token decode: write to cache, attend to full buffer
818
+ past_key_values.update_text(key_states, value_states, self.layer_idx)
819
+ k_full, v_full = past_key_values.get_attention_kv(self.layer_idx)
820
+ attn_output, attn_weights = self._decode_attn_fn(
821
+ self,
822
+ query_states,
823
+ k_full,
824
+ v_full,
825
+ None,
826
+ dropout=0.0 if not self.training else self.attention_dropout,
827
+ scaling=self.scaling,
828
+ sliding_window=self.sliding_window,
829
+ **kwargs,
830
+ )
831
+ else:
832
+ # Chunk boundary: query = [text_token, summary_token(s)]
833
+ # Split into text (first token) and summary (remaining tokens)
834
+ q_text = query_states[:, :, :1, :]
835
+ q_summary = query_states[:, :, 1:, :]
836
+ k_text = key_states[:, :, :1, :]
837
+ v_text = value_states[:, :, :1, :]
838
+ k_summary = key_states[:, :, 1:, :]
839
+ v_summary = value_states[:, :, 1:, :]
840
+
841
+ # 1. Write text token to cache, get full KV, run text attention
842
+ past_key_values.update_text(k_text, v_text, self.layer_idx)
843
+ k_full, v_full = past_key_values.get_attention_kv(self.layer_idx)
844
+ text_out, _ = self._decode_attn_fn(
845
+ self,
846
+ q_text,
847
+ k_full,
848
+ v_full,
849
+ None,
850
+ dropout=0.0 if not self.training else self.attention_dropout,
851
+ scaling=self.scaling,
852
+ sliding_window=self.sliding_window,
853
+ **kwargs,
854
+ )
855
+
856
+ # 2. Summary attention: attend to current chunk's C text tokens + own KV (self-attention)
857
+ # Uses scratch area [0:S] for summary self-KV, contiguous with chunk [S:S+C].
858
+ k_chunk_with_self, v_chunk_with_self = past_key_values.get_summary_attention_kv(
859
+ self.layer_idx, k_summary, v_summary)
860
+ summary_out, _ = self._decode_attn_fn(
861
+ self,
862
+ q_summary,
863
+ k_chunk_with_self,
864
+ v_chunk_with_self,
865
+ None,
866
+ dropout=0.0 if not self.training else self.attention_dropout,
867
+ scaling=self.scaling,
868
+ sliding_window=self.sliding_window,
869
+ **kwargs,
870
+ )
871
+
872
+ # 3. Write summary KV to cache
873
+ past_key_values.update_summary(k_summary, v_summary, self.layer_idx)
874
+
875
+ attn_output = torch.cat([text_out, summary_out], dim=2)
876
+ attn_weights = None
877
+
878
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
879
+ attn_output = self.o_proj(attn_output)
880
+ return attn_output, attn_weights
881
+
882
+
883
+ class Qwen3DecoderLayer(GradientCheckpointingLayer):
884
+ def __init__(self, config: Qwen3Config, layer_idx: int):
885
+ super().__init__()
886
+ self.config = config
887
+ self.hidden_size = config.hidden_size
888
+
889
+ # Use SummaryAttention if enabled in config
890
+ if getattr(config, "use_summary_attention", False) is True and config.summary_layer_freq[layer_idx] == 1:
891
+ self.self_attn = Qwen3SummaryAttention(config=config, layer_idx=layer_idx)
892
+ elif getattr(config, "use_summary_attention", False) is False and config.summary_layer_freq[layer_idx] == 0:
893
+ self.self_attn = Qwen3Attention(config=config, layer_idx=layer_idx)
894
+ else:
895
+ raise ValueError(f'Check config.summary_layer_freq {config.summary_layer_freq} and config.use_summary_attention {config.use_summary_attention}')
896
+
897
+ self.mlp = Qwen3MLP(config)
898
+ self.input_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
899
+ if getattr(config, "summary_independent_attention_layernorm", False):
900
+ self.input_layernorm_summary = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
901
+ self.post_attention_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
902
+ self.attention_type = config.layer_types[layer_idx]
903
+
904
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
905
+ def forward(
906
+ self,
907
+ hidden_states: torch.Tensor,
908
+ attention_mask: Optional[torch.Tensor] = None,
909
+ position_ids: Optional[torch.LongTensor] = None,
910
+ past_key_values: Optional[Cache] = None,
911
+ use_cache: Optional[bool] = False,
912
+ cache_position: Optional[torch.LongTensor] = None,
913
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
914
+ summary_ctx: Optional[SummaryBatchContext] = None,
915
+ **kwargs: Unpack[TransformersKwargs],
916
+ ) -> torch.Tensor:
917
+ residual = hidden_states
918
+ if getattr(self.config, "summary_independent_attention_layernorm", False):
919
+ summary_mask = summary_ctx.summary_mask
920
+ assert (summary_mask == summary_mask[0:1]).all(), \
921
+ "summary_mask must be identical across batch"
922
+ hidden_states = self.input_layernorm(hidden_states)
923
+ if summary_mask.any():
924
+ hidden_summary = residual[:, summary_mask[0].to(residual.device), :]
925
+ hidden_summary = self.input_layernorm_summary(hidden_summary)
926
+ hidden_states[:, summary_mask[0], :] = hidden_summary
927
+ else:
928
+ hidden_states = self.input_layernorm(hidden_states)
929
+
930
+ # Self Attention - pass summary_ctx if using summary attention
931
+ attn_kwargs = {
932
+ "hidden_states": hidden_states,
933
+ "attention_mask": attention_mask,
934
+ "position_ids": position_ids,
935
+ "past_key_values": past_key_values,
936
+ "use_cache": use_cache,
937
+ "cache_position": cache_position,
938
+ "position_embeddings": position_embeddings,
939
+ **kwargs,
940
+ }
941
+ if isinstance(self.self_attn, Qwen3SummaryAttention):
942
+ attn_kwargs["summary_ctx"] = summary_ctx
943
+
944
+ hidden_states, _ = self.self_attn(**attn_kwargs)
945
+ hidden_states = residual + hidden_states
946
+
947
+ # Fully Connected
948
+ residual = hidden_states
949
+ hidden_states = self.post_attention_layernorm(hidden_states)
950
+ hidden_states = self.mlp(hidden_states)
951
+ hidden_states = residual + hidden_states
952
+ return hidden_states
953
+
954
+
955
+ @auto_docstring
956
+ class Qwen3PreTrainedModel(PreTrainedModel):
957
+ config: Qwen3Config
958
+ base_model_prefix = "model"
959
+ supports_gradient_checkpointing = True
960
+ _no_split_modules = ["Qwen3DecoderLayer"]
961
+ _skip_keys_device_placement = ["past_key_values"]
962
+ _supports_flash_attn = True
963
+ _supports_sdpa = True
964
+ _supports_flex_attn = True
965
+
966
+ _can_compile_fullgraph = True
967
+ _supports_attention_backend = True
968
+ _can_record_outputs = {
969
+ "hidden_states": Qwen3DecoderLayer,
970
+ "attentions": Qwen3Attention,
971
+ }
972
+
973
+
974
+ class Qwen3RotaryEmbedding(nn.Module):
975
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
976
+
977
+ def __init__(self, config: Qwen3Config, device=None):
978
+ super().__init__()
979
+ self.max_seq_len_cached = config.max_position_embeddings
980
+ self.original_max_seq_len = config.max_position_embeddings
981
+
982
+ self.config = config
983
+
984
+ self.rope_type = self.config.rope_parameters["rope_type"]
985
+ rope_init_fn: Callable = self.compute_default_rope_parameters
986
+ if self.rope_type != "default":
987
+ rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
988
+ inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
989
+
990
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
991
+ self.original_inv_freq = inv_freq
992
+
993
+ @staticmethod
994
+ def compute_default_rope_parameters(
995
+ config: Optional[Qwen3Config] = None,
996
+ device: Optional["torch.device"] = None,
997
+ seq_len: Optional[int] = None,
998
+ ) -> tuple["torch.Tensor", float]:
999
+ """
1000
+ Computes the inverse frequencies according to the original RoPE implementation
1001
+ Args:
1002
+ config ([`~transformers.PreTrainedConfig`]):
1003
+ The model configuration.
1004
+ device (`torch.device`):
1005
+ The device to use for initialization of the inverse frequencies.
1006
+ seq_len (`int`, *optional*):
1007
+ The current sequence length. Unused for this type of RoPE.
1008
+ Returns:
1009
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
1010
+ post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
1011
+ """
1012
+ base = config.rope_parameters["rope_theta"]
1013
+ dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
1014
+
1015
+ attention_factor = 1.0 # Unused in this type of RoPE
1016
+
1017
+ # Compute the inverse frequencies
1018
+ inv_freq = 1.0 / (
1019
+ base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
1020
+ )
1021
+ return inv_freq, attention_factor
1022
+
1023
+ @torch.no_grad()
1024
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
1025
+ def forward(self, x, position_ids):
1026
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
1027
+ position_ids_expanded = position_ids[:, None, :].float()
1028
+
1029
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
1030
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
1031
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
1032
+ emb = torch.cat((freqs, freqs), dim=-1)
1033
+ cos = emb.cos() * self.attention_scaling
1034
+ sin = emb.sin() * self.attention_scaling
1035
+
1036
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
1037
+
1038
+
1039
+ @auto_docstring
1040
+ class Qwen3Model(Qwen3PreTrainedModel):
1041
+ def __init__(self, config: Qwen3Config):
1042
+ super().__init__(config)
1043
+ self.padding_idx = config.pad_token_id
1044
+ self.vocab_size = config.vocab_size
1045
+ if not getattr(config, "summary_layer_freq", False):
1046
+ if config.use_summary_attention:
1047
+ config.summary_layer_freq = [1]*config.num_hidden_layers
1048
+ else:
1049
+ config.summary_layer_freq = [0]*config.num_hidden_layers
1050
+ Warning(f'Please set config.summary_layer_freq, temp set summary_layer_freq = {config.num_hidden_layers}')
1051
+ else:
1052
+ config.summary_layer_freq = _parse_config_pattern(config.summary_layer_freq)
1053
+
1054
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
1055
+ self.layers = nn.ModuleList(
1056
+ [Qwen3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
1057
+ )
1058
+ self.norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1059
+ self.rotary_emb = Qwen3RotaryEmbedding(config=config)
1060
+ self.gradient_checkpointing = False
1061
+ self.has_sliding_layers = "sliding_attention" in self.config.layer_types
1062
+
1063
+ # Cache per-layer sliding_chunk_nums for KV cache eviction
1064
+ _sv = _parse_config_pattern(getattr(config, "summary_sliding_chunk_num", 0) or 0)
1065
+ if isinstance(_sv, list):
1066
+ self._sliding_chunk_nums = [int(v) for v in _sv]
1067
+ else:
1068
+ self._sliding_chunk_nums = [int(_sv)] * config.num_hidden_layers
1069
+
1070
+ # Initialize weights and apply final processing
1071
+ self.post_init()
1072
+
1073
+ def _expand_input_with_summary_tokens(self, input_ids):
1074
+ """Expand input_ids with summary tokens for prefill phase (vectorized).
1075
+
1076
+ Returns:
1077
+ Tuple of (expanded_input_ids, position_ids, text_only_mask)
1078
+ """
1079
+ summary_chunk = self.config.summary_chunk_size
1080
+ summary_num = self.config.summary_token_num
1081
+ summary_begin = self.config.summary_token_begin
1082
+
1083
+ if summary_chunk == 0 or summary_num == 0:
1084
+ return input_ids, None, None
1085
+
1086
+ batch_size, seq_len = input_ids.shape
1087
+ device = input_ids.device
1088
+ dtype = input_ids.dtype
1089
+ block = summary_chunk + summary_num
1090
+
1091
+ # Number of full chunks and remainder
1092
+ n_full_chunks = seq_len // summary_chunk
1093
+ remainder = seq_len % summary_chunk
1094
+ has_remainder = remainder > 0
1095
+
1096
+ # Total expanded length: full_chunks * block + remainder
1097
+ expanded_len = n_full_chunks * block + (remainder if has_remainder else 0)
1098
+
1099
+ # --- Build expanded_input_ids ---
1100
+ expanded_ids = torch.empty((batch_size, expanded_len), dtype=dtype, device=device)
1101
+ text_only_mask = torch.zeros((batch_size, expanded_len), dtype=torch.bool, device=device)
1102
+
1103
+ # Compute text positions: for chunk i, text goes to [i*block, i*block+summary_chunk)
1104
+ # Summary positions: [i*block+summary_chunk, (i+1)*block)
1105
+ if n_full_chunks > 0:
1106
+ chunk_indices = torch.arange(n_full_chunks, device=device)
1107
+ # Text source positions in original input_ids
1108
+ text_src_offsets = (chunk_indices * summary_chunk).unsqueeze(1) + torch.arange(summary_chunk, device=device).unsqueeze(0) # [n_full_chunks, summary_chunk]
1109
+ # Text dest positions in expanded
1110
+ text_dst_offsets = (chunk_indices * block).unsqueeze(1) + torch.arange(summary_chunk, device=device).unsqueeze(0) # [n_full_chunks, summary_chunk]
1111
+ # Summary dest positions
1112
+ summary_dst_offsets = (chunk_indices * block + summary_chunk).unsqueeze(1) + torch.arange(summary_num, device=device).unsqueeze(0) # [n_full_chunks, summary_num]
1113
+
1114
+ text_src_flat = text_src_offsets.reshape(-1)
1115
+ text_dst_flat = text_dst_offsets.reshape(-1)
1116
+ summary_dst_flat = summary_dst_offsets.reshape(-1)
1117
+
1118
+ # Copy text tokens
1119
+ expanded_ids[:, text_dst_flat] = input_ids[:, text_src_flat]
1120
+ text_only_mask[:, text_dst_flat] = True
1121
+
1122
+ # Fill summary tokens
1123
+ summary_ids_val = torch.arange(summary_num, device=device, dtype=dtype) + summary_begin
1124
+ expanded_ids[:, summary_dst_flat] = summary_ids_val.repeat(n_full_chunks).unsqueeze(0).expand(batch_size, -1)
1125
+
1126
+ # Handle remainder (last partial chunk, no summary tokens)
1127
+ if has_remainder:
1128
+ rem_start_src = n_full_chunks * summary_chunk
1129
+ rem_start_dst = n_full_chunks * block
1130
+ rem_offsets = torch.arange(remainder, device=device)
1131
+ expanded_ids[:, rem_start_dst + rem_offsets] = input_ids[:, rem_start_src + rem_offsets]
1132
+ text_only_mask[:, rem_start_dst + rem_offsets] = True
1133
+
1134
+ # --- Build position_ids ---
1135
+ position_ids = torch.empty((batch_size, expanded_len), dtype=torch.long, device=device)
1136
+
1137
+ if n_full_chunks > 0:
1138
+ # Text position IDs
1139
+ if self.config.summary_chunk_position_ids_type == 'origin':
1140
+ text_pos = text_src_flat.unsqueeze(0).expand(batch_size, -1)
1141
+ elif self.config.summary_chunk_position_ids_type == 'inner_chunk':
1142
+ inner_pos = torch.arange(summary_chunk, device=device).repeat(n_full_chunks)
1143
+ text_pos = inner_pos.unsqueeze(0).expand(batch_size, -1)
1144
+ else:
1145
+ raise ValueError(f'Check config.summary_chunk_position_ids_type: {self.config.summary_chunk_position_ids_type}')
1146
+ position_ids[:, text_dst_flat] = text_pos
1147
+
1148
+ # Summary position IDs
1149
+ if self.config.summary_token_position_ids_type == 'zeros':
1150
+ position_ids[:, summary_dst_flat] = 0
1151
+ elif self.config.summary_token_position_ids_type in ('last_chunk_slice_left', 'last_chunk_slice_right'):
1152
+ # Vectorized slice_ends computation for all chunks at once
1153
+ if self.config.summary_token_position_ids_type == 'last_chunk_slice_left':
1154
+ idx = torch.arange(0, summary_num, device=device, dtype=torch.long)
1155
+ else:
1156
+ idx = torch.arange(1, summary_num + 1, device=device, dtype=torch.long)
1157
+ # For each chunk i: prev_text_end = i * summary_chunk
1158
+ prev_ends = (chunk_indices * summary_chunk).unsqueeze(1) # [n_full_chunks, 1]
1159
+ slice_ends = prev_ends + (idx.unsqueeze(0) * summary_chunk) // summary_num - 1 # [n_full_chunks, summary_num]
1160
+ slice_ends = slice_ends.clamp(min=0)
1161
+ # Clamp per-chunk: min is prev_text_end for that chunk
1162
+ slice_ends = torch.max(slice_ends, prev_ends)
1163
+ position_ids[:, summary_dst_flat] = slice_ends.reshape(-1).unsqueeze(0).expand(batch_size, -1)
1164
+ else:
1165
+ raise ValueError(f'Unknown summary_token_position_ids_type: {self.config.summary_token_position_ids_type}')
1166
+
1167
+ # Remainder position IDs
1168
+ if has_remainder:
1169
+ if self.config.summary_chunk_position_ids_type == 'origin':
1170
+ rem_pos = (rem_start_src + rem_offsets).unsqueeze(0).expand(batch_size, -1)
1171
+ elif self.config.summary_chunk_position_ids_type == 'inner_chunk':
1172
+ rem_pos = rem_offsets.unsqueeze(0).expand(batch_size, -1)
1173
+ else:
1174
+ raise ValueError(f'Check config.summary_chunk_position_ids_type: {self.config.summary_chunk_position_ids_type}')
1175
+ position_ids[:, rem_start_dst + rem_offsets] = rem_pos
1176
+
1177
+ return expanded_ids, position_ids, text_only_mask
1178
+
1179
+ def _build_summary_context(self, input_ids, position_ids, is_prefill, use_cache):
1180
+ """Build summary context for attention layers."""
1181
+ summary_chunk = self.config.summary_chunk_size
1182
+ summary_num = self.config.summary_token_num
1183
+ summary_begin = self.config.summary_token_begin
1184
+
1185
+ if summary_chunk > 0 and summary_num > 0:
1186
+ return build_summary_sliding_context(
1187
+ input_ids=input_ids,
1188
+ position_ids=position_ids,
1189
+ summary_token_num=summary_num,
1190
+ summary_token_begin=summary_begin,
1191
+ )
1192
+ return None
1193
+
1194
+ def _filter_summary_tokens(self, hidden_states, text_only_mask, use_summary, is_decode):
1195
+ """Filter out summary tokens from output hidden states."""
1196
+ if text_only_mask is not None:
1197
+ # Prefill: vectorized filtering using boolean mask
1198
+ batch_size, _, hidden_size = hidden_states.shape
1199
+ text_length = text_only_mask[0].sum().item()
1200
+ return hidden_states[text_only_mask.to(hidden_states.device)].reshape(batch_size, text_length, hidden_size)
1201
+ elif use_summary and is_decode and hidden_states.size(1) > 1:
1202
+ # Decode: if we have multiple tokens, only return the first (text token)
1203
+ return hidden_states[:, :1, :]
1204
+ return hidden_states
1205
+
1206
+ @check_model_inputs()
1207
+ @auto_docstring
1208
+ def forward(
1209
+ self,
1210
+ input_ids: Optional[torch.LongTensor] = None,
1211
+ attention_mask: Optional[torch.Tensor] = None,
1212
+ position_ids: Optional[torch.LongTensor] = None,
1213
+ past_key_values: Optional[Cache] = None,
1214
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1215
+ use_cache: Optional[bool] = None,
1216
+ cache_position: Optional[torch.LongTensor] = None,
1217
+ summary_ctx: Optional[SummaryBatchContext] = None,
1218
+ **kwargs: Unpack[TransformersKwargs],
1219
+ ) -> BaseModelOutputWithPast:
1220
+ if (input_ids is None) ^ (inputs_embeds is not None):
1221
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
1222
+ use_summary = getattr(self.config, "use_summary_attention", False)
1223
+ is_prefill = past_key_values is None or past_key_values.get_seq_length() == 0
1224
+
1225
+ # Prefill phase with summary attention: expand input_ids with summary tokens
1226
+ text_only_mask = None
1227
+ if use_summary and input_ids is not None and inputs_embeds is None and is_prefill:
1228
+ input_ids, position_ids, text_only_mask = self._expand_input_with_summary_tokens(input_ids)
1229
+
1230
+ if inputs_embeds is None:
1231
+ inputs_embeds = self.embed_tokens(input_ids)
1232
+
1233
+ # Initialize cache
1234
+ if use_cache and past_key_values is None:
1235
+ if use_summary:
1236
+ past_key_values = Qwen3RingBufferCache(
1237
+ config=self.config, sliding_chunk_nums=self._sliding_chunk_nums)
1238
+ else:
1239
+ past_key_values = DynamicCache(config=self.config)
1240
+
1241
+ if cache_position is None:
1242
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
1243
+ cache_position = torch.arange(
1244
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
1245
+ )
1246
+
1247
+ if position_ids is None:
1248
+ position_ids = cache_position.unsqueeze(0)
1249
+
1250
+ # Build summary context if needed
1251
+ if use_summary and summary_ctx is None and input_ids is not None:
1252
+ summary_ctx = self._build_summary_context(input_ids, position_ids, is_prefill, use_cache)
1253
+
1254
+ causal_mask_mapping = attention_mask
1255
+ if not isinstance(causal_mask_mapping, (dict, list)):
1256
+ if summary_ctx and summary_ctx.enabled:
1257
+ seq_len = inputs_embeds.shape[1]
1258
+ # During prefill, Qwen3SummaryAttention uses summary_attn_func
1259
+ # which does not need a dense mask. Skip expensive mask construction.
1260
+ # During decode, prepare_inputs_for_generation already computed
1261
+ # per-layer keep_indices and passed them as attention_mask (list).
1262
+ # If we reach here with a non-list, it means no mask is needed.
1263
+ causal_mask_mapping = None
1264
+ else:
1265
+ # Prepare mask arguments
1266
+ mask_kwargs = {
1267
+ "config": self.config,
1268
+ "input_embeds": inputs_embeds,
1269
+ "attention_mask": attention_mask,
1270
+ "cache_position": cache_position,
1271
+ "past_key_values": past_key_values,
1272
+ "position_ids": position_ids,
1273
+ }
1274
+ # Create the masks - disable causal mask when summary context is enabled
1275
+ causal_mask_mapping = {
1276
+ "full_attention": create_causal_mask(**mask_kwargs),
1277
+ }
1278
+ # The sliding window alternating layers are not always activated depending on the config
1279
+ if self.has_sliding_layers:
1280
+ causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs)
1281
+
1282
+ hidden_states = inputs_embeds
1283
+
1284
+ # create position embeddings to be shared across the decoder layers
1285
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
1286
+
1287
+ for layer_idx, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
1288
+ if causal_mask_mapping is None:
1289
+ layer_mask = None
1290
+ elif isinstance(causal_mask_mapping, list):
1291
+ layer_mask = causal_mask_mapping[layer_idx]
1292
+ else:
1293
+ layer_mask = causal_mask_mapping[decoder_layer.attention_type]
1294
+ hidden_states = decoder_layer(
1295
+ hidden_states,
1296
+ attention_mask=layer_mask,
1297
+ position_ids=position_ids,
1298
+ past_key_values=past_key_values,
1299
+ use_cache=use_cache,
1300
+ cache_position=cache_position,
1301
+ position_embeddings=position_embeddings,
1302
+ summary_ctx=summary_ctx,
1303
+ **kwargs,
1304
+ )
1305
+
1306
+ hidden_states = self.norm(hidden_states)
1307
+
1308
+ # After prefill: reorganize cache to ring buffer layout
1309
+ if use_cache and use_summary and past_key_values is not None and is_prefill:
1310
+ if hasattr(past_key_values, 'reorganize_after_prefill') and summary_ctx is not None:
1311
+ past_key_values.reorganize_after_prefill(summary_ctx.summary_mask)
1312
+
1313
+ # After chunk boundary decode: reset chunk counters
1314
+ if use_cache and use_summary and past_key_values is not None and not is_prefill:
1315
+ if hasattr(past_key_values, 'reset_chunk_counter'):
1316
+ past_key_values.reset_chunk_counter()
1317
+
1318
+ # Filter out summary tokens from output
1319
+ hidden_states = self._filter_summary_tokens(hidden_states, text_only_mask, use_summary,
1320
+ past_key_values is not None and past_key_values.get_seq_length() > 0)
1321
+
1322
+ return BaseModelOutputWithPast(
1323
+ last_hidden_state=hidden_states,
1324
+ past_key_values=past_key_values if use_cache else None,
1325
+ )
1326
+
1327
+
1328
+ @auto_docstring
1329
+ class Qwen3ForCausalLM(Qwen3PreTrainedModel, GenerationMixin):
1330
+ _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
1331
+ _tp_plan = {"lm_head": "colwise_rep"}
1332
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
1333
+
1334
+ def __init__(self, config):
1335
+ super().__init__(config)
1336
+ self.model = Qwen3Model(config)
1337
+ self.vocab_size = config.vocab_size
1338
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1339
+
1340
+ # Initialize weights and apply final processing
1341
+ self.post_init()
1342
+
1343
+ @can_return_tuple
1344
+ @auto_docstring
1345
+ def forward(
1346
+ self,
1347
+ input_ids: Optional[torch.LongTensor] = None,
1348
+ attention_mask: Optional[torch.Tensor] = None,
1349
+ position_ids: Optional[torch.LongTensor] = None,
1350
+ past_key_values: Optional[Cache] = None,
1351
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1352
+ labels: Optional[torch.LongTensor] = None,
1353
+ use_cache: Optional[bool] = None,
1354
+ cache_position: Optional[torch.LongTensor] = None,
1355
+ logits_to_keep: Union[int, torch.Tensor] = 0,
1356
+ summary_ctx: Optional[SummaryBatchContext] = None,
1357
+ **kwargs: Unpack[TransformersKwargs],
1358
+ ) -> CausalLMOutputWithPast:
1359
+ r"""
1360
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1361
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1362
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1363
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1364
+
1365
+ Example:
1366
+
1367
+ ```python
1368
+ >>> from transformers import AutoTokenizer, Qwen3ForCausalLM
1369
+
1370
+ >>> model = Qwen3ForCausalLM.from_pretrained("Qwen/Qwen3-8B")
1371
+ >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B")
1372
+
1373
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
1374
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1375
+
1376
+ >>> # Generate
1377
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1378
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1379
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1380
+ ```"""
1381
+ outputs: BaseModelOutputWithPast = self.model(
1382
+ input_ids=input_ids,
1383
+ attention_mask=attention_mask,
1384
+ position_ids=position_ids,
1385
+ past_key_values=past_key_values,
1386
+ inputs_embeds=inputs_embeds,
1387
+ use_cache=use_cache,
1388
+ cache_position=cache_position,
1389
+ summary_ctx=summary_ctx,
1390
+ **kwargs,
1391
+ )
1392
+
1393
+ hidden_states = outputs.last_hidden_state
1394
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
1395
+ if isinstance(logits_to_keep, int) and logits_to_keep == 0 and labels is None:
1396
+ # Inference: only need last token's logits to avoid OOM from [seq_len, vocab_size]
1397
+ logits = self.lm_head(hidden_states[:, -1:, :])
1398
+ else:
1399
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
1400
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
1401
+
1402
+ truncate_n = getattr(self.config, "truncate_predict_nums", 151936)
1403
+ if truncate_n > 0:
1404
+ logits = logits[..., :truncate_n]
1405
+
1406
+ loss = None
1407
+ if labels is not None:
1408
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=logits.shape[-1], **kwargs)
1409
+
1410
+ return CausalLMOutputWithPast(
1411
+ loss=loss,
1412
+ logits=logits,
1413
+ past_key_values=outputs.past_key_values,
1414
+ hidden_states=outputs.hidden_states,
1415
+ attentions=outputs.attentions,
1416
+ )
1417
+
1418
+ def _build_summary_attention_mask_for_generation(
1419
+ self,
1420
+ *,
1421
+ input_ids: torch.LongTensor,
1422
+ past_key_values: Optional[Cache],
1423
+ attention_mask: Optional[torch.Tensor],
1424
+ ) -> Optional[torch.Tensor]:
1425
+ """Ring buffer cache handles attention internally — no mask needed for decode."""
1426
+ if isinstance(past_key_values, Qwen3RingBufferCache):
1427
+ return None
1428
+ return attention_mask
1429
+
1430
+ def prepare_inputs_for_generation(
1431
+ self,
1432
+ input_ids: torch.LongTensor,
1433
+ past_key_values: Optional[Cache] = None,
1434
+ attention_mask: Optional[torch.LongTensor] = None,
1435
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1436
+ cache_position: Optional[torch.LongTensor] = None,
1437
+ position_ids: Optional[torch.LongTensor] = None,
1438
+ **kwargs,
1439
+ ):
1440
+ use_summary = getattr(self.config, "use_summary_attention", False)
1441
+
1442
+ # If not using summary attention, use standard behavior
1443
+ if not use_summary:
1444
+ return super().prepare_inputs_for_generation(
1445
+ input_ids=input_ids,
1446
+ past_key_values=past_key_values,
1447
+ attention_mask=attention_mask,
1448
+ inputs_embeds=inputs_embeds,
1449
+ cache_position=cache_position,
1450
+ position_ids=position_ids,
1451
+ **kwargs,
1452
+ )
1453
+
1454
+ # For summary attention: handle cache-based input slicing
1455
+ summary_chunk_size = getattr(self.config, "summary_chunk_size", 0)
1456
+ summary_token_num = getattr(self.config, "summary_token_num", 0)
1457
+ summary_token_begin = getattr(self.config, "summary_token_begin", 0)
1458
+
1459
+ # Prefill phase: pass full sequence, forward() will handle summary token insertion
1460
+ if past_key_values is None or past_key_values.get_seq_length() == 0:
1461
+ if cache_position is None:
1462
+ cache_position = torch.arange(0, input_ids.shape[1], device=input_ids.device)
1463
+
1464
+ return {
1465
+ "input_ids": input_ids,
1466
+ "attention_mask": attention_mask,
1467
+ "position_ids": position_ids,
1468
+ "past_key_values": past_key_values,
1469
+ "cache_position": cache_position,
1470
+ "use_cache": kwargs.get("use_cache"),
1471
+ }
1472
+
1473
+ # Decode phase: only pass new tokens not in cache
1474
+ # Get current chunk size (number of text tokens in current chunk)
1475
+ cur_chunk = past_key_values.get_cur_chunk_size() if hasattr(past_key_values, "get_cur_chunk_size") else 0
1476
+ true_token_num = past_key_values.get_true_token_num()
1477
+
1478
+ # Only take the new tokens that haven't been processed
1479
+ if input_ids.shape[1] > 1:
1480
+ # Slice to get only new tokens
1481
+ new_token_count = input_ids.shape[1] - true_token_num
1482
+ assert new_token_count > 0, f'new_token_count={new_token_count} should be greater than 0'
1483
+ input_ids = input_ids[:, -new_token_count:]
1484
+ device = input_ids.device
1485
+ # Check if we need to insert summary tokens
1486
+ # If cur_chunk >= summary_chunk_size, we need to generate summary tokens
1487
+ if cur_chunk == summary_chunk_size - 1:
1488
+ # Insert summary tokens
1489
+ batch_size = input_ids.shape[0]
1490
+ summary_ids = (
1491
+ torch.arange(summary_token_num, device=device, dtype=input_ids.dtype)
1492
+ + summary_token_begin
1493
+ ).unsqueeze(0).repeat(batch_size, 1)
1494
+
1495
+ # Concatenate: [text_token, summary_tokens]
1496
+ input_ids = torch.cat([input_ids, summary_ids], dim=1)
1497
+
1498
+ # Position IDs: text token uses cur_chunk, summary tokens use 0
1499
+ if self.config.summary_chunk_position_ids_type == 'origin':
1500
+ text_pos = torch.full((batch_size, 1), past_key_values.get_true_token_num(), device=device, dtype=torch.long)
1501
+ elif self.config.summary_chunk_position_ids_type == 'inner_chunk':
1502
+ text_pos = torch.full((batch_size, 1), cur_chunk, device=device, dtype=torch.long)
1503
+ else:
1504
+ raise ValueError(f'Check config.summary_chunk_position_ids_type: {self.config.summary_chunk_position_ids_type}')
1505
+
1506
+ if self.config.summary_token_position_ids_type == 'zeros':
1507
+ summary_pos = torch.zeros((batch_size, summary_token_num), device=device, dtype=torch.long)
1508
+ elif self.config.summary_token_position_ids_type == 'last_chunk_slice_left':
1509
+ # 等分成 summary_num 份,每个 summary token 取对应 slice 的末尾
1510
+ prev_text_end = past_key_values.get_true_token_num()+1-summary_chunk_size
1511
+ cur_text_end = past_key_values.get_true_token_num()+1
1512
+ chunk_len = cur_text_end - prev_text_end
1513
+
1514
+ idx = torch.arange(0, summary_token_num, device=device, dtype=torch.long,)
1515
+
1516
+ # 每一份的末尾(全局 position)
1517
+ slice_ends = prev_text_end + (idx * chunk_len) // summary_token_num - 1
1518
+ slice_ends = slice_ends.clamp(min=prev_text_end)
1519
+
1520
+ summary_pos = slice_ends.to(dtype=torch.long, device=device).unsqueeze(0)
1521
+ elif self.config.summary_token_position_ids_type == 'last_chunk_slice_right':
1522
+ # 等分成 summary_num 份,每个 summary token 取对应 slice 的末尾
1523
+ prev_text_end = past_key_values.get_true_token_num()+1-summary_chunk_size
1524
+ cur_text_end = past_key_values.get_true_token_num()+1
1525
+ chunk_len = cur_text_end - prev_text_end
1526
+
1527
+ idx = torch.arange(1, summary_token_num + 1, device=device, dtype=torch.long,)
1528
+
1529
+ # 每一份的末尾(全局 position)
1530
+ slice_ends = prev_text_end + (idx * chunk_len) // summary_token_num - 1
1531
+ slice_ends = slice_ends.clamp(min=prev_text_end)
1532
+
1533
+ summary_pos = slice_ends.to(dtype=torch.long, device=device).unsqueeze(0)
1534
+
1535
+ else:
1536
+ raise ValueError('')
1537
+
1538
+ position_ids = torch.cat([text_pos, summary_pos], dim=1)
1539
+ else:
1540
+ # Normal decode: just the new text token with position = cur_chunk
1541
+ if position_ids is None:
1542
+ batch_size = input_ids.shape[0]
1543
+ if self.config.summary_chunk_position_ids_type == 'origin':
1544
+ position_ids = torch.full((batch_size, input_ids.shape[1]), past_key_values.get_true_token_num(), device=input_ids.device, dtype=torch.long)
1545
+ elif self.config.summary_chunk_position_ids_type == 'inner_chunk':
1546
+ position_ids = torch.full((batch_size, input_ids.shape[1]), cur_chunk, device=input_ids.device, dtype=torch.long)
1547
+ else:
1548
+ raise ValueError(f'Check config.summary_chunk_position_ids_type: {self.config.summary_chunk_position_ids_type}')
1549
+ return {
1550
+ "input_ids": input_ids,
1551
+ "attention_mask": self._build_summary_attention_mask_for_generation(
1552
+ input_ids=input_ids,
1553
+ past_key_values=past_key_values,
1554
+ attention_mask=attention_mask,
1555
+ ),
1556
+ "position_ids": position_ids,
1557
+ "past_key_values": past_key_values,
1558
+ "cache_position": cache_position,
1559
+ "use_cache": kwargs.get("use_cache"),
1560
+ }
1561
+
1562
+
1563
+ class Qwen3ForSequenceClassification(GenericForSequenceClassification, Qwen3PreTrainedModel):
1564
+ pass
1565
+
1566
+
1567
+ class Qwen3ForTokenClassification(GenericForTokenClassification, Qwen3PreTrainedModel):
1568
+ pass
1569
+
1570
+
1571
+ class Qwen3ForQuestionAnswering(GenericForQuestionAnswering, Qwen3PreTrainedModel):
1572
+ base_model_prefix = "transformer" # For BC, where `transformer` was used instead of `model`
1573
+
1574
+
1575
+ __all__ = [
1576
+ "Qwen3ForCausalLM",
1577
+ "Qwen3ForQuestionAnswering",
1578
+ "Qwen3PreTrainedModel",
1579
+ "Qwen3Model",
1580
+ "Qwen3ForSequenceClassification",
1581
+ "Qwen3ForTokenClassification",
1582
+ "Qwen3RingBufferCache",
1583
+ "Qwen3SummaryAttention",
1584
+ "SummaryBatchContext",
1585
+ "build_summary_context",
1586
+ "build_summary_sliding_context",
1587
+ ]
summary_context.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import List
5
+
6
+ import torch
7
+
8
+ __all__ = [
9
+ "SummaryChunkMeta",
10
+ "SummarySampleContext",
11
+ "SummaryBatchContext",
12
+ "build_summary_context",
13
+ "build_summary_sliding_context",
14
+ ]
15
+
16
+
17
+ @dataclass
18
+ class SummaryChunkMeta:
19
+ text_positions: torch.Tensor
20
+ summary_positions: torch.Tensor
21
+ prefix_summary_positions: torch.Tensor
22
+
23
+ @property
24
+ def window_positions(self) -> torch.Tensor:
25
+ if self.prefix_summary_positions.numel() == 0:
26
+ if self.summary_positions.numel() == 0:
27
+ return self.text_positions
28
+ return torch.cat((self.text_positions, self.summary_positions), dim=0)
29
+ if self.summary_positions.numel() == 0:
30
+ return torch.cat((self.prefix_summary_positions, self.text_positions), dim=0)
31
+ return torch.cat(
32
+ (self.prefix_summary_positions, self.text_positions, self.summary_positions),
33
+ dim=0,
34
+ )
35
+
36
+
37
+ @dataclass
38
+ class SummarySampleContext:
39
+ chunks: List[SummaryChunkMeta]
40
+
41
+
42
+ @dataclass
43
+ class SummaryBatchContext:
44
+ samples: List[SummarySampleContext]
45
+ position_ids: torch.Tensor
46
+ summary_mask: torch.Tensor
47
+
48
+ @property
49
+ def enabled(self) -> bool:
50
+ return self.summary_mask.numel() > 0
51
+
52
+
53
+ def build_summary_context(
54
+ input_ids: torch.Tensor,
55
+ position_ids: torch.Tensor,
56
+ summary_chunk_size: int,
57
+ summary_token_num: int,
58
+ summary_token_begin: int,
59
+ ) -> SummaryBatchContext:
60
+ """
61
+ Build SummaryBatchContext from already-expanded sequences: each chunk should
62
+ be text tokens (<= chunk_size) followed by summary_token_num summary tokens.
63
+ """
64
+ batch_size, seq_len = input_ids.shape
65
+ block_size = summary_chunk_size + summary_token_num
66
+
67
+ summary_mask = torch.zeros_like(input_ids, dtype=torch.bool)
68
+ samples: List[SummarySampleContext] = []
69
+
70
+ for b in range(batch_size):
71
+ chunks: List[SummaryChunkMeta] = []
72
+ prefix_summary_positions: List[torch.Tensor] = []
73
+ cursor = 0
74
+ while cursor < seq_len:
75
+ text_len = min(summary_chunk_size, seq_len - cursor)
76
+ if text_len <= 0:
77
+ break
78
+
79
+ text_positions = torch.arange(cursor, cursor + text_len, device=input_ids.device)
80
+ summary_start = cursor + text_len
81
+ summary_end = min(cursor + block_size, seq_len)
82
+
83
+ # Keep only true summary tokens (in case of ragged last block).
84
+ summary_positions = torch.arange(summary_start, summary_end, device=input_ids.device)
85
+ if summary_positions.numel() > 0:
86
+ summary_tokens = input_ids[b, summary_positions]
87
+ valid = (summary_tokens >= summary_token_begin) & (
88
+ summary_tokens < summary_token_begin + summary_token_num
89
+ )
90
+ summary_positions = summary_positions[valid]
91
+ if summary_positions.numel() > 0:
92
+ summary_mask[b, summary_positions] = True
93
+
94
+ prefix_tensor = (
95
+ torch.cat(prefix_summary_positions, dim=0)
96
+ if prefix_summary_positions
97
+ else torch.empty(0, device=input_ids.device, dtype=torch.long)
98
+ )
99
+
100
+ chunk_meta = SummaryChunkMeta(
101
+ text_positions=text_positions,
102
+ summary_positions=summary_positions,
103
+ prefix_summary_positions=prefix_tensor,
104
+ )
105
+ chunks.append(chunk_meta)
106
+ if summary_positions.numel() > 0:
107
+ prefix_summary_positions.append(summary_positions)
108
+
109
+ cursor += block_size
110
+
111
+ samples.append(SummarySampleContext(chunks=chunks))
112
+
113
+ return SummaryBatchContext(
114
+ samples=samples,
115
+ position_ids=position_ids,
116
+ summary_mask=summary_mask,
117
+ )
118
+
119
+
120
+ def build_summary_sliding_context(
121
+ input_ids: torch.Tensor,
122
+ position_ids: torch.Tensor,
123
+ summary_token_num: int,
124
+ summary_token_begin: int,
125
+ ) -> SummaryBatchContext:
126
+ summary_mask = (input_ids >= summary_token_begin) & (
127
+ input_ids < summary_token_begin + summary_token_num
128
+ )
129
+ return SummaryBatchContext(
130
+ samples=[],
131
+ position_ids=position_ids,
132
+ summary_mask=summary_mask,
133
+ )
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_prefix_space": false,
4
+ "added_tokens_decoder": {
5
+ "151643": {
6
+ "content": "<|endoftext|>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "151644": {
14
+ "content": "<|im_start|>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "151645": {
22
+ "content": "<|im_end|>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ },
29
+ "151646": {
30
+ "content": "<|object_ref_start|>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false,
35
+ "special": true
36
+ },
37
+ "151647": {
38
+ "content": "<|object_ref_end|>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false,
43
+ "special": true
44
+ },
45
+ "151648": {
46
+ "content": "<|box_start|>",
47
+ "lstrip": false,
48
+ "normalized": false,
49
+ "rstrip": false,
50
+ "single_word": false,
51
+ "special": true
52
+ },
53
+ "151649": {
54
+ "content": "<|box_end|>",
55
+ "lstrip": false,
56
+ "normalized": false,
57
+ "rstrip": false,
58
+ "single_word": false,
59
+ "special": true
60
+ },
61
+ "151650": {
62
+ "content": "<|quad_start|>",
63
+ "lstrip": false,
64
+ "normalized": false,
65
+ "rstrip": false,
66
+ "single_word": false,
67
+ "special": true
68
+ },
69
+ "151651": {
70
+ "content": "<|quad_end|>",
71
+ "lstrip": false,
72
+ "normalized": false,
73
+ "rstrip": false,
74
+ "single_word": false,
75
+ "special": true
76
+ },
77
+ "151652": {
78
+ "content": "<|vision_start|>",
79
+ "lstrip": false,
80
+ "normalized": false,
81
+ "rstrip": false,
82
+ "single_word": false,
83
+ "special": true
84
+ },
85
+ "151653": {
86
+ "content": "<|vision_end|>",
87
+ "lstrip": false,
88
+ "normalized": false,
89
+ "rstrip": false,
90
+ "single_word": false,
91
+ "special": true
92
+ },
93
+ "151654": {
94
+ "content": "<|vision_pad|>",
95
+ "lstrip": false,
96
+ "normalized": false,
97
+ "rstrip": false,
98
+ "single_word": false,
99
+ "special": true
100
+ },
101
+ "151655": {
102
+ "content": "<|image_pad|>",
103
+ "lstrip": false,
104
+ "normalized": false,
105
+ "rstrip": false,
106
+ "single_word": false,
107
+ "special": true
108
+ },
109
+ "151656": {
110
+ "content": "<|video_pad|>",
111
+ "lstrip": false,
112
+ "normalized": false,
113
+ "rstrip": false,
114
+ "single_word": false,
115
+ "special": true
116
+ },
117
+ "151657": {
118
+ "content": "<tool_call>",
119
+ "lstrip": false,
120
+ "normalized": false,
121
+ "rstrip": false,
122
+ "single_word": false,
123
+ "special": false
124
+ },
125
+ "151658": {
126
+ "content": "</tool_call>",
127
+ "lstrip": false,
128
+ "normalized": false,
129
+ "rstrip": false,
130
+ "single_word": false,
131
+ "special": false
132
+ },
133
+ "151659": {
134
+ "content": "<|fim_prefix|>",
135
+ "lstrip": false,
136
+ "normalized": false,
137
+ "rstrip": false,
138
+ "single_word": false,
139
+ "special": false
140
+ },
141
+ "151660": {
142
+ "content": "<|fim_middle|>",
143
+ "lstrip": false,
144
+ "normalized": false,
145
+ "rstrip": false,
146
+ "single_word": false,
147
+ "special": false
148
+ },
149
+ "151661": {
150
+ "content": "<|fim_suffix|>",
151
+ "lstrip": false,
152
+ "normalized": false,
153
+ "rstrip": false,
154
+ "single_word": false,
155
+ "special": false
156
+ },
157
+ "151662": {
158
+ "content": "<|fim_pad|>",
159
+ "lstrip": false,
160
+ "normalized": false,
161
+ "rstrip": false,
162
+ "single_word": false,
163
+ "special": false
164
+ },
165
+ "151663": {
166
+ "content": "<|repo_name|>",
167
+ "lstrip": false,
168
+ "normalized": false,
169
+ "rstrip": false,
170
+ "single_word": false,
171
+ "special": false
172
+ },
173
+ "151664": {
174
+ "content": "<|file_sep|>",
175
+ "lstrip": false,
176
+ "normalized": false,
177
+ "rstrip": false,
178
+ "single_word": false,
179
+ "special": false
180
+ },
181
+ "151665": {
182
+ "content": "<tool_response>",
183
+ "lstrip": false,
184
+ "normalized": false,
185
+ "rstrip": false,
186
+ "single_word": false,
187
+ "special": false
188
+ },
189
+ "151666": {
190
+ "content": "</tool_response>",
191
+ "lstrip": false,
192
+ "normalized": false,
193
+ "rstrip": false,
194
+ "single_word": false,
195
+ "special": false
196
+ },
197
+ "151667": {
198
+ "content": "<think>",
199
+ "lstrip": false,
200
+ "normalized": false,
201
+ "rstrip": false,
202
+ "single_word": false,
203
+ "special": false
204
+ },
205
+ "151668": {
206
+ "content": "</think>",
207
+ "lstrip": false,
208
+ "normalized": false,
209
+ "rstrip": false,
210
+ "single_word": false,
211
+ "special": false
212
+ }
213
+ },
214
+ "additional_special_tokens": [
215
+ "<|im_start|>",
216
+ "<|im_end|>",
217
+ "<|object_ref_start|>",
218
+ "<|object_ref_end|>",
219
+ "<|box_start|>",
220
+ "<|box_end|>",
221
+ "<|quad_start|>",
222
+ "<|quad_end|>",
223
+ "<|vision_start|>",
224
+ "<|vision_end|>",
225
+ "<|vision_pad|>",
226
+ "<|image_pad|>",
227
+ "<|video_pad|>"
228
+ ],
229
+ "bos_token": null,
230
+ "chat_template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].role == 'system' %}\n {{- messages[0].content + '\\n\\n' }}\n {%- endif %}\n {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0].role == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0].content + '<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n{%- for message in messages[::-1] %}\n {%- set index = (messages|length - 1) - loop.index0 %}\n {%- if ns.multi_step_tool and message.role == \"user\" and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}\n {%- set ns.multi_step_tool = false %}\n {%- set ns.last_query_index = index %}\n {%- endif %}\n{%- endfor %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {%- set content = message.content %}\n {%- set reasoning_content = '' %}\n {%- if message.reasoning_content is defined and message.reasoning_content is not none %}\n {%- set reasoning_content = message.reasoning_content %}\n {%- else %}\n {%- if '</think>' in message.content %}\n {%- set content = message.content.split('</think>')[-1].lstrip('\\n') %}\n {%- set reasoning_content = message.content.split('</think>')[0].rstrip('\\n').split('<think>')[-1].lstrip('\\n') %}\n {%- endif %}\n {%- endif %}\n {%- if loop.index0 > ns.last_query_index %}\n {%- if loop.last or (not loop.last and reasoning_content) %}\n {{- '<|im_start|>' + message.role + '\\n<think>\\n' + reasoning_content.strip('\\n') + '\\n</think>\\n\\n' + content.lstrip('\\n') }}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- if message.tool_calls %}\n {%- for tool_call in message.tool_calls %}\n {%- if (loop.first and content) or (not loop.first) %}\n {{- '\\n' }}\n {%- endif %}\n {%- if tool_call.function %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {%- if tool_call.arguments is string %}\n {{- tool_call.arguments }}\n {%- else %}\n {{- tool_call.arguments | tojson }}\n {%- endif %}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- message.content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n {%- if enable_thinking is defined and enable_thinking is false %}\n {{- '<think>\\n\\n</think>\\n\\n' }}\n {%- endif %}\n{%- endif %}",
231
+ "clean_up_tokenization_spaces": false,
232
+ "eos_token": "<|endoftext|>",
233
+ "errors": "replace",
234
+ "model_max_length": 131072,
235
+ "pad_token": "<|endoftext|>",
236
+ "split_special_tokens": false,
237
+ "tokenizer_class": "Qwen2Tokenizer",
238
+ "unk_token": null
239
+ }
vocab.json ADDED
The diff for this file is too large to render. See raw diff